2025-05-15 15:26:06 +08:00
|
|
|
|
import logging
|
|
|
|
|
|
import binascii
|
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
from timeit import default_timer as timer
|
|
|
|
|
|
import datetime
|
|
|
|
|
|
from datetime import timedelta
|
2025-07-10 22:04:44 +08:00
|
|
|
|
import threading, time, queue, uuid, time, array
|
2025-05-15 15:26:06 +08:00
|
|
|
|
from threading import Lock, Thread
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
import base64, gzip
|
2025-07-10 22:04:44 +08:00
|
|
|
|
import os, io, re, json
|
2025-05-15 15:26:06 +08:00
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
from typing import Optional, Dict, Any
|
2025-07-10 22:04:44 +08:00
|
|
|
|
import asyncio, httpx
|
|
|
|
|
|
from collections import deque
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query
|
2025-05-15 15:26:06 +08:00
|
|
|
|
from fastapi import FastAPI, UploadFile, File, Form, Header
|
2025-07-10 22:04:44 +08:00
|
|
|
|
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
|
|
|
|
|
|
|
|
|
|
|
TTS_SAMPLERATE = 44100 # 22050 # 16000
|
2025-05-15 15:26:06 +08:00
|
|
|
|
FORMAT = "mp3"
|
|
|
|
|
|
CHANNELS = 1 # 单声道
|
|
|
|
|
|
SAMPLE_WIDTH = 2 # 16-bit = 2字节
|
|
|
|
|
|
|
|
|
|
|
|
tts_router = APIRouter()
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
# logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MillisecondsFormatter(logging.Formatter):
|
|
|
|
|
|
"""自定义日志格式器,添加毫秒时间戳"""
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
def formatTime(self, record, datefmt=None):
|
|
|
|
|
|
# 将时间戳转换为本地时间元组
|
|
|
|
|
|
ct = self.converter(record.created)
|
|
|
|
|
|
# 格式化为 "小时:分钟:秒"
|
|
|
|
|
|
t = time.strftime("%H:%M:%S", ct)
|
|
|
|
|
|
# 添加毫秒(3位)
|
|
|
|
|
|
return f"{t}.{int(record.msecs):03d}"
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
# 配置全局日志格式
|
|
|
|
|
|
def configure_logging():
|
|
|
|
|
|
# 创建 Formatter
|
|
|
|
|
|
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
|
|
|
|
|
formatter = MillisecondsFormatter(log_format)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取根 Logger 并清除已有配置
|
|
|
|
|
|
root_logger = logging.getLogger()
|
|
|
|
|
|
root_logger.handlers = []
|
|
|
|
|
|
|
|
|
|
|
|
# 创建并配置 Handler(输出到控制台)
|
|
|
|
|
|
console_handler = logging.StreamHandler()
|
|
|
|
|
|
console_handler.setFormatter(formatter)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置日志级别并添加 Handler
|
|
|
|
|
|
root_logger.setLevel(logging.INFO)
|
|
|
|
|
|
root_logger.addHandler(console_handler)
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
# 调用配置函数(程序启动时运行一次)
|
|
|
|
|
|
configure_logging()
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
class StreamSessionManager:
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
|
|
|
|
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
|
|
|
|
|
self.gc_interval = 300 # 5分钟清理一次
|
2025-07-10 22:04:44 +08:00
|
|
|
|
self.streaming_call_timeout = 15 # 20s
|
|
|
|
|
|
self.gc_tts = 3 # 3s
|
|
|
|
|
|
self.sentence_timeout = 1.5 # 1500ms句子超时
|
|
|
|
|
|
self.sentence_endings = set('。?!;.?!;') # 中英文结束符
|
|
|
|
|
|
# 增强版正则表达式:匹配中英文句子结束符(包含全角)
|
|
|
|
|
|
self.sentence_pattern = re.compile(
|
|
|
|
|
|
r'([,,。?!;.?!;?!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;?!;…])'
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def create_session(self, tts_model, sample_rate=8000, stream_format='mp3', session_id=None, streaming_call=False):
|
|
|
|
|
|
if not session_id:
|
|
|
|
|
|
session_id = str(uuid.uuid4())
|
2025-05-15 15:26:06 +08:00
|
|
|
|
with self.lock:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# 创建TTS实例并设置流式回调
|
|
|
|
|
|
tts_instance = tts_model
|
|
|
|
|
|
|
|
|
|
|
|
# 定义音频数据回调函数
|
|
|
|
|
|
def on_data(data: bytes):
|
|
|
|
|
|
if data:
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.sessions[session_id]['last_active'] = time.time()
|
|
|
|
|
|
self.sessions[session_id]['buffer'].put(data)
|
|
|
|
|
|
except queue.Full:
|
|
|
|
|
|
logging.warning(f"Audio buffer full for session {session_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# 设置TTS流式传输
|
|
|
|
|
|
tts_instance.setup_tts(on_data)
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
self.sessions[session_id] = {
|
|
|
|
|
|
'tts_model': tts_model,
|
|
|
|
|
|
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
|
|
|
|
|
'task_queue': queue.Queue(),
|
|
|
|
|
|
'active': True,
|
|
|
|
|
|
'last_active': time.time(),
|
2025-07-10 22:04:44 +08:00
|
|
|
|
'audio_chunk_count': 0,
|
2025-05-15 15:26:06 +08:00
|
|
|
|
'finished': threading.Event(), # 添加事件对象
|
2025-07-10 22:04:44 +08:00
|
|
|
|
'sample_rate': sample_rate,
|
|
|
|
|
|
'stream_format': stream_format,
|
|
|
|
|
|
"tts_chunk_data_valid": False,
|
|
|
|
|
|
"text_buffer": "", # 新增文本缓冲区
|
|
|
|
|
|
"last_text_time": time.time(), # 最后文本到达时间
|
|
|
|
|
|
"streaming_call": streaming_call,
|
|
|
|
|
|
"tts_stream_started": False # 标记是否已启动流
|
2025-05-15 15:26:06 +08:00
|
|
|
|
}
|
|
|
|
|
|
# 启动任务处理线程
|
|
|
|
|
|
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
|
|
|
|
|
def append_text(self, session_id, text):
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
session = self.sessions.get(session_id)
|
|
|
|
|
|
if not session: return
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# 更新文本缓冲区和时间戳
|
|
|
|
|
|
session['text_buffer'] += text
|
|
|
|
|
|
session['last_text_time'] = time.time()
|
2025-05-15 15:26:06 +08:00
|
|
|
|
# 将文本放入任务队列(非阻塞)
|
|
|
|
|
|
try:
|
|
|
|
|
|
session['task_queue'].put(text, block=False)
|
|
|
|
|
|
except queue.Full:
|
|
|
|
|
|
logging.warning(f"Session {session_id} task queue full")
|
|
|
|
|
|
|
|
|
|
|
|
def _process_tasks(self, session_id):
|
|
|
|
|
|
"""任务处理线程(每个会话独立)"""
|
2025-07-10 22:04:44 +08:00
|
|
|
|
session = self.sessions.get(session_id)
|
|
|
|
|
|
if not session or not session['active']:
|
|
|
|
|
|
return
|
|
|
|
|
|
gen_tts_audio_func = self._generate_audio
|
|
|
|
|
|
if session.get('streaming_call'):
|
|
|
|
|
|
gen_tts_audio_func = self._stream_audio
|
|
|
|
|
|
while session['active']:
|
|
|
|
|
|
current_time = time.time()
|
|
|
|
|
|
text_to_process = ""
|
|
|
|
|
|
|
|
|
|
|
|
# 直接处理缓冲区文本(无中间变量)
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
if session['text_buffer']:
|
|
|
|
|
|
text_to_process = session['text_buffer']
|
|
|
|
|
|
session['text_buffer'] = "" # 清空缓冲区
|
|
|
|
|
|
|
|
|
|
|
|
if text_to_process:
|
|
|
|
|
|
# 分割完整句子
|
|
|
|
|
|
complete_sentences, remaining_text = self._split_and_extract(text_to_process)
|
|
|
|
|
|
# 保存剩余文本
|
|
|
|
|
|
if remaining_text:
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
session['text_buffer'] = remaining_text + session['text_buffer']
|
|
|
|
|
|
|
|
|
|
|
|
# 合并并处理完整句子
|
|
|
|
|
|
if complete_sentences:
|
|
|
|
|
|
# 智能合并句子(最长300字符)
|
|
|
|
|
|
buffer = []
|
|
|
|
|
|
current_length = 0
|
|
|
|
|
|
|
|
|
|
|
|
for sentence in complete_sentences:
|
|
|
|
|
|
sent_length = len(sentence)
|
|
|
|
|
|
|
|
|
|
|
|
# 添加到当前缓冲区
|
|
|
|
|
|
if current_length + sent_length <= 300:
|
|
|
|
|
|
buffer.append(sentence)
|
|
|
|
|
|
current_length += sent_length
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 处理已缓冲的文本
|
|
|
|
|
|
if buffer:
|
|
|
|
|
|
gen_tts_audio_func(session_id, "".join(buffer))
|
|
|
|
|
|
buffer = [sentence]
|
|
|
|
|
|
current_length = sent_length
|
|
|
|
|
|
|
|
|
|
|
|
# 处理剩余的缓冲文本
|
|
|
|
|
|
if buffer:
|
|
|
|
|
|
gen_tts_audio_func(session_id, "".join(buffer))
|
|
|
|
|
|
|
|
|
|
|
|
# 检查超时未处理的文本
|
|
|
|
|
|
if current_time - session['last_text_time'] > self.sentence_timeout:
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
if session['text_buffer']:
|
|
|
|
|
|
# 直接处理剩余文本
|
|
|
|
|
|
gen_tts_audio_func(session_id, session['text_buffer'])
|
|
|
|
|
|
session['text_buffer'] = ""
|
|
|
|
|
|
|
|
|
|
|
|
if current_time - session['last_active'] > self.streaming_call_timeout:
|
|
|
|
|
|
if session.get('streaming_call'):
|
|
|
|
|
|
session['tts_model'].end_streaming_call()
|
|
|
|
|
|
session['streaming_call'] = False
|
|
|
|
|
|
|
|
|
|
|
|
# 会话超时检查
|
|
|
|
|
|
if current_time - session['last_active'] > self.gc_interval:
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
if session['text_buffer']:
|
|
|
|
|
|
gen_tts_audio_func(session_id, session['text_buffer'])
|
|
|
|
|
|
session['text_buffer'] = ""
|
|
|
|
|
|
self.close_session(session_id)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
break
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# 休眠避免CPU空转
|
|
|
|
|
|
time.sleep(0.05) # 50ms检查间隔
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
|
|
|
|
|
def _generate_audio(self, session_id, text):
|
|
|
|
|
|
"""实际生成音频(线程池执行)"""
|
|
|
|
|
|
session = self.sessions.get(session_id)
|
|
|
|
|
|
if not session: return
|
2025-07-10 22:04:44 +08:00
|
|
|
|
logging.info(f"_generate_audio:{text}")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
first_chunk = True
|
|
|
|
|
|
# logging.info(f"转换开始!!! {text}")
|
|
|
|
|
|
try:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
for chunk in session['tts_model'].tts(text, session['sample_rate'], session['stream_format']):
|
2025-05-15 15:26:06 +08:00
|
|
|
|
if session['stream_format'] == 'wav':
|
|
|
|
|
|
if first_chunk:
|
|
|
|
|
|
chunk_len = len(chunk)
|
|
|
|
|
|
if chunk_len > 2048:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
session['buffer'].put(audio_fade_in(chunk, 1024))
|
2025-05-15 15:26:06 +08:00
|
|
|
|
else:
|
|
|
|
|
|
session['buffer'].put(audio_fade_in(chunk, chunk_len))
|
|
|
|
|
|
first_chunk = False
|
|
|
|
|
|
else:
|
|
|
|
|
|
session['buffer'].put(chunk)
|
|
|
|
|
|
else:
|
|
|
|
|
|
session['buffer'].put(chunk)
|
2025-07-10 22:04:44 +08:00
|
|
|
|
"""
|
|
|
|
|
|
session['tts_model'].text_tts_call(text)
|
|
|
|
|
|
session['last_active'] = time.time()
|
|
|
|
|
|
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
|
|
|
|
|
if session['tts_chunk_data_valid'] is False:
|
|
|
|
|
|
session['tts_chunk_data_valid'] = True # 20250510 增加,表示连接TTS后台已经返回,可以通知前端了
|
|
|
|
|
|
# logging.info(f"转换结束!!! {session['audio_chunk_count']}")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
session['buffer'].put(f"ERROR:{str(e)}")
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def _stream_audio(self, session_id, text):
|
|
|
|
|
|
"""流式传输文本到TTS服务"""
|
|
|
|
|
|
session = self.sessions.get(session_id)
|
|
|
|
|
|
if not session:
|
|
|
|
|
|
return
|
|
|
|
|
|
# logging.info(f"Streaming text to TTS: {text}")
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用流式调用发送文本
|
|
|
|
|
|
session['tts_model'].streaming_call(text)
|
|
|
|
|
|
session['last_active'] = time.time()
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"Error in streaming_call: {str(e)}")
|
|
|
|
|
|
session['buffer'].put(f"ERROR:{str(e)}".encode())
|
|
|
|
|
|
|
|
|
|
|
|
async def get_tts_buffer_data(self, session_id):
|
|
|
|
|
|
"""异步流式返回 TTS 音频数据(适配同步 queue.Queue,带 10 秒超时)"""
|
|
|
|
|
|
session = self.sessions.get(session_id)
|
|
|
|
|
|
if not session:
|
|
|
|
|
|
raise ValueError(f"Session {session_id} not found")
|
|
|
|
|
|
|
|
|
|
|
|
buffer = session['buffer'] # 这里是 queue.Queue
|
|
|
|
|
|
last_data_time = time.time() # 记录最后一次获取数据的时间
|
|
|
|
|
|
|
|
|
|
|
|
while session['active']:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 使用 run_in_executor + wait_for 设置 10 秒超时
|
|
|
|
|
|
data = await asyncio.wait_for(
|
|
|
|
|
|
asyncio.get_event_loop().run_in_executor(None, buffer.get),
|
|
|
|
|
|
timeout=10.0 # 10 秒超时
|
|
|
|
|
|
)
|
|
|
|
|
|
last_data_time = time.time() # 更新最后数据时间
|
|
|
|
|
|
yield data
|
|
|
|
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
|
|
# 10 秒内没有新数据,检查是否超时
|
|
|
|
|
|
if time.time() - last_data_time >= 10.0:
|
|
|
|
|
|
break
|
|
|
|
|
|
else:
|
|
|
|
|
|
continue # 未超时,继续等待
|
|
|
|
|
|
|
|
|
|
|
|
except asyncio.CancelledError:
|
|
|
|
|
|
logging.info(f"Session {session_id} stream cancelled")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"Error in get_tts_buffer_data: {e}")
|
|
|
|
|
|
break
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
|
|
|
|
|
def close_session(self, session_id):
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
if session_id in self.sessions:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# 结束流式传输
|
|
|
|
|
|
try:
|
|
|
|
|
|
# if self.sessions[session_id].get('streaming_call'):
|
|
|
|
|
|
# self.sessions[session_id]['tts_model'].end_streaming_call()
|
|
|
|
|
|
logging.info(f"Ended streaming for session {session_id}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"Error ending streaming call: {str(e)}")
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
# 标记会话为不活跃
|
|
|
|
|
|
self.sessions[session_id]['active'] = False
|
|
|
|
|
|
# 延迟2秒后清理资源
|
|
|
|
|
|
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
|
|
|
|
|
|
|
|
|
|
|
def _clean_session(self, session_id):
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
if session_id in self.sessions:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# 确保流完全关闭
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.sessions[session_id]['tts_model'].end_streaming_call()
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
2025-05-15 15:26:06 +08:00
|
|
|
|
del self.sessions[session_id]
|
|
|
|
|
|
|
|
|
|
|
|
def get_session(self, session_id):
|
|
|
|
|
|
return self.sessions.get(session_id)
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def _has_sentence_ending(self, text):
|
|
|
|
|
|
"""检测文本是否包含句子结束符"""
|
|
|
|
|
|
if not text:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 检查常见结束符(包含全角字符)
|
|
|
|
|
|
if any(char in self.sentence_endings for char in text[-3:]):
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
# 检查中文段落结束(换行符前有结束符)
|
|
|
|
|
|
if '\n' in text and any(char in self.sentence_endings for char in text.split('\n')[-2:-1]):
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def _split_and_extract(self, text):
|
|
|
|
|
|
"""
|
|
|
|
|
|
增强型句子分割器
|
|
|
|
|
|
返回: (完整句子列表, 剩余文本)
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 特殊处理:如果文本以逗号开头,先处理前面的部分
|
|
|
|
|
|
if text.startswith((",", ",")):
|
|
|
|
|
|
return [text[0]], text[1:]
|
|
|
|
|
|
|
|
|
|
|
|
# 1. 查找所有可能的句子结束位置
|
|
|
|
|
|
matches = list(self.sentence_pattern.finditer(text))
|
|
|
|
|
|
|
|
|
|
|
|
if not matches:
|
|
|
|
|
|
return [], text # 没有找到结束符
|
|
|
|
|
|
|
|
|
|
|
|
# 2. 确定最后一个完整句子的结束位置
|
|
|
|
|
|
last_end = 0
|
|
|
|
|
|
complete_sentences = []
|
|
|
|
|
|
|
|
|
|
|
|
for match in matches:
|
|
|
|
|
|
end_pos = match.end()
|
|
|
|
|
|
sentence = text[last_end:end_pos].strip()
|
|
|
|
|
|
|
|
|
|
|
|
# 跳过空句子
|
|
|
|
|
|
if not sentence:
|
|
|
|
|
|
last_end = end_pos
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否为有效句子(最小长度或包含结束符)
|
|
|
|
|
|
if len(sentence) > 6 or any(char in "。.?!?!" for char in sentence):
|
|
|
|
|
|
complete_sentences.append(sentence)
|
|
|
|
|
|
last_end = end_pos
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 短文本但包含结束符,可能是特殊符号
|
|
|
|
|
|
if any(char in "。.?!?!" for char in sentence):
|
|
|
|
|
|
complete_sentences.append(sentence)
|
|
|
|
|
|
last_end = end_pos
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 提取剩余文本
|
|
|
|
|
|
remaining_text = text[last_end:].strip()
|
|
|
|
|
|
|
|
|
|
|
|
return complete_sentences, remaining_text
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
stream_manager = StreamSessionManager()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def allowed_file(filename):
|
|
|
|
|
|
return '.' in filename and \
|
|
|
|
|
|
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_text_cache = {}
|
|
|
|
|
|
cache_lock = Lock()
|
|
|
|
|
|
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
# WebSocket 连接管理
|
|
|
|
|
|
class ConnectionManager:
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.active_connections = {}
|
|
|
|
|
|
|
|
|
|
|
|
async def connect(self, websocket: WebSocket, connection_id: str):
|
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
|
self.active_connections[connection_id] = websocket
|
|
|
|
|
|
logging.info(f"新连接建立: {connection_id}")
|
|
|
|
|
|
|
|
|
|
|
|
async def disconnect(self, connection_id: str, code=1000, reason: str = ""):
|
|
|
|
|
|
if connection_id in self.active_connections:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 尝试正常关闭连接(非阻塞)
|
|
|
|
|
|
await self.active_connections[connection_id].close(code=code, reason=reason)
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass # 忽略关闭错误
|
|
|
|
|
|
finally:
|
|
|
|
|
|
del self.active_connections[connection_id]
|
|
|
|
|
|
|
|
|
|
|
|
def is_connected(self, connection_id: str) -> bool:
|
|
|
|
|
|
"""检查连接是否仍然活跃"""
|
|
|
|
|
|
return connection_id in self.active_connections
|
|
|
|
|
|
|
|
|
|
|
|
async def _safe_send(self, connection_id: str, send_func, *args):
|
|
|
|
|
|
"""安全发送的通用方法(核心修改)"""
|
|
|
|
|
|
# 1. 检查连接是否存在
|
|
|
|
|
|
if connection_id not in self.active_connections:
|
|
|
|
|
|
logging.warning(f"尝试向不存在的连接发送数据: {connection_id}")
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
websocket = self.active_connections[connection_id]
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
try:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# 2. 检查连接状态(关键修改)
|
|
|
|
|
|
if websocket.client_state.name != "CONNECTED":
|
|
|
|
|
|
logging.warning(f"连接 {connection_id} 状态为 {websocket.client_state.name}")
|
|
|
|
|
|
await self.disconnect(connection_id)
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
# 3. 执行发送操作
|
|
|
|
|
|
await send_func(websocket, *args)
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
except (WebSocketDisconnect, RuntimeError) as e:
|
|
|
|
|
|
# 4. 处理连接断开异常
|
|
|
|
|
|
logging.info(f"发送时检测到断开连接: {connection_id}, {str(e)}")
|
|
|
|
|
|
await self.disconnect(connection_id)
|
|
|
|
|
|
return False
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
# 5. 处理其他异常
|
|
|
|
|
|
logging.error(f"发送数据出错: {connection_id}, {str(e)}")
|
|
|
|
|
|
await self.disconnect(connection_id)
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
async def send_bytes(self, connection_id: str, data: bytes):
|
|
|
|
|
|
"""安全发送字节数据"""
|
|
|
|
|
|
return await self._safe_send(
|
|
|
|
|
|
connection_id,
|
|
|
|
|
|
lambda ws, d: ws.send_bytes(d),
|
|
|
|
|
|
data
|
|
|
|
|
|
)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
async def send_text(self, connection_id: str, message: str):
|
|
|
|
|
|
"""安全发送文本数据"""
|
|
|
|
|
|
return await self._safe_send(
|
|
|
|
|
|
connection_id,
|
|
|
|
|
|
lambda ws, m: ws.send_text(m),
|
|
|
|
|
|
message
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def send_json(self, connection_id: str, data: dict):
|
|
|
|
|
|
"""安全发送JSON数据"""
|
|
|
|
|
|
return await self._safe_send(
|
|
|
|
|
|
connection_id,
|
|
|
|
|
|
lambda ws, d: ws.send_json(d),
|
|
|
|
|
|
data
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
manager = ConnectionManager()
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
def generate_mp3_header(
|
2025-07-10 22:04:44 +08:00
|
|
|
|
sample_rate: int,
|
|
|
|
|
|
bitrate_kbps: int,
|
|
|
|
|
|
channels: int = 1,
|
|
|
|
|
|
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
|
2025-05-26 21:38:46 +08:00
|
|
|
|
) -> bytes:
|
|
|
|
|
|
"""
|
|
|
|
|
|
动态生成 MP3 帧头(4字节),支持 Layer I/II/III
|
|
|
|
|
|
|
|
|
|
|
|
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
|
|
|
|
|
|
:param bitrate_kbps: 比特率(单位 kbps)
|
|
|
|
|
|
:param channels: 声道数 (1: 单声道, 2: 立体声)
|
|
|
|
|
|
:param layer: 编码层 ("I", "II", "III")
|
|
|
|
|
|
:return: 4字节的帧头数据
|
|
|
|
|
|
"""
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
# 参数校验
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
|
|
|
|
|
|
if sample_rate not in valid_sample_rates:
|
|
|
|
|
|
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
|
|
|
|
|
|
|
|
|
|
|
|
valid_layers = {"I", "II", "III"}
|
|
|
|
|
|
if layer not in valid_layers:
|
|
|
|
|
|
raise ValueError(f"不支持的层,可选:{valid_layers}")
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
# 确定 MPEG 版本和采样率索引
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
if sample_rate == 44100:
|
|
|
|
|
|
mpeg_version = 0b11 # MPEG-1
|
|
|
|
|
|
sample_rate_index = 0b00
|
|
|
|
|
|
elif sample_rate == 22050:
|
|
|
|
|
|
mpeg_version = 0b10 # MPEG-2
|
|
|
|
|
|
sample_rate_index = 0b00
|
|
|
|
|
|
elif sample_rate == 16000:
|
|
|
|
|
|
mpeg_version = 0b10 # MPEG-2
|
|
|
|
|
|
sample_rate_index = 0b10
|
|
|
|
|
|
elif sample_rate == 8000:
|
|
|
|
|
|
mpeg_version = 0b00 # MPEG-2.5
|
|
|
|
|
|
sample_rate_index = 0b10
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError("采样率与版本不匹配")
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
# 动态选择比特率表(关键扩展)
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
# Layer 编码映射(I:0b11, II:0b10, III:0b01)
|
|
|
|
|
|
layer_code = {
|
|
|
|
|
|
"I": 0b11,
|
|
|
|
|
|
"II": 0b10,
|
|
|
|
|
|
"III": 0b01
|
|
|
|
|
|
}[layer]
|
|
|
|
|
|
|
|
|
|
|
|
# 比特率表(覆盖所有 Layer)
|
|
|
|
|
|
bitrate_tables = {
|
|
|
|
|
|
# -------------------------------
|
|
|
|
|
|
# MPEG-1 (0b11)
|
|
|
|
|
|
# -------------------------------
|
|
|
|
|
|
# Layer I
|
|
|
|
|
|
(0b11, 0b11): {
|
|
|
|
|
|
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
|
|
|
|
|
|
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
|
|
|
|
|
|
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
|
|
|
|
|
|
416: 0b1100, 448: 0b1101
|
|
|
|
|
|
},
|
|
|
|
|
|
# Layer II
|
|
|
|
|
|
(0b11, 0b10): {
|
|
|
|
|
|
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
|
|
|
|
|
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
|
|
|
|
|
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
|
|
|
|
|
|
320: 0b1100, 384: 0b1101
|
|
|
|
|
|
},
|
|
|
|
|
|
# Layer III
|
|
|
|
|
|
(0b11, 0b01): {
|
|
|
|
|
|
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
|
|
|
|
|
|
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
|
|
|
|
|
|
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
|
|
|
|
|
|
256: 0b0100, 320: 0b0101
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
# -------------------------------
|
|
|
|
|
|
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
|
|
|
|
|
|
# -------------------------------
|
|
|
|
|
|
# Layer I
|
|
|
|
|
|
(0b10, 0b11): {
|
|
|
|
|
|
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
|
|
|
|
|
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
|
|
|
|
|
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
|
|
|
|
|
224: 0b1100, 256: 0b1101
|
|
|
|
|
|
},
|
|
|
|
|
|
(0b00, 0b11): {
|
|
|
|
|
|
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
|
|
|
|
|
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
|
|
|
|
|
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
|
|
|
|
|
224: 0b1100, 256: 0b1101
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
# Layer II
|
|
|
|
|
|
(0b10, 0b10): {
|
|
|
|
|
|
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
|
|
|
|
|
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
|
|
|
|
|
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
|
|
|
|
|
144: 0b1100, 160: 0b1101
|
|
|
|
|
|
},
|
|
|
|
|
|
(0b00, 0b10): {
|
|
|
|
|
|
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
|
|
|
|
|
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
|
|
|
|
|
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
|
|
|
|
|
144: 0b1100, 160: 0b1101
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
|
|
|
|
# Layer III
|
|
|
|
|
|
(0b10, 0b01): {
|
|
|
|
|
|
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
|
|
|
|
|
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
|
|
|
|
|
|
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
|
|
|
|
|
|
144: 0b0100, 160: 0b0101
|
|
|
|
|
|
},
|
|
|
|
|
|
(0b00, 0b01): {
|
|
|
|
|
|
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
|
|
|
|
|
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 获取当前版本的比特率表
|
|
|
|
|
|
key = (mpeg_version, layer_code)
|
|
|
|
|
|
if key not in bitrate_tables:
|
|
|
|
|
|
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
|
|
|
|
|
|
bitrate_table = bitrate_tables[key]
|
|
|
|
|
|
|
|
|
|
|
|
if bitrate_kbps not in bitrate_table:
|
|
|
|
|
|
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
|
|
|
|
|
|
bitrate_index = bitrate_table[bitrate_kbps]
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
# 确定声道模式
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
if channels == 1:
|
|
|
|
|
|
channel_mode = 0b11 # 单声道
|
|
|
|
|
|
elif channels == 2:
|
|
|
|
|
|
channel_mode = 0b00 # 立体声
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError("声道数必须为1或2")
|
|
|
|
|
|
|
|
|
|
|
|
# ----------------------------------
|
|
|
|
|
|
# 组合帧头字段(修正层编码)
|
|
|
|
|
|
# ----------------------------------
|
2025-07-10 22:04:44 +08:00
|
|
|
|
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
|
2025-05-26 21:38:46 +08:00
|
|
|
|
version = mpeg_version << 19 # MPEG 版本 2位
|
|
|
|
|
|
layer_bits = layer_code << 17 # Layer 编码(I:0b11, II:0b10, III:0b01)
|
2025-07-10 22:04:44 +08:00
|
|
|
|
protection = 0 << 16 # 无 CRC
|
2025-05-26 21:38:46 +08:00
|
|
|
|
bitrate_bits = bitrate_index << 12
|
|
|
|
|
|
sample_rate_bits = sample_rate_index << 10
|
2025-07-10 22:04:44 +08:00
|
|
|
|
padding = 0 << 9 # 无填充
|
2025-05-26 21:38:46 +08:00
|
|
|
|
private = 0 << 8
|
|
|
|
|
|
mode = channel_mode << 6
|
2025-07-10 22:04:44 +08:00
|
|
|
|
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
|
2025-05-26 21:38:46 +08:00
|
|
|
|
copyright = 0 << 3
|
|
|
|
|
|
original = 0 << 2
|
2025-07-10 22:04:44 +08:00
|
|
|
|
emphasis = 0b00 # 无强调
|
2025-05-26 21:38:46 +08:00
|
|
|
|
|
|
|
|
|
|
frame_header = (
|
2025-07-10 22:04:44 +08:00
|
|
|
|
sync |
|
|
|
|
|
|
version |
|
|
|
|
|
|
layer_bits |
|
|
|
|
|
|
protection |
|
|
|
|
|
|
bitrate_bits |
|
|
|
|
|
|
sample_rate_bits |
|
|
|
|
|
|
padding |
|
|
|
|
|
|
private |
|
|
|
|
|
|
mode |
|
|
|
|
|
|
mode_ext |
|
|
|
|
|
|
copyright |
|
|
|
|
|
|
original |
|
|
|
|
|
|
emphasis
|
2025-05-15 15:26:06 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
return frame_header.to_bytes(4, byteorder='big')
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
# ------------------------------------------------
|
|
|
|
|
|
def audio_fade_in(audio_data, fade_length):
|
|
|
|
|
|
# 假设音频数据是16位单声道PCM
|
|
|
|
|
|
# 将二进制数据转换为整数数组
|
|
|
|
|
|
samples = array.array('h', audio_data)
|
|
|
|
|
|
|
|
|
|
|
|
# 对前fade_length个样本进行淡入处理
|
|
|
|
|
|
for i in range(fade_length):
|
|
|
|
|
|
fade_factor = i / fade_length
|
|
|
|
|
|
samples[i] = int(samples[i] * fade_factor)
|
|
|
|
|
|
|
|
|
|
|
|
# 将整数数组转换回二进制数据
|
|
|
|
|
|
return samples.tobytes()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_markdown_json(json_string):
|
|
|
|
|
|
# 使用正则表达式匹配Markdown中的JSON代码块
|
|
|
|
|
|
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
|
|
|
|
|
|
if match:
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 尝试解析JSON字符串
|
|
|
|
|
|
data = json.loads(match[1])
|
|
|
|
|
|
return {'success': True, 'data': data}
|
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
|
# 如果解析失败,返回错误信息
|
|
|
|
|
|
return {'success': False, 'data': str(e)}
|
|
|
|
|
|
else:
|
|
|
|
|
|
return {'success': False, 'data': 'not a valid markdown json string'}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_text_cache = {}
|
|
|
|
|
|
cache_lock = Lock()
|
|
|
|
|
|
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全角字符到半角字符的映射
|
|
|
|
|
|
def fullwidth_to_halfwidth(s):
|
|
|
|
|
|
full_to_half_map = {
|
|
|
|
|
|
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
|
|
|
|
|
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
|
|
|
|
|
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
|
|
|
|
|
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
|
|
|
|
|
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
|
|
|
|
|
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
|
|
|
|
|
'、': '、', '・': '・', ':': ':'
|
|
|
|
|
|
}
|
|
|
|
|
|
return ''.join(full_to_half_map.get(char, char) for char in s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_text_at_punctuation(text, chunk_size=100):
|
|
|
|
|
|
# 使用正则表达式找到所有的标点符号和特殊字符
|
|
|
|
|
|
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
|
|
|
|
|
tokens = re.split(punctuation_pattern, text)
|
|
|
|
|
|
|
|
|
|
|
|
# 移除空字符串
|
|
|
|
|
|
tokens = [token for token in tokens if token]
|
|
|
|
|
|
|
|
|
|
|
|
# 存储最终的文本块
|
|
|
|
|
|
chunks = []
|
|
|
|
|
|
current_chunk = ''
|
|
|
|
|
|
|
|
|
|
|
|
for token in tokens:
|
|
|
|
|
|
if len(current_chunk) + len(token) <= chunk_size:
|
|
|
|
|
|
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
|
|
|
|
|
current_chunk += (token + ' ')
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
|
|
|
|
|
chunks.append(current_chunk.strip())
|
|
|
|
|
|
current_chunk = token + ' '
|
|
|
|
|
|
|
|
|
|
|
|
# 添加最后一个块(如果有剩余)
|
|
|
|
|
|
if current_chunk:
|
|
|
|
|
|
chunks.append(current_chunk.strip())
|
|
|
|
|
|
|
|
|
|
|
|
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_text_from_markdown(markdown_text):
|
|
|
|
|
|
# 移除Markdown标题
|
|
|
|
|
|
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
|
|
|
|
|
# 移除内联代码块
|
|
|
|
|
|
text = re.sub(r'`[^`]+`', '', text)
|
|
|
|
|
|
# 移除代码块
|
|
|
|
|
|
text = re.sub(r'```[\s\S]*?```', '', text)
|
|
|
|
|
|
# 移除加粗和斜体
|
|
|
|
|
|
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
|
|
|
|
|
# 移除链接
|
|
|
|
|
|
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
|
|
|
|
|
# 移除图片
|
|
|
|
|
|
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
|
|
|
|
|
# 移除HTML标签
|
|
|
|
|
|
text = re.sub(r'<[^>]+>', '', text)
|
|
|
|
|
|
# 转换标点符号
|
|
|
|
|
|
# text = re.sub(r'[^\w\s]', '', text)
|
|
|
|
|
|
text = fullwidth_to_halfwidth(text)
|
|
|
|
|
|
# 移除多余的空格
|
|
|
|
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|
|
|
|
|
|
|
|
|
|
|
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def encode_gzip_base64(original_data: bytes) -> str:
|
|
|
|
|
|
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
|
|
|
|
|
|
# Step 1: Gzip 压缩
|
|
|
|
|
|
with BytesIO() as buf:
|
|
|
|
|
|
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
|
|
|
|
|
|
gz_file.write(original_data)
|
|
|
|
|
|
compressed_bytes = buf.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
# Step 2: Base64 编码(配置与Android端匹配)
|
|
|
|
|
|
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_audio_cache():
|
|
|
|
|
|
"""定时清理过期缓存"""
|
|
|
|
|
|
with cache_lock:
|
|
|
|
|
|
now = time.time()
|
|
|
|
|
|
expired_keys = [
|
|
|
|
|
|
k for k, v in audio_text_cache.items()
|
|
|
|
|
|
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
|
|
|
|
|
|
]
|
|
|
|
|
|
for k in expired_keys:
|
|
|
|
|
|
entry = audio_text_cache.pop(k, None)
|
|
|
|
|
|
if entry and entry.get('audio_stream'):
|
|
|
|
|
|
entry['audio_stream'].close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_background_cleaner():
|
|
|
|
|
|
"""启动后台清理线程"""
|
|
|
|
|
|
|
|
|
|
|
|
def cleaner_loop():
|
|
|
|
|
|
while True:
|
|
|
|
|
|
time.sleep(180) # 每3分钟清理一次
|
|
|
|
|
|
clean_audio_cache()
|
|
|
|
|
|
|
|
|
|
|
|
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
|
|
|
|
|
|
cleaner_thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_qwen_chat():
|
|
|
|
|
|
messages = [
|
|
|
|
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
|
|
|
|
|
{'role': 'user', 'content': '你是谁?'}
|
|
|
|
|
|
]
|
|
|
|
|
|
response = Generation.call(
|
|
|
|
|
|
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx",
|
|
|
|
|
|
api_key=ALI_KEY,
|
|
|
|
|
|
model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models
|
|
|
|
|
|
messages=messages,
|
|
|
|
|
|
result_format="message"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
|
print(response.output.choices[0].message.content)
|
|
|
|
|
|
else:
|
|
|
|
|
|
print(f"HTTP返回码:{response.status_code}")
|
|
|
|
|
|
print(f"错误码:{response.code}")
|
|
|
|
|
|
print(f"错误信息:{response.message}")
|
|
|
|
|
|
print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
|
2025-07-10 22:04:44 +08:00
|
|
|
|
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
|
|
|
|
|
from dashscope.audio.tts import (
|
|
|
|
|
|
ResultCallback as TTSResultCallback,
|
|
|
|
|
|
SpeechSynthesizer as TTSSpeechSynthesizer,
|
|
|
|
|
|
SpeechSynthesisResult as TTSSpeechSynthesisResult,
|
|
|
|
|
|
)
|
|
|
|
|
|
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
|
|
|
|
|
|
from dashscope.audio.tts_v2 import (
|
|
|
|
|
|
ResultCallback as CosyResultCallback,
|
|
|
|
|
|
SpeechSynthesizer as CosySpeechSynthesizer,
|
|
|
|
|
|
AudioFormat,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
|
|
|
|
|
class QwenTTS:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun"):
|
2025-05-15 15:26:06 +08:00
|
|
|
|
import dashscope
|
|
|
|
|
|
import ssl
|
2025-07-10 22:04:44 +08:00
|
|
|
|
logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx
|
|
|
|
|
|
self.model_name = model_name.split('@')[0]
|
2025-05-15 15:26:06 +08:00
|
|
|
|
dashscope.api_key = key
|
|
|
|
|
|
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
|
|
|
|
|
|
self.synthesizer = None
|
|
|
|
|
|
self.callback = None
|
|
|
|
|
|
self.is_cosyvoice = False
|
|
|
|
|
|
self.voice = ""
|
2025-05-26 21:38:46 +08:00
|
|
|
|
self.format = format
|
|
|
|
|
|
self.sample_rate = sample_rate
|
2025-07-10 22:04:44 +08:00
|
|
|
|
self.first_chunk = True
|
|
|
|
|
|
if '/' in self.model_name:
|
|
|
|
|
|
parts = self.model_name.split('/', 1)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
# 返回分离后的两个字符串parts[0], parts[1]
|
|
|
|
|
|
if parts[0] == 'cosyvoice-v1':
|
|
|
|
|
|
self.is_cosyvoice = True
|
|
|
|
|
|
self.voice = parts[1]
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
class Callback(TTSResultCallback):
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
|
self.dque = deque()
|
|
|
|
|
|
|
|
|
|
|
|
def _run(self):
|
|
|
|
|
|
while True:
|
|
|
|
|
|
if not self.dque:
|
|
|
|
|
|
time.sleep(0)
|
|
|
|
|
|
continue
|
|
|
|
|
|
val = self.dque.popleft()
|
|
|
|
|
|
if val:
|
|
|
|
|
|
yield val
|
|
|
|
|
|
else:
|
|
|
|
|
|
break
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_open(self):
|
|
|
|
|
|
pass
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_complete(self):
|
|
|
|
|
|
self.dque.append(None)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_error(self, response: SpeechSynthesisResponse):
|
|
|
|
|
|
print("Qwen tts error", str(response))
|
|
|
|
|
|
raise RuntimeError(str(response))
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_close(self):
|
|
|
|
|
|
pass
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_event(self, result: TTSSpeechSynthesisResult):
|
|
|
|
|
|
if result.get_audio_frame() is not None:
|
|
|
|
|
|
self.dque.append(result.get_audio_frame())
|
|
|
|
|
|
|
|
|
|
|
|
# --------------------------
|
|
|
|
|
|
|
|
|
|
|
|
class Callback_Cosy(CosyResultCallback):
|
|
|
|
|
|
def __init__(self, data_callback=None) -> None:
|
|
|
|
|
|
self.dque = deque()
|
|
|
|
|
|
self.data_callback = data_callback
|
|
|
|
|
|
|
|
|
|
|
|
def _run(self):
|
|
|
|
|
|
while True:
|
|
|
|
|
|
if not self.dque:
|
|
|
|
|
|
time.sleep(0)
|
|
|
|
|
|
continue
|
|
|
|
|
|
val = self.dque.popleft()
|
|
|
|
|
|
if val:
|
|
|
|
|
|
yield val
|
|
|
|
|
|
else:
|
|
|
|
|
|
break
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_open(self):
|
|
|
|
|
|
logging.info("Qwen CosyVoice tts open ")
|
|
|
|
|
|
pass
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_complete(self):
|
|
|
|
|
|
self.dque.append(None)
|
|
|
|
|
|
if self.data_callback:
|
|
|
|
|
|
self.data_callback(None) # 发送结束信号
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_error(self, response: SpeechSynthesisResponse):
|
|
|
|
|
|
print("Qwen tts error", str(response))
|
|
|
|
|
|
if self.data_callback:
|
|
|
|
|
|
self.data_callback(f"ERROR:{str(response)}".encode())
|
|
|
|
|
|
raise RuntimeError(str(response))
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_close(self):
|
|
|
|
|
|
# print("---Qwen call back close") # cyx
|
|
|
|
|
|
logging.info("Qwen CosyVoice tts close")
|
|
|
|
|
|
pass
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
""" canceled for test 语音大模型CosyVoice
|
|
|
|
|
|
def on_event(self, result: SpeechSynthesisResult):
|
|
|
|
|
|
if result.get_audio_frame() is not None:
|
|
|
|
|
|
self.dque.append(result.get_audio_frame())
|
|
|
|
|
|
"""
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def on_event(self, message):
|
|
|
|
|
|
# logging.info(f"recv speech synthsis message {message}")
|
|
|
|
|
|
pass
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# 以下适合语音大模型CosyVoice
|
|
|
|
|
|
def on_data(self, data: bytes) -> None:
|
|
|
|
|
|
if len(data) > 0:
|
|
|
|
|
|
if self.data_callback:
|
|
|
|
|
|
self.data_callback(data)
|
|
|
|
|
|
else:
|
2025-05-15 15:26:06 +08:00
|
|
|
|
self.dque.append(data)
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# --------------------------
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def tts(self, text):
|
|
|
|
|
|
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
2025-05-15 15:26:06 +08:00
|
|
|
|
# text = self.normalize_text(text)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# if self.model_name != 'cosyvoice-v1':
|
|
|
|
|
|
if self.is_cosyvoice is False:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
self.callback = self.Callback()
|
|
|
|
|
|
TTSSpeechSynthesizer.call(model=self.model_name,
|
|
|
|
|
|
text=text,
|
|
|
|
|
|
callback=self.callback,
|
|
|
|
|
|
format="wav") # format="mp3")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
else:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
self.callback = self.Callback_Cosy()
|
|
|
|
|
|
format = self.get_audio_format(self.format, self.sample_rate)
|
|
|
|
|
|
self.synthesizer = CosySpeechSynthesizer(
|
2025-05-15 15:26:06 +08:00
|
|
|
|
model='cosyvoice-v1',
|
|
|
|
|
|
# voice="longyuan", #"longfei",
|
|
|
|
|
|
voice=self.voice,
|
|
|
|
|
|
callback=self.callback,
|
|
|
|
|
|
format=format
|
|
|
|
|
|
)
|
|
|
|
|
|
self.synthesizer.call(text)
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"---dale---20 error {e}") # cyx
|
|
|
|
|
|
# -----------------------------------
|
|
|
|
|
|
try:
|
|
|
|
|
|
for data in self.callback._run():
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# logging.info(f"dashcope return data {len(data)}")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
yield data
|
|
|
|
|
|
# print(f"---Qwen return data {num_tokens_from_string(text)}")
|
|
|
|
|
|
# yield num_tokens_from_string(text)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise RuntimeError(f"**ERROR**: {e}")
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
def setup_tts(self, on_data):
|
|
|
|
|
|
"""设置 TTS 回调,返回配置好的 synthesizer"""
|
|
|
|
|
|
if not self.is_cosyvoice:
|
|
|
|
|
|
raise NotImplementedError("Only CosyVoice supported")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建 CosyVoice 回调
|
|
|
|
|
|
self.callback = self.Callback_Cosy(on_data)
|
|
|
|
|
|
format_val = self.get_audio_format(self.format, self.sample_rate)
|
|
|
|
|
|
logging.info(f"setup_tts {self.voice} {format_val}")
|
|
|
|
|
|
self.synthesizer = CosySpeechSynthesizer(
|
|
|
|
|
|
model='cosyvoice-v1',
|
|
|
|
|
|
voice=self.voice, # voice="longyuan", #"longfei",
|
|
|
|
|
|
callback=self.callback,
|
|
|
|
|
|
format=format_val
|
|
|
|
|
|
)
|
|
|
|
|
|
return self.synthesizer
|
|
|
|
|
|
|
|
|
|
|
|
def text_tts_call(self, text):
|
|
|
|
|
|
if self.synthesizer:
|
|
|
|
|
|
self.synthesizer.call(text)
|
|
|
|
|
|
|
|
|
|
|
|
def streaming_call(self, text):
|
|
|
|
|
|
if self.synthesizer:
|
|
|
|
|
|
self.synthesizer.streaming_call(text)
|
|
|
|
|
|
|
|
|
|
|
|
def end_streaming_call(self):
|
|
|
|
|
|
if self.synthesizer:
|
|
|
|
|
|
# logging.info(f"---dale end_streaming_call")
|
|
|
|
|
|
self.synthesizer.streaming_complete()
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
def get_audio_format(self, format: str, sample_rate: int):
|
|
|
|
|
|
"""动态获取音频格式"""
|
|
|
|
|
|
from dashscope.audio.tts_v2 import AudioFormat
|
2025-05-26 21:38:46 +08:00
|
|
|
|
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
format_map = {
|
|
|
|
|
|
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
|
|
|
|
|
|
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
|
|
|
|
|
|
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
|
|
|
|
|
|
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
|
|
|
|
|
|
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
|
|
|
|
|
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
|
|
|
|
|
|
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
|
|
|
|
|
|
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
|
|
|
|
|
|
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
|
|
|
|
|
|
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
|
2025-07-10 22:04:44 +08:00
|
|
|
|
(48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
|
|
|
|
|
|
(48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
|
|
|
|
|
|
(48000, 'wav'): AudioFormat.WAV_48000HZ_MONO_16BIT
|
2025-05-26 21:38:46 +08:00
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
}
|
|
|
|
|
|
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
|
|
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
import threading
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
import time
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
|
import threading
|
|
|
|
|
|
import uuid
|
|
|
|
|
|
import time
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import logging
|
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
|
|
from collections import deque
|
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnifiedTTSEngine:
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
self.lock = threading.Lock()
|
|
|
|
|
|
self.tasks = {}
|
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=10)
|
|
|
|
|
|
self.cache_expire = 300 # 5分钟缓存
|
|
|
|
|
|
# 启动清理过期任务的定时器
|
|
|
|
|
|
self.cleanup_timer = None
|
|
|
|
|
|
self.start_cleanup_timer()
|
|
|
|
|
|
|
|
|
|
|
|
def _cleanup_old_tasks(self):
|
|
|
|
|
|
"""清理过期任务"""
|
|
|
|
|
|
now = time.time()
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
expired_ids = [task_id for task_id, task in self.tasks.items()
|
|
|
|
|
|
if now - task['created_at'] > self.cache_expire]
|
|
|
|
|
|
for task_id in expired_ids:
|
|
|
|
|
|
self._remove_task(task_id)
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_task(self, task_id):
|
|
|
|
|
|
"""移除任务"""
|
|
|
|
|
|
if task_id in self.tasks:
|
|
|
|
|
|
task = self.tasks.pop(task_id)
|
|
|
|
|
|
# 取消可能的后台任务
|
|
|
|
|
|
if 'future' in task and not task['future'].done():
|
|
|
|
|
|
task['future'].cancel()
|
|
|
|
|
|
# 其他资源在任务被移除后会被垃圾回收
|
|
|
|
|
|
# 资源释放机制总结:
|
|
|
|
|
|
# 移除任务引用:self.tasks.pop() 解除任务对象引用,触发垃圾回收。
|
|
|
|
|
|
# 取消后台线程:future.cancel() 终止未完成线程,释放线程资源。
|
|
|
|
|
|
# 自动内存回收:Python GC 回收任务对象及其队列、缓冲区占用的内存。
|
|
|
|
|
|
# 线程池管理:执行器自动回收线程至池中,避免资源泄漏。
|
|
|
|
|
|
|
|
|
|
|
|
def create_tts_task(self, text, format, sample_rate, model_name, key, delay_gen_audio=False):
|
|
|
|
|
|
"""创建TTS任务(同步方法)"""
|
|
|
|
|
|
self._cleanup_old_tasks()
|
|
|
|
|
|
audio_stream_id = str(uuid.uuid4())
|
|
|
|
|
|
|
|
|
|
|
|
# 创建任务数据结构
|
|
|
|
|
|
task_data = {
|
|
|
|
|
|
'id': audio_stream_id,
|
|
|
|
|
|
'text': text,
|
|
|
|
|
|
'format': format,
|
|
|
|
|
|
'sample_rate': sample_rate,
|
|
|
|
|
|
'model_name': model_name,
|
|
|
|
|
|
'key': key,
|
|
|
|
|
|
'delay_gen_audio': delay_gen_audio,
|
|
|
|
|
|
'created_at': time.time(),
|
|
|
|
|
|
'status': 'pending',
|
|
|
|
|
|
'data_queue': deque(),
|
|
|
|
|
|
'event': threading.Event(),
|
|
|
|
|
|
'completed': False,
|
|
|
|
|
|
'error': None
|
|
|
|
|
|
}
|
|
|
|
|
|
with self.lock:
|
|
|
|
|
|
self.tasks[audio_stream_id] = task_data
|
|
|
|
|
|
|
|
|
|
|
|
# 如果不是延迟模式,立即启动任务
|
|
|
|
|
|
if not delay_gen_audio:
|
|
|
|
|
|
self._start_tts_task(audio_stream_id)
|
|
|
|
|
|
|
|
|
|
|
|
return audio_stream_id
|
|
|
|
|
|
|
|
|
|
|
|
def _start_tts_task(self, audio_stream_id):
|
|
|
|
|
|
# 启动TTS任务(后台线程)
|
|
|
|
|
|
|
|
|
|
|
|
task = self.tasks.get(audio_stream_id)
|
|
|
|
|
|
if not task or task['status'] != 'pending':
|
|
|
|
|
|
return
|
|
|
|
|
|
logging.info("已经启动 start tts task {audio_stream_id}")
|
|
|
|
|
|
task['status'] = 'processing'
|
|
|
|
|
|
|
|
|
|
|
|
# 在后台线程中执行TTS
|
|
|
|
|
|
future = self.executor.submit(self._run_tts_sync, audio_stream_id)
|
|
|
|
|
|
task['future'] = future
|
|
|
|
|
|
|
|
|
|
|
|
# 如果需要等待任务完成
|
|
|
|
|
|
if not task.get('delay_gen_audio', True):
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 等待任务完成(最多5分钟)
|
|
|
|
|
|
future.result(timeout=300)
|
|
|
|
|
|
logging.info(f"TTS任务 {audio_stream_id} 已完成")
|
|
|
|
|
|
self._merge_audio_data(audio_stream_id)
|
|
|
|
|
|
except concurrent.futures.TimeoutError:
|
|
|
|
|
|
task['error'] = "TTS生成超时"
|
|
|
|
|
|
task['completed'] = True
|
|
|
|
|
|
logging.error(f"TTS任务 {audio_stream_id} 超时")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
task['error'] = f"ERROR:{str(e)}"
|
|
|
|
|
|
task['completed'] = True
|
|
|
|
|
|
logging.exception(f"TTS任务执行异常: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
def _run_tts_sync(self, audio_stream_id):
|
|
|
|
|
|
# 同步执行TTS生成 在后台线程中执行
|
|
|
|
|
|
task = self.tasks.get(audio_stream_id)
|
|
|
|
|
|
if not task:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 创建TTS实例
|
|
|
|
|
|
tts = QwenTTS(
|
|
|
|
|
|
key=task['key'],
|
|
|
|
|
|
format=task['format'],
|
|
|
|
|
|
sample_rate=task['sample_rate'],
|
|
|
|
|
|
model_name=task['model_name']
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 定义同步数据处理函数
|
|
|
|
|
|
def data_handler(data):
|
|
|
|
|
|
if data is None: # 结束信号
|
|
|
|
|
|
task['completed'] = True
|
|
|
|
|
|
task['event'].set()
|
|
|
|
|
|
logging.info(f"--data_handler on_complete")
|
|
|
|
|
|
elif data.startswith(b"ERROR"): # 错误信号
|
|
|
|
|
|
task['error'] = data.decode()
|
|
|
|
|
|
task['completed'] = True
|
|
|
|
|
|
task['event'].set()
|
|
|
|
|
|
else: # 音频数据
|
|
|
|
|
|
task['data_queue'].append(data)
|
|
|
|
|
|
|
|
|
|
|
|
# 设置并执行TTS
|
|
|
|
|
|
synthesizer = tts.setup_tts(data_handler)
|
|
|
|
|
|
synthesizer.call(task['text'])
|
|
|
|
|
|
|
|
|
|
|
|
# 等待完成或超时
|
|
|
|
|
|
if not task['event'].wait(timeout=300): # 5分钟超时
|
|
|
|
|
|
task['error'] = "TTS generation timeout"
|
|
|
|
|
|
task['completed'] = True
|
|
|
|
|
|
|
|
|
|
|
|
logging.info(f"--tts task event set error = {task['error']}")
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
task['error'] = f"ERROR:{str(e)}"
|
|
|
|
|
|
task['completed'] = True
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_audio_data(self, audio_stream_id):
|
|
|
|
|
|
"""将任务的所有音频数据合并到ByteIO缓冲区"""
|
|
|
|
|
|
task = self.tasks.get(audio_stream_id)
|
|
|
|
|
|
if not task or not task.get('completed'):
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
logging.info(f"开始合并音频数据: {audio_stream_id}")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建内存缓冲区
|
|
|
|
|
|
buffer = io.BytesIO()
|
|
|
|
|
|
|
|
|
|
|
|
# 合并所有数据块
|
|
|
|
|
|
for data_chunk in task['data_queue']:
|
|
|
|
|
|
buffer.write(data_chunk)
|
|
|
|
|
|
|
|
|
|
|
|
# 重置指针位置以便读取
|
|
|
|
|
|
buffer.seek(0)
|
|
|
|
|
|
|
|
|
|
|
|
# 保存到任务对象
|
|
|
|
|
|
task['buffer'] = buffer
|
|
|
|
|
|
logging.info(f"音频数据合并完成,总大小: {buffer.getbuffer().nbytes} 字节")
|
|
|
|
|
|
|
|
|
|
|
|
# 可选:清理原始数据队列以节省内存
|
|
|
|
|
|
task['data_queue'].clear()
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"合并音频数据失败: {str(e)}")
|
|
|
|
|
|
task['error'] = f"合并错误: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
async def get_audio_stream(self, audio_stream_id):
|
|
|
|
|
|
"""获取音频流(异步生成器)"""
|
|
|
|
|
|
task = self.tasks.get(audio_stream_id)
|
|
|
|
|
|
if not task:
|
|
|
|
|
|
raise RuntimeError("Audio stream not found")
|
|
|
|
|
|
|
|
|
|
|
|
# 如果是延迟任务且未启动,现在启动 status 为 pending
|
|
|
|
|
|
if task['delay_gen_audio'] and task['status'] == 'pending':
|
|
|
|
|
|
self._start_tts_task(audio_stream_id)
|
|
|
|
|
|
|
|
|
|
|
|
# 等待任务启动
|
|
|
|
|
|
while task['status'] == 'pending':
|
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
# 流式返回数据
|
|
|
|
|
|
while not task['completed'] or task['data_queue']:
|
|
|
|
|
|
while task['data_queue']:
|
|
|
|
|
|
data = task['data_queue'].popleft()
|
|
|
|
|
|
# logging.info(f"yield data {len(data)}")
|
|
|
|
|
|
yield data
|
|
|
|
|
|
|
|
|
|
|
|
# 短暂等待新数据
|
|
|
|
|
|
await asyncio.sleep(0.05)
|
|
|
|
|
|
|
|
|
|
|
|
# 检查错误
|
|
|
|
|
|
if task['error']:
|
|
|
|
|
|
raise RuntimeError(task['error'])
|
|
|
|
|
|
|
|
|
|
|
|
def start_cleanup_timer(self):
|
|
|
|
|
|
"""启动定时清理任务"""
|
|
|
|
|
|
if self.cleanup_timer:
|
|
|
|
|
|
self.cleanup_timer.cancel()
|
|
|
|
|
|
|
|
|
|
|
|
self.cleanup_timer = threading.Timer(30.0, self.cleanup_task) # 每30秒清理一次
|
|
|
|
|
|
self.cleanup_timer.daemon = True # 设置为守护线程
|
|
|
|
|
|
self.cleanup_timer.start()
|
|
|
|
|
|
|
|
|
|
|
|
def cleanup_task(self):
|
|
|
|
|
|
"""执行清理任务"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
self._cleanup_old_tasks()
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"清理任务时出错: {str(e)}")
|
|
|
|
|
|
finally:
|
|
|
|
|
|
self.start_cleanup_timer() # 重新启动定时器
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 全局 TTS 引擎实例
|
|
|
|
|
|
tts_engine = UnifiedTTSEngine()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def replace_domain(url: str) -> str:
|
|
|
|
|
|
"""替换URL中的域名为本地地址,不使用urllib.parse"""
|
|
|
|
|
|
# 定义需要替换的域名列表
|
|
|
|
|
|
domains_to_replace = [
|
|
|
|
|
|
"http://1.13.185.116:9380",
|
|
|
|
|
|
"https://ragflow.szzysztech.com",
|
|
|
|
|
|
"1.13.185.116:9380",
|
|
|
|
|
|
"ragflow.szzysztech.com"
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
# 尝试替换每个可能的域名
|
|
|
|
|
|
for domain in domains_to_replace:
|
|
|
|
|
|
if domain in url:
|
|
|
|
|
|
# 直接替换域名部分
|
|
|
|
|
|
return url.replace(domain, "http://localhost:9380", 1)
|
|
|
|
|
|
|
|
|
|
|
|
# 如果未匹配到特定域名,尝试智能替换
|
|
|
|
|
|
if "://" in url:
|
|
|
|
|
|
# 分割协议和路径
|
|
|
|
|
|
protocol, path = url.split("://", 1)
|
|
|
|
|
|
|
|
|
|
|
|
# 查找第一个斜杠位置来确定域名结束位置
|
|
|
|
|
|
slash_pos = path.find("/")
|
|
|
|
|
|
if slash_pos > 0:
|
|
|
|
|
|
# 替换域名部分
|
|
|
|
|
|
return f"http://localhost:9380{path[slash_pos:]}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 没有路径部分,直接返回本地地址
|
|
|
|
|
|
return "http://localhost:9380"
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 没有协议部分,直接添加本地地址
|
|
|
|
|
|
return f"http://localhost:9380/{url}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def proxy_aichat_audio_stream(client_id: str, audio_url: str):
|
|
|
|
|
|
"""代理外部音频流请求"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 替换域名为本地地址
|
|
|
|
|
|
local_url = audio_url
|
|
|
|
|
|
logging.info(f"代理音频流: {audio_url} -> {local_url}")
|
|
|
|
|
|
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
|
|
|
|
async with client.stream("GET", local_url) as response:
|
|
|
|
|
|
# 流式转发音频数据
|
|
|
|
|
|
async for chunk in response.aiter_bytes():
|
|
|
|
|
|
if not await manager.send_bytes(client_id, chunk):
|
|
|
|
|
|
logging.warning(f"Audio proxy interrupted for {client_id}")
|
|
|
|
|
|
return
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"Audio proxy failed: {str(e)}")
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"message": f"音频流获取失败: {str(e)}"
|
|
|
|
|
|
}))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 代理函数 - 文本流
|
|
|
|
|
|
async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict):
|
|
|
|
|
|
"""代理大模型文本流请求 - 兼容现有Flask实现"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
logging.info(f"代理文本流: completions_url={completions_url} {payload}")
|
|
|
|
|
|
logging.debug(f"请求负载: {json.dumps(payload, ensure_ascii=False)}")
|
|
|
|
|
|
|
|
|
|
|
|
headers = {
|
|
|
|
|
|
"Content-Type": "application/json",
|
|
|
|
|
|
'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm'
|
|
|
|
|
|
}
|
|
|
|
|
|
# 创建TTS实例
|
|
|
|
|
|
tts_model = QwenTTS(
|
|
|
|
|
|
key=ALI_KEY,
|
|
|
|
|
|
format=payload.get('tts_stream_format', 'mp3'),
|
|
|
|
|
|
sample_rate=payload.get('tts_sample_rate', 48000),
|
|
|
|
|
|
model_name=payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建流会话
|
|
|
|
|
|
tts_stream_session_id = stream_manager.create_session(
|
|
|
|
|
|
tts_model=tts_model,
|
|
|
|
|
|
sample_rate=payload.get('tts_sample_rate', 48000),
|
|
|
|
|
|
stream_format=payload.get('tts_stream_format', 'mp3'),
|
|
|
|
|
|
session_id=None,
|
|
|
|
|
|
streaming_call=True
|
|
|
|
|
|
)
|
|
|
|
|
|
# logging.info(f"---tts_stream_session_id = {tts_stream_session_id}")
|
|
|
|
|
|
tts_stream_session_id_sent = False
|
|
|
|
|
|
# 使用更长的超时时间 (5分钟)
|
|
|
|
|
|
timeout = httpx.Timeout(300.0, connect=60.0)
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
|
|
|
|
# 关键修改:使用流式请求模式
|
|
|
|
|
|
async with client.stream( # <-- 使用stream方法
|
|
|
|
|
|
"POST",
|
|
|
|
|
|
completions_url,
|
|
|
|
|
|
json=payload,
|
|
|
|
|
|
headers=headers
|
|
|
|
|
|
) as response:
|
|
|
|
|
|
logging.info(f"响应状态: HTTP {response.status_code}")
|
|
|
|
|
|
|
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
|
# 读取错误信息(非流式)
|
|
|
|
|
|
error_content = await response.aread()
|
|
|
|
|
|
error_msg = f"后端错误: HTTP {response.status_code}"
|
|
|
|
|
|
error_msg += f" - {error_content[:200].decode()}" if error_content else ""
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({"type": "error", "message": error_msg}))
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# 验证SSE流
|
|
|
|
|
|
content_type = response.headers.get("content-type", "").lower()
|
|
|
|
|
|
if "text/event-stream" not in content_type:
|
|
|
|
|
|
logging.warning("非流式响应,转发完整内容")
|
|
|
|
|
|
full_content = await response.aread()
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({
|
|
|
|
|
|
"type": "text",
|
|
|
|
|
|
"data": full_content.decode('utf-8')
|
|
|
|
|
|
}))
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
logging.info("开始处理SSE流")
|
|
|
|
|
|
event_count = 0
|
|
|
|
|
|
# 使用异步迭代器逐行处理
|
|
|
|
|
|
async for line in response.aiter_lines():
|
|
|
|
|
|
# 跳过空行和注释行
|
|
|
|
|
|
if not line or line.startswith(':'):
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
# 处理SSE事件
|
|
|
|
|
|
if line.startswith("data:"):
|
|
|
|
|
|
data_str = line[5:].strip()
|
|
|
|
|
|
if data_str: # 过滤空数据
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 解析并提取增量文本
|
|
|
|
|
|
data_obj = json.loads(data_str)
|
|
|
|
|
|
delta_text = None
|
|
|
|
|
|
if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict):
|
|
|
|
|
|
delta_text = data_obj.get('data', None).get('delta_ans', "")
|
|
|
|
|
|
if tts_stream_session_id_sent is False:
|
|
|
|
|
|
data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}"
|
|
|
|
|
|
data_str = json.dumps(data_obj)
|
|
|
|
|
|
tts_stream_session_id_sent = True
|
|
|
|
|
|
|
|
|
|
|
|
# 直接转发原始数据
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({
|
|
|
|
|
|
"type": "text",
|
|
|
|
|
|
"data": data_str
|
|
|
|
|
|
}))
|
|
|
|
|
|
# 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析
|
|
|
|
|
|
if delta_text:
|
|
|
|
|
|
# 追加到会话管理器
|
|
|
|
|
|
stream_manager.append_text(tts_stream_session_id, delta_text)
|
|
|
|
|
|
# logging.info(f"文本代理转发: {data_str}")
|
|
|
|
|
|
event_count += 1
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"事件发送失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
# 保持连接活性
|
|
|
|
|
|
await asyncio.sleep(0.001) # 避免CPU空转
|
|
|
|
|
|
|
|
|
|
|
|
logging.info(f"SSE流处理完成,事件数: {event_count}")
|
|
|
|
|
|
|
|
|
|
|
|
# 发送结束信号
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({"type": "end"}))
|
|
|
|
|
|
|
|
|
|
|
|
except httpx.ReadTimeout:
|
|
|
|
|
|
logging.error("读取后端服务超时")
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"message": "后端服务响应超时"
|
|
|
|
|
|
}))
|
|
|
|
|
|
except httpx.ConnectError as e:
|
|
|
|
|
|
logging.error(f"连接后端服务失败: {str(e)}")
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"message": f"无法连接到后端服务: {str(e)}"
|
|
|
|
|
|
}))
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.exception(f"文本代理失败: {str(e)}")
|
|
|
|
|
|
await manager.send_text(client_id, json.dumps({
|
|
|
|
|
|
"type": "error",
|
|
|
|
|
|
"message": f"文本流获取失败: {str(e)}"
|
|
|
|
|
|
}))
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
|
|
|
|
|
@tts_router.get("/audio/pcm_mp3")
|
|
|
|
|
|
async def stream_mp3():
|
|
|
|
|
|
def audio_generator():
|
|
|
|
|
|
path = './test.mp3'
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open(path, 'rb') as f:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
chunk = f.read(1024)
|
|
|
|
|
|
if not chunk:
|
|
|
|
|
|
break
|
|
|
|
|
|
yield chunk
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"MP3 streaming error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
audio_generator(),
|
|
|
|
|
|
media_type="audio/mpeg",
|
2025-05-26 21:38:46 +08:00
|
|
|
|
headers={
|
|
|
|
|
|
"Cache-Control": "no-store",
|
|
|
|
|
|
"Accept-Ranges": "bytes"
|
|
|
|
|
|
}
|
2025-05-15 15:26:06 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
|
|
|
|
|
|
"""动态生成WAV头(严格保持原有逻辑结构)"""
|
|
|
|
|
|
with BytesIO() as wav_buffer:
|
|
|
|
|
|
with wave.open(wav_buffer, 'wb') as wav_file:
|
|
|
|
|
|
wav_file.setnchannels(1) # 保持原单声道设置
|
|
|
|
|
|
wav_file.setsampwidth(2) # 保持原16-bit设置
|
|
|
|
|
|
wav_file.setframerate(sample_rate)
|
|
|
|
|
|
wav_file.writeframes(pcm_data)
|
|
|
|
|
|
wav_buffer.seek(0)
|
|
|
|
|
|
return wav_buffer.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_silence_header(duration_ms: int = 500) -> bytes:
|
|
|
|
|
|
"""生成静音数据(用于MP3流式传输预缓冲)"""
|
|
|
|
|
|
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
|
|
|
|
|
|
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
# ------------------------ API路由 ------------------------
|
2025-05-15 15:26:06 +08:00
|
|
|
|
@tts_router.post("/chats/{chat_id}/tts")
|
|
|
|
|
|
async def create_tts_request(chat_id: str, request: Request):
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = await request.json()
|
2025-07-10 22:04:44 +08:00
|
|
|
|
logging.info(f"Creating TTS request: {data}")
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
# 参数校验
|
|
|
|
|
|
text = data.get("text", "").strip()
|
|
|
|
|
|
if not text:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
raise HTTPException(400, detail="Text cannot be empty")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
format = data.get("tts_stream_format", "mp3")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
if format not in ["mp3", "wav", "pcm"]:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
raise HTTPException(400, detail="Unsupported audio format")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
sample_rate = data.get("tts_sample_rate", 48000)
|
2025-07-10 22:04:44 +08:00
|
|
|
|
if sample_rate not in [8000, 16000, 22050, 44100, 48000]:
|
|
|
|
|
|
raise HTTPException(400, detail="Unsupported sample rate")
|
|
|
|
|
|
|
|
|
|
|
|
model_name = data.get("model_name", "cosyvoice-v1/longxiaochun")
|
|
|
|
|
|
delay_gen_audio = data.get('delay_gen_audio', False)
|
|
|
|
|
|
|
|
|
|
|
|
# 创建TTS任务
|
|
|
|
|
|
audio_stream_id = tts_engine.create_tts_task(
|
|
|
|
|
|
text=text,
|
|
|
|
|
|
format=format,
|
|
|
|
|
|
sample_rate=sample_rate,
|
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
|
key=ALI_KEY,
|
|
|
|
|
|
delay_gen_audio=delay_gen_audio
|
|
|
|
|
|
)
|
2025-05-26 21:38:46 +08:00
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
status_code=200,
|
|
|
|
|
|
content={
|
2025-07-10 22:04:44 +08:00
|
|
|
|
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
2025-05-26 21:38:46 +08:00
|
|
|
|
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
2025-07-10 22:04:44 +08:00
|
|
|
|
"ws_url": f"/chats/{chat_id}/tts/{audio_stream_id}", # WebSocket URL 2025 0622新增
|
|
|
|
|
|
"expires_at": (datetime.datetime.now() + datetime.timedelta(seconds=300)).isoformat()
|
2025-05-15 15:26:06 +08:00
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
2025-07-10 22:04:44 +08:00
|
|
|
|
logging.error(f"Request failed: {str(e)}")
|
|
|
|
|
|
raise HTTPException(500, detail="Internal server error")
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
executor = ThreadPoolExecutor()
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
|
|
|
|
|
|
async def get_tts_audio(
|
2025-05-15 15:26:06 +08:00
|
|
|
|
chat_id: str,
|
|
|
|
|
|
audio_stream_id: str,
|
|
|
|
|
|
range: str = Header(None)
|
|
|
|
|
|
):
|
2025-07-10 22:04:44 +08:00
|
|
|
|
try:
|
|
|
|
|
|
# 获取任务信息
|
|
|
|
|
|
task = tts_engine.tasks.get(audio_stream_id)
|
|
|
|
|
|
if not task:
|
|
|
|
|
|
# 返回友好的错误信息而不是抛出异常
|
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
|
status_code=404,
|
|
|
|
|
|
content={
|
|
|
|
|
|
"error": "Audio stream not found",
|
|
|
|
|
|
"message": f"The requested audio stream ID '{audio_stream_id}' does not exist or has expired",
|
|
|
|
|
|
"suggestion": "Please create a new TTS request and try again"
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# 获取媒体类型
|
|
|
|
|
|
format = task['format']
|
|
|
|
|
|
media_type = {
|
|
|
|
|
|
"mp3": "audio/mpeg",
|
|
|
|
|
|
"wav": "audio/wav",
|
|
|
|
|
|
"pcm": f"audio/L16; rate={task['sample_rate']}; channels=1"
|
|
|
|
|
|
}[format]
|
|
|
|
|
|
|
|
|
|
|
|
# 如果任务已完成且有完整缓冲区,处理Range请求
|
|
|
|
|
|
logging.info(f"get_tts_audio task = {task.get('completed', 'None')} {task.get('buffer', 'None')}")
|
|
|
|
|
|
|
|
|
|
|
|
# 创建响应内容生成器
|
|
|
|
|
|
def buffer_read(buffer):
|
|
|
|
|
|
content_length = buffer.getbuffer().nbytes
|
|
|
|
|
|
remaining = content_length
|
|
|
|
|
|
chunk_size = 4096
|
|
|
|
|
|
buffer.seek(0)
|
|
|
|
|
|
while remaining > 0:
|
|
|
|
|
|
read_size = min(remaining, chunk_size)
|
|
|
|
|
|
data = buffer.read(read_size)
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
break
|
|
|
|
|
|
yield data
|
|
|
|
|
|
remaining -= len(data)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
if task.get('completed') and task.get('buffer') is not None:
|
|
|
|
|
|
buffer = task['buffer']
|
|
|
|
|
|
total_size = buffer.getbuffer().nbytes
|
|
|
|
|
|
# 强制小文件使用流式传输(避免206响应问题)
|
|
|
|
|
|
|
|
|
|
|
|
if total_size < 1024 * 120: # 小于300KB
|
|
|
|
|
|
range = None
|
|
|
|
|
|
|
|
|
|
|
|
if range:
|
|
|
|
|
|
# 处理范围请求
|
|
|
|
|
|
return handle_range_request(range, buffer, total_size, media_type)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return StreamingResponse(
|
|
|
|
|
|
buffer_read(buffer),
|
|
|
|
|
|
media_type=media_type,
|
|
|
|
|
|
headers={
|
|
|
|
|
|
"Accept-Ranges": "bytes",
|
|
|
|
|
|
"Cache-Control": "no-store",
|
|
|
|
|
|
"Transfer-Encoding": "chunked"
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
# 创建流式响应
|
|
|
|
|
|
logging.info("tts_engine.get_audio_stream--0")
|
2025-05-26 21:38:46 +08:00
|
|
|
|
return StreamingResponse(
|
2025-07-10 22:04:44 +08:00
|
|
|
|
tts_engine.get_audio_stream(audio_stream_id),
|
2025-05-26 21:38:46 +08:00
|
|
|
|
media_type=media_type,
|
|
|
|
|
|
headers={
|
|
|
|
|
|
"Accept-Ranges": "bytes",
|
|
|
|
|
|
"Cache-Control": "no-store",
|
2025-07-10 22:04:44 +08:00
|
|
|
|
"Transfer-Encoding": "chunked"
|
2025-05-26 21:38:46 +08:00
|
|
|
|
}
|
|
|
|
|
|
)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"Audio streaming failed: {str(e)}")
|
|
|
|
|
|
raise HTTPException(500, detail="Audio generation error")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
def handle_range_request(range_header: str, buffer: BytesIO, total_size: int, media_type: str):
|
2025-05-26 21:38:46 +08:00
|
|
|
|
"""处理 HTTP Range 请求"""
|
2025-05-15 15:26:06 +08:00
|
|
|
|
try:
|
2025-05-26 21:38:46 +08:00
|
|
|
|
# 解析 Range 头部 (示例: "bytes=0-1023")
|
|
|
|
|
|
range_type, range_spec = range_header.split('=')
|
|
|
|
|
|
if range_type != 'bytes':
|
|
|
|
|
|
raise ValueError("Unsupported range type")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
start_str, end_str = range_spec.split('-')
|
|
|
|
|
|
start = int(start_str)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
end = int(end_str) if end_str else total_size - 1
|
2025-07-10 22:04:44 +08:00
|
|
|
|
logging.info(f"handle_range_request--1 {start_str}-{end_str} {end}")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
# 验证范围有效性
|
|
|
|
|
|
if start >= total_size or end >= total_size:
|
2025-05-26 21:38:46 +08:00
|
|
|
|
raise HTTPException(status_code=416, headers={
|
2025-05-15 15:26:06 +08:00
|
|
|
|
"Content-Range": f"bytes */{total_size}"
|
|
|
|
|
|
})
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 计算内容长度
|
|
|
|
|
|
content_length = end - start + 1
|
|
|
|
|
|
|
|
|
|
|
|
# 设置状态码
|
|
|
|
|
|
status_code = 206 # Partial Content
|
|
|
|
|
|
if start == 0 and end == total_size - 1:
|
|
|
|
|
|
status_code = 200 # Full Content
|
|
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
# 设置流读取位置
|
2025-05-15 15:26:06 +08:00
|
|
|
|
buffer.seek(start)
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
# 创建响应内容生成器
|
|
|
|
|
|
def content_generator():
|
|
|
|
|
|
remaining = content_length
|
|
|
|
|
|
chunk_size = 4096
|
|
|
|
|
|
while remaining > 0:
|
|
|
|
|
|
read_size = min(remaining, chunk_size)
|
|
|
|
|
|
data = buffer.read(read_size)
|
|
|
|
|
|
if not data:
|
|
|
|
|
|
break
|
|
|
|
|
|
yield data
|
|
|
|
|
|
remaining -= len(data)
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
# 返回分块响应
|
2025-05-15 15:26:06 +08:00
|
|
|
|
return StreamingResponse(
|
2025-07-10 22:04:44 +08:00
|
|
|
|
content_generator(),
|
2025-05-26 21:38:46 +08:00
|
|
|
|
status_code=status_code,
|
|
|
|
|
|
media_type=media_type,
|
2025-05-15 15:26:06 +08:00
|
|
|
|
headers={
|
|
|
|
|
|
"Content-Range": f"bytes {start}-{end}/{total_size}",
|
|
|
|
|
|
"Content-Length": str(content_length),
|
2025-05-26 21:38:46 +08:00
|
|
|
|
"Accept-Ranges": "bytes",
|
|
|
|
|
|
"Cache-Control": "public, max-age=3600"
|
2025-05-15 15:26:06 +08:00
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-05-26 21:38:46 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
2025-05-15 15:26:06 +08:00
|
|
|
|
|
2025-07-10 22:04:44 +08:00
|
|
|
|
|
|
|
|
|
|
@tts_router.websocket("/chats/{chat_id}/tts/{audio_stream_id}")
|
|
|
|
|
|
async def websocket_tts_endpoint(
|
|
|
|
|
|
websocket: WebSocket,
|
|
|
|
|
|
chat_id: str,
|
|
|
|
|
|
audio_stream_id: str
|
|
|
|
|
|
):
|
|
|
|
|
|
# 接收 header 参数
|
|
|
|
|
|
headers = websocket.headers
|
|
|
|
|
|
service_type = headers.get("x-tts-type") # 注意:header 名称转为小写
|
|
|
|
|
|
# audio_url = headers.get("x-audio-url")
|
|
|
|
|
|
"""
|
|
|
|
|
|
前端示例
|
|
|
|
|
|
websocketConnection = uni.connectSocket({
|
|
|
|
|
|
url: url,
|
|
|
|
|
|
header: {
|
|
|
|
|
|
'Authorization': token,
|
|
|
|
|
|
'X-Tts-Type': 'AiChat', //'Ask' // 自定义参数1
|
|
|
|
|
|
'X-Device-Type': 'mobile', // 自定义参数2
|
|
|
|
|
|
'X-User-ID': '12345' // 自定义参数3
|
|
|
|
|
|
},
|
|
|
|
|
|
success: () => {
|
|
|
|
|
|
console.log('WebSocket connected');
|
|
|
|
|
|
},
|
|
|
|
|
|
fail: (err) => {
|
|
|
|
|
|
console.error('WebSocket connection failed:', err);
|
|
|
|
|
|
}
|
|
|
|
|
|
});
|
|
|
|
|
|
"""
|
|
|
|
|
|
# 创建唯一连接 ID
|
|
|
|
|
|
connection_id = str(uuid.uuid4())
|
|
|
|
|
|
# logging.info(f"---dale-- websocket connection_id = {connection_id} chat_id={chat_id}")
|
|
|
|
|
|
await manager.connect(websocket, connection_id)
|
|
|
|
|
|
|
|
|
|
|
|
completed_successfully = False
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 根据tts_type路由到不同的音频源
|
|
|
|
|
|
if service_type == "AiChatTts":
|
|
|
|
|
|
# 音频代理服务
|
|
|
|
|
|
audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}"
|
|
|
|
|
|
# await proxy_aichat_audio_stream(connection_id, audio_url)
|
|
|
|
|
|
sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate')
|
|
|
|
|
|
await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate})
|
|
|
|
|
|
async for data in stream_manager.get_tts_buffer_data(audio_stream_id):
|
|
|
|
|
|
if not await manager.send_bytes(connection_id, data):
|
|
|
|
|
|
break
|
|
|
|
|
|
completed_successfully = True
|
|
|
|
|
|
|
|
|
|
|
|
elif service_type == "AiChatText":
|
|
|
|
|
|
# 文本代理服务
|
|
|
|
|
|
# 等待客户端发送初始请求数据 进行大模型对话代理时,需要前端连接后发送payload
|
|
|
|
|
|
payload = await websocket.receive_json()
|
|
|
|
|
|
completions_url = f"http://localhost:9380/api/v1/chats/{chat_id}/completions"
|
|
|
|
|
|
await proxy_aichat_text_stream(connection_id, completions_url, payload)
|
|
|
|
|
|
completed_successfully = True
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 使用引擎的生成器直接获取音频流
|
|
|
|
|
|
async for data in tts_engine.get_audio_stream(audio_stream_id):
|
|
|
|
|
|
if not await manager.send_bytes(connection_id, data):
|
|
|
|
|
|
logging.warning(f"Send failed, connection closed: {connection_id}")
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
completed_successfully = True
|
|
|
|
|
|
|
|
|
|
|
|
# 发送完成信号前检查连接状态
|
|
|
|
|
|
if manager.is_connected(connection_id):
|
|
|
|
|
|
# 发送完成信号
|
|
|
|
|
|
await manager.send_json(connection_id, {"status": "completed"})
|
|
|
|
|
|
|
|
|
|
|
|
# 添加短暂延迟确保消息送达
|
|
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
|
|
|
|
|
|
|
|
# 主动关闭WebSocket连接
|
|
|
|
|
|
await manager.disconnect(connection_id, code=1000, reason="Audio stream completed")
|
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
|
logging.info(f"WebSocket disconnected: {connection_id}")
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
logging.error(f"WebSocket TTS error: {str(e)}")
|
|
|
|
|
|
if manager.is_connected(connection_id):
|
|
|
|
|
|
await manager.send_json(connection_id, {"error": str(e)})
|
|
|
|
|
|
finally:
|
|
|
|
|
|
pass
|
|
|
|
|
|
# await manager.disconnect(connection_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-05-15 15:26:06 +08:00
|
|
|
|
def cleanup_cache():
|
|
|
|
|
|
"""清理过期缓存"""
|
|
|
|
|
|
with cache_lock:
|
|
|
|
|
|
now = datetime.datetime.now()
|
|
|
|
|
|
expired = [k for k, v in audio_text_cache.items()
|
|
|
|
|
|
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
|
|
|
|
|
|
for key in expired:
|
2025-05-26 21:38:46 +08:00
|
|
|
|
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
|
2025-05-15 15:26:06 +08:00
|
|
|
|
del audio_text_cache[key]
|
|
|
|
|
|
|
|
|
|
|
|
# 应用启动时启动清理线程
|
2025-07-10 22:04:44 +08:00
|
|
|
|
# start_background_cleaner()
|