Files
ragflow_python/asr-monitor-test/app/tts_service.py

1775 lines
67 KiB
Python
Raw Normal View History

import logging
import binascii
from copy import deepcopy
from timeit import default_timer as timer
import datetime
from datetime import timedelta
import threading, time, queue, uuid, time, array
from threading import Lock, Thread
from concurrent.futures import ThreadPoolExecutor
import base64, gzip
import os, io, re, json
from io import BytesIO
from typing import Optional, Dict, Any
import asyncio, httpx
from collections import deque
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query
from fastapi import FastAPI, UploadFile, File, Form, Header
from fastapi.responses import StreamingResponse, JSONResponse, Response
TTS_SAMPLERATE = 44100 # 22050 # 16000
FORMAT = "mp3"
CHANNELS = 1 # 单声道
SAMPLE_WIDTH = 2 # 16-bit = 2字节
tts_router = APIRouter()
# logger = logging.getLogger(__name__)
class MillisecondsFormatter(logging.Formatter):
"""自定义日志格式器,添加毫秒时间戳"""
def formatTime(self, record, datefmt=None):
# 将时间戳转换为本地时间元组
ct = self.converter(record.created)
# 格式化为 "小时:分钟:秒"
t = time.strftime("%H:%M:%S", ct)
# 添加毫秒3位
return f"{t}.{int(record.msecs):03d}"
# 配置全局日志格式
def configure_logging():
# 创建 Formatter
log_format = "%(asctime)s - %(levelname)s - %(message)s"
formatter = MillisecondsFormatter(log_format)
# 获取根 Logger 并清除已有配置
root_logger = logging.getLogger()
root_logger.handlers = []
# 创建并配置 Handler输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
# 设置日志级别并添加 Handler
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
# 调用配置函数(程序启动时运行一次)
configure_logging()
class StreamSessionManager:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次
self.streaming_call_timeout = 15 # 20s
self.gc_tts = 3 # 3s
self.sentence_timeout = 1.5 # 1500ms句子超时
self.sentence_endings = set('。?!;.?!;') # 中英文结束符
# 增强版正则表达式:匹配中英文句子结束符(包含全角)
self.sentence_pattern = re.compile(
r'([,,。?!;.?!;!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;!;…])'
)
def create_session(self, tts_model, sample_rate=8000, stream_format='mp3', session_id=None, streaming_call=False):
if not session_id:
session_id = str(uuid.uuid4())
with self.lock:
# 创建TTS实例并设置流式回调
tts_instance = tts_model
# 定义音频数据回调函数
def on_data(data: bytes):
if data:
try:
self.sessions[session_id]['last_active'] = time.time()
self.sessions[session_id]['buffer'].put(data)
except queue.Full:
logging.warning(f"Audio buffer full for session {session_id}")
# 设置TTS流式传输
tts_instance.setup_tts(on_data)
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count': 0,
'finished': threading.Event(), # 添加事件对象
'sample_rate': sample_rate,
'stream_format': stream_format,
"tts_chunk_data_valid": False,
"text_buffer": "", # 新增文本缓冲区
"last_text_time": time.time(), # 最后文本到达时间
"streaming_call": streaming_call,
"tts_stream_started": False # 标记是否已启动流
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 更新文本缓冲区和时间戳
session['text_buffer'] += text
session['last_text_time'] = time.time()
# 将文本放入任务队列(非阻塞)
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
"""任务处理线程(每个会话独立)"""
session = self.sessions.get(session_id)
if not session or not session['active']:
return
gen_tts_audio_func = self._generate_audio
if session.get('streaming_call'):
gen_tts_audio_func = self._stream_audio
while session['active']:
current_time = time.time()
text_to_process = ""
# 直接处理缓冲区文本(无中间变量)
with self.lock:
if session['text_buffer']:
text_to_process = session['text_buffer']
session['text_buffer'] = "" # 清空缓冲区
if text_to_process:
# 分割完整句子
complete_sentences, remaining_text = self._split_and_extract(text_to_process)
# 保存剩余文本
if remaining_text:
with self.lock:
session['text_buffer'] = remaining_text + session['text_buffer']
# 合并并处理完整句子
if complete_sentences:
# 智能合并句子最长300字符
buffer = []
current_length = 0
for sentence in complete_sentences:
sent_length = len(sentence)
# 添加到当前缓冲区
if current_length + sent_length <= 300:
buffer.append(sentence)
current_length += sent_length
else:
# 处理已缓冲的文本
if buffer:
gen_tts_audio_func(session_id, "".join(buffer))
buffer = [sentence]
current_length = sent_length
# 处理剩余的缓冲文本
if buffer:
gen_tts_audio_func(session_id, "".join(buffer))
# 检查超时未处理的文本
if current_time - session['last_text_time'] > self.sentence_timeout:
with self.lock:
if session['text_buffer']:
# 直接处理剩余文本
gen_tts_audio_func(session_id, session['text_buffer'])
session['text_buffer'] = ""
if current_time - session['last_active'] > self.streaming_call_timeout:
if session.get('streaming_call'):
session['tts_model'].end_streaming_call()
session['streaming_call'] = False
# 会话超时检查
if current_time - session['last_active'] > self.gc_interval:
with self.lock:
if session['text_buffer']:
gen_tts_audio_func(session_id, session['text_buffer'])
session['text_buffer'] = ""
self.close_session(session_id)
break
# 休眠避免CPU空转
time.sleep(0.05) # 50ms检查间隔
def _generate_audio(self, session_id, text):
"""实际生成音频(线程池执行)"""
session = self.sessions.get(session_id)
if not session: return
logging.info(f"_generate_audio:{text}")
first_chunk = True
# logging.info(f"转换开始!!! {text}")
try:
"""
for chunk in session['tts_model'].tts(text, session['sample_rate'], session['stream_format']):
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk, 1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
first_chunk = False
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
"""
session['tts_model'].text_tts_call(text)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True # 20250510 增加表示连接TTS后台已经返回可以通知前端了
# logging.info(f"转换结束!!! {session['audio_chunk_count']}")
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
def _stream_audio(self, session_id, text):
"""流式传输文本到TTS服务"""
session = self.sessions.get(session_id)
if not session:
return
# logging.info(f"Streaming text to TTS: {text}")
try:
# 使用流式调用发送文本
session['tts_model'].streaming_call(text)
session['last_active'] = time.time()
except Exception as e:
logging.error(f"Error in streaming_call: {str(e)}")
session['buffer'].put(f"ERROR:{str(e)}".encode())
async def get_tts_buffer_data(self, session_id):
"""异步流式返回 TTS 音频数据(适配同步 queue.Queue带 10 秒超时)"""
session = self.sessions.get(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
buffer = session['buffer'] # 这里是 queue.Queue
last_data_time = time.time() # 记录最后一次获取数据的时间
while session['active']:
try:
# 使用 run_in_executor + wait_for 设置 10 秒超时
data = await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, buffer.get),
timeout=10.0 # 10 秒超时
)
last_data_time = time.time() # 更新最后数据时间
yield data
except asyncio.TimeoutError:
# 10 秒内没有新数据,检查是否超时
if time.time() - last_data_time >= 10.0:
break
else:
continue # 未超时,继续等待
except asyncio.CancelledError:
logging.info(f"Session {session_id} stream cancelled")
break
except Exception as e:
logging.error(f"Error in get_tts_buffer_data: {e}")
break
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 结束流式传输
try:
# if self.sessions[session_id].get('streaming_call'):
# self.sessions[session_id]['tts_model'].end_streaming_call()
logging.info(f"Ended streaming for session {session_id}")
except Exception as e:
logging.error(f"Error ending streaming call: {str(e)}")
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 确保流完全关闭
try:
self.sessions[session_id]['tts_model'].end_streaming_call()
except:
pass
del self.sessions[session_id]
def get_session(self, session_id):
return self.sessions.get(session_id)
def _has_sentence_ending(self, text):
"""检测文本是否包含句子结束符"""
if not text:
return False
# 检查常见结束符(包含全角字符)
if any(char in self.sentence_endings for char in text[-3:]):
return True
# 检查中文段落结束(换行符前有结束符)
if '\n' in text and any(char in self.sentence_endings for char in text.split('\n')[-2:-1]):
return True
return False
def _split_and_extract(self, text):
"""
增强型句子分割器
返回: (完整句子列表, 剩余文本)
"""
# 特殊处理:如果文本以逗号开头,先处理前面的部分
if text.startswith((",", "")):
return [text[0]], text[1:]
# 1. 查找所有可能的句子结束位置
matches = list(self.sentence_pattern.finditer(text))
if not matches:
return [], text # 没有找到结束符
# 2. 确定最后一个完整句子的结束位置
last_end = 0
complete_sentences = []
for match in matches:
end_pos = match.end()
sentence = text[last_end:end_pos].strip()
# 跳过空句子
if not sentence:
last_end = end_pos
continue
# 检查是否为有效句子(最小长度或包含结束符)
if len(sentence) > 6 or any(char in "。.?!!" for char in sentence):
complete_sentences.append(sentence)
last_end = end_pos
else:
# 短文本但包含结束符,可能是特殊符号
if any(char in "。.?!!" for char in sentence):
complete_sentences.append(sentence)
last_end = end_pos
# 3. 提取剩余文本
remaining_text = text[last_end:].strip()
return complete_sentences, remaining_text
stream_manager = StreamSessionManager()
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# WebSocket 连接管理
class ConnectionManager:
def __init__(self):
self.active_connections = {}
async def connect(self, websocket: WebSocket, connection_id: str):
await websocket.accept()
self.active_connections[connection_id] = websocket
logging.info(f"新连接建立: {connection_id}")
async def disconnect(self, connection_id: str, code=1000, reason: str = ""):
if connection_id in self.active_connections:
try:
# 尝试正常关闭连接(非阻塞)
await self.active_connections[connection_id].close(code=code, reason=reason)
except:
pass # 忽略关闭错误
finally:
del self.active_connections[connection_id]
def is_connected(self, connection_id: str) -> bool:
"""检查连接是否仍然活跃"""
return connection_id in self.active_connections
async def _safe_send(self, connection_id: str, send_func, *args):
"""安全发送的通用方法(核心修改)"""
# 1. 检查连接是否存在
if connection_id not in self.active_connections:
logging.warning(f"尝试向不存在的连接发送数据: {connection_id}")
return False
websocket = self.active_connections[connection_id]
try:
# 2. 检查连接状态(关键修改)
if websocket.client_state.name != "CONNECTED":
logging.warning(f"连接 {connection_id} 状态为 {websocket.client_state.name}")
await self.disconnect(connection_id)
return False
# 3. 执行发送操作
await send_func(websocket, *args)
return True
except (WebSocketDisconnect, RuntimeError) as e:
# 4. 处理连接断开异常
logging.info(f"发送时检测到断开连接: {connection_id}, {str(e)}")
await self.disconnect(connection_id)
return False
except Exception as e:
# 5. 处理其他异常
logging.error(f"发送数据出错: {connection_id}, {str(e)}")
await self.disconnect(connection_id)
return False
async def send_bytes(self, connection_id: str, data: bytes):
"""安全发送字节数据"""
return await self._safe_send(
connection_id,
lambda ws, d: ws.send_bytes(d),
data
)
async def send_text(self, connection_id: str, message: str):
"""安全发送文本数据"""
return await self._safe_send(
connection_id,
lambda ws, m: ws.send_text(m),
message
)
async def send_json(self, connection_id: str, data: dict):
"""安全发送JSON数据"""
return await self._safe_send(
connection_id,
lambda ws, d: ws.send_json(d),
data
)
manager = ConnectionManager()
def generate_mp3_header(
sample_rate: int,
bitrate_kbps: int,
channels: int = 1,
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
) -> bytes:
"""
动态生成 MP3 帧头4字节支持 Layer I/II/III
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
:param bitrate_kbps: 比特率单位 kbps
:param channels: 声道数 (1: 单声道, 2: 立体声)
:param layer: 编码层 ("I", "II", "III")
:return: 4字节的帧头数据
"""
# ----------------------------------
# 参数校验
# ----------------------------------
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
if sample_rate not in valid_sample_rates:
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
valid_layers = {"I", "II", "III"}
if layer not in valid_layers:
raise ValueError(f"不支持的层,可选:{valid_layers}")
# ----------------------------------
# 确定 MPEG 版本和采样率索引
# ----------------------------------
if sample_rate == 44100:
mpeg_version = 0b11 # MPEG-1
sample_rate_index = 0b00
elif sample_rate == 22050:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b00
elif sample_rate == 16000:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b10
elif sample_rate == 8000:
mpeg_version = 0b00 # MPEG-2.5
sample_rate_index = 0b10
else:
raise ValueError("采样率与版本不匹配")
# ----------------------------------
# 动态选择比特率表(关键扩展)
# ----------------------------------
# Layer 编码映射I:0b11, II:0b10, III:0b01
layer_code = {
"I": 0b11,
"II": 0b10,
"III": 0b01
}[layer]
# 比特率表(覆盖所有 Layer
bitrate_tables = {
# -------------------------------
# MPEG-1 (0b11)
# -------------------------------
# Layer I
(0b11, 0b11): {
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
416: 0b1100, 448: 0b1101
},
# Layer II
(0b11, 0b10): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
320: 0b1100, 384: 0b1101
},
# Layer III
(0b11, 0b01): {
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
256: 0b0100, 320: 0b0101
},
# -------------------------------
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
# -------------------------------
# Layer I
(0b10, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
(0b00, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
# Layer II
(0b10, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
(0b00, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
# Layer III
(0b10, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
144: 0b0100, 160: 0b0101
},
(0b00, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
}
}
# 获取当前版本的比特率表
key = (mpeg_version, layer_code)
if key not in bitrate_tables:
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
bitrate_table = bitrate_tables[key]
if bitrate_kbps not in bitrate_table:
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
bitrate_index = bitrate_table[bitrate_kbps]
# ----------------------------------
# 确定声道模式
# ----------------------------------
if channels == 1:
channel_mode = 0b11 # 单声道
elif channels == 2:
channel_mode = 0b00 # 立体声
else:
raise ValueError("声道数必须为1或2")
# ----------------------------------
# 组合帧头字段(修正层编码)
# ----------------------------------
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
version = mpeg_version << 19 # MPEG 版本 2位
layer_bits = layer_code << 17 # Layer 编码I:0b11, II:0b10, III:0b01
protection = 0 << 16 # 无 CRC
bitrate_bits = bitrate_index << 12
sample_rate_bits = sample_rate_index << 10
padding = 0 << 9 # 无填充
private = 0 << 8
mode = channel_mode << 6
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
copyright = 0 << 3
original = 0 << 2
emphasis = 0b00 # 无强调
frame_header = (
sync |
version |
layer_bits |
protection |
bitrate_bits |
sample_rate_bits |
padding |
private |
mode |
mode_ext |
copyright |
original |
emphasis
)
return frame_header.to_bytes(4, byteorder='big')
# ------------------------------------------------
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()
def parse_markdown_json(json_string):
# 使用正则表达式匹配Markdown中的JSON代码块
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
if match:
try:
# 尝试解析JSON字符串
data = json.loads(match[1])
return {'success': True, 'data': data}
except json.JSONDecodeError as e:
# 如果解析失败,返回错误信息
return {'success': False, 'data': str(e)}
else:
return {'success': False, 'data': 'not a valid markdown json string'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def encode_gzip_base64(original_data: bytes) -> str:
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
# Step 1: Gzip 压缩
with BytesIO() as buf:
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
gz_file.write(original_data)
compressed_bytes = buf.getvalue()
# Step 2: Base64 编码配置与Android端匹配
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符等同于Android的Base64.NO_WRAP
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
now = time.time()
expired_keys = [
k for k, v in audio_text_cache.items()
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
]
for k in expired_keys:
entry = audio_text_cache.pop(k, None)
if entry and entry.get('audio_stream'):
entry['audio_stream'].close()
def start_background_cleaner():
"""启动后台清理线程"""
def cleaner_loop():
while True:
time.sleep(180) # 每3分钟清理一次
clean_audio_cache()
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
cleaner_thread.start()
def test_qwen_chat():
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': '你是谁?'}
]
response = Generation.call(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key = "sk-xxx",
api_key=ALI_KEY,
model="qwen-plus", # 模型列表https://help.aliyun.com/zh/model-studio/getting-started/models
messages=messages,
result_format="message"
)
if response.status_code == 200:
print(response.output.choices[0].message.content)
else:
print(f"HTTP返回码{response.status_code}")
print(f"错误码:{response.code}")
print(f"错误信息:{response.message}")
print("请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import (
ResultCallback as TTSResultCallback,
SpeechSynthesizer as TTSSpeechSynthesizer,
SpeechSynthesisResult as TTSSpeechSynthesisResult,
)
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
from dashscope.audio.tts_v2 import (
ResultCallback as CosyResultCallback,
SpeechSynthesizer as CosySpeechSynthesizer,
AudioFormat,
)
class QwenTTS:
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun"):
import dashscope
import ssl
logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx
self.model_name = model_name.split('@')[0]
dashscope.api_key = key
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
self.synthesizer = None
self.callback = None
self.is_cosyvoice = False
self.voice = ""
self.format = format
self.sample_rate = sample_rate
self.first_chunk = True
if '/' in self.model_name:
parts = self.model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1':
self.is_cosyvoice = True
self.voice = parts[1]
class Callback(TTSResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: TTSSpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
# --------------------------
class Callback_Cosy(CosyResultCallback):
def __init__(self, data_callback=None) -> None:
self.dque = deque()
self.data_callback = data_callback
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
logging.info("Qwen CosyVoice tts open ")
pass
def on_complete(self):
self.dque.append(None)
if self.data_callback:
self.data_callback(None) # 发送结束信号
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
if self.data_callback:
self.data_callback(f"ERROR:{str(response)}".encode())
raise RuntimeError(str(response))
def on_close(self):
# print("---Qwen call back close") # cyx
logging.info("Qwen CosyVoice tts close")
pass
""" canceled for test 语音大模型CosyVoice
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
"""
def on_event(self, message):
# logging.info(f"recv speech synthsis message {message}")
pass
# 以下适合语音大模型CosyVoice
def on_data(self, data: bytes) -> None:
if len(data) > 0:
if self.data_callback:
self.data_callback(data)
else:
self.dque.append(data)
# --------------------------
def tts(self, text):
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
# text = self.normalize_text(text)
try:
# if self.model_name != 'cosyvoice-v1':
if self.is_cosyvoice is False:
self.callback = self.Callback()
TTSSpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format="wav") # format="mp3")
else:
self.callback = self.Callback_Cosy()
format = self.get_audio_format(self.format, self.sample_rate)
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
self.synthesizer.call(text)
except Exception as e:
print(f"---dale---20 error {e}") # cyx
# -----------------------------------
try:
for data in self.callback._run():
# logging.info(f"dashcope return data {len(data)}")
yield data
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def setup_tts(self, on_data):
"""设置 TTS 回调,返回配置好的 synthesizer"""
if not self.is_cosyvoice:
raise NotImplementedError("Only CosyVoice supported")
# 创建 CosyVoice 回调
self.callback = self.Callback_Cosy(on_data)
format_val = self.get_audio_format(self.format, self.sample_rate)
logging.info(f"setup_tts {self.voice} {format_val}")
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
voice=self.voice, # voice="longyuan", #"longfei",
callback=self.callback,
format=format_val
)
return self.synthesizer
def text_tts_call(self, text):
if self.synthesizer:
self.synthesizer.call(text)
def streaming_call(self, text):
if self.synthesizer:
self.synthesizer.streaming_call(text)
def end_streaming_call(self):
if self.synthesizer:
# logging.info(f"---dale end_streaming_call")
self.synthesizer.streaming_complete()
def get_audio_format(self, format: str, sample_rate: int):
"""动态获取音频格式"""
from dashscope.audio.tts_v2 import AudioFormat
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
(48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
(48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
(48000, 'wav'): AudioFormat.WAV_48000HZ_MONO_16BIT
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
import threading
import uuid
import time
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
import threading
import uuid
import time
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from collections import deque
from io import BytesIO
class UnifiedTTSEngine:
def __init__(self):
self.lock = threading.Lock()
self.tasks = {}
self.executor = ThreadPoolExecutor(max_workers=10)
self.cache_expire = 300 # 5分钟缓存
# 启动清理过期任务的定时器
self.cleanup_timer = None
self.start_cleanup_timer()
def _cleanup_old_tasks(self):
"""清理过期任务"""
now = time.time()
with self.lock:
expired_ids = [task_id for task_id, task in self.tasks.items()
if now - task['created_at'] > self.cache_expire]
for task_id in expired_ids:
self._remove_task(task_id)
def _remove_task(self, task_id):
"""移除任务"""
if task_id in self.tasks:
task = self.tasks.pop(task_id)
# 取消可能的后台任务
if 'future' in task and not task['future'].done():
task['future'].cancel()
# 其他资源在任务被移除后会被垃圾回收
# 资源释放机制总结:
# 移除任务引用self.tasks.pop() 解除任务对象引用,触发垃圾回收。
# 取消后台线程future.cancel() 终止未完成线程,释放线程资源。
# 自动内存回收Python GC 回收任务对象及其队列、缓冲区占用的内存。
# 线程池管理:执行器自动回收线程至池中,避免资源泄漏。
def create_tts_task(self, text, format, sample_rate, model_name, key, delay_gen_audio=False):
"""创建TTS任务同步方法"""
self._cleanup_old_tasks()
audio_stream_id = str(uuid.uuid4())
# 创建任务数据结构
task_data = {
'id': audio_stream_id,
'text': text,
'format': format,
'sample_rate': sample_rate,
'model_name': model_name,
'key': key,
'delay_gen_audio': delay_gen_audio,
'created_at': time.time(),
'status': 'pending',
'data_queue': deque(),
'event': threading.Event(),
'completed': False,
'error': None
}
with self.lock:
self.tasks[audio_stream_id] = task_data
# 如果不是延迟模式,立即启动任务
if not delay_gen_audio:
self._start_tts_task(audio_stream_id)
return audio_stream_id
def _start_tts_task(self, audio_stream_id):
# 启动TTS任务后台线程
task = self.tasks.get(audio_stream_id)
if not task or task['status'] != 'pending':
return
logging.info("已经启动 start tts task {audio_stream_id}")
task['status'] = 'processing'
# 在后台线程中执行TTS
future = self.executor.submit(self._run_tts_sync, audio_stream_id)
task['future'] = future
# 如果需要等待任务完成
if not task.get('delay_gen_audio', True):
try:
# 等待任务完成最多5分钟
future.result(timeout=300)
logging.info(f"TTS任务 {audio_stream_id} 已完成")
self._merge_audio_data(audio_stream_id)
except concurrent.futures.TimeoutError:
task['error'] = "TTS生成超时"
task['completed'] = True
logging.error(f"TTS任务 {audio_stream_id} 超时")
except Exception as e:
task['error'] = f"ERROR:{str(e)}"
task['completed'] = True
logging.exception(f"TTS任务执行异常: {str(e)}")
def _run_tts_sync(self, audio_stream_id):
# 同步执行TTS生成 在后台线程中执行
task = self.tasks.get(audio_stream_id)
if not task:
return
try:
# 创建TTS实例
tts = QwenTTS(
key=task['key'],
format=task['format'],
sample_rate=task['sample_rate'],
model_name=task['model_name']
)
# 定义同步数据处理函数
def data_handler(data):
if data is None: # 结束信号
task['completed'] = True
task['event'].set()
logging.info(f"--data_handler on_complete")
elif data.startswith(b"ERROR"): # 错误信号
task['error'] = data.decode()
task['completed'] = True
task['event'].set()
else: # 音频数据
task['data_queue'].append(data)
# 设置并执行TTS
synthesizer = tts.setup_tts(data_handler)
synthesizer.call(task['text'])
# 等待完成或超时
if not task['event'].wait(timeout=300): # 5分钟超时
task['error'] = "TTS generation timeout"
task['completed'] = True
logging.info(f"--tts task event set error = {task['error']}")
except Exception as e:
task['error'] = f"ERROR:{str(e)}"
task['completed'] = True
def _merge_audio_data(self, audio_stream_id):
"""将任务的所有音频数据合并到ByteIO缓冲区"""
task = self.tasks.get(audio_stream_id)
if not task or not task.get('completed'):
return
try:
logging.info(f"开始合并音频数据: {audio_stream_id}")
# 创建内存缓冲区
buffer = io.BytesIO()
# 合并所有数据块
for data_chunk in task['data_queue']:
buffer.write(data_chunk)
# 重置指针位置以便读取
buffer.seek(0)
# 保存到任务对象
task['buffer'] = buffer
logging.info(f"音频数据合并完成,总大小: {buffer.getbuffer().nbytes} 字节")
# 可选:清理原始数据队列以节省内存
task['data_queue'].clear()
except Exception as e:
logging.error(f"合并音频数据失败: {str(e)}")
task['error'] = f"合并错误: {str(e)}"
async def get_audio_stream(self, audio_stream_id):
"""获取音频流(异步生成器)"""
task = self.tasks.get(audio_stream_id)
if not task:
raise RuntimeError("Audio stream not found")
# 如果是延迟任务且未启动,现在启动 status 为 pending
if task['delay_gen_audio'] and task['status'] == 'pending':
self._start_tts_task(audio_stream_id)
# 等待任务启动
while task['status'] == 'pending':
await asyncio.sleep(0.1)
# 流式返回数据
while not task['completed'] or task['data_queue']:
while task['data_queue']:
data = task['data_queue'].popleft()
# logging.info(f"yield data {len(data)}")
yield data
# 短暂等待新数据
await asyncio.sleep(0.05)
# 检查错误
if task['error']:
raise RuntimeError(task['error'])
def start_cleanup_timer(self):
"""启动定时清理任务"""
if self.cleanup_timer:
self.cleanup_timer.cancel()
self.cleanup_timer = threading.Timer(30.0, self.cleanup_task) # 每30秒清理一次
self.cleanup_timer.daemon = True # 设置为守护线程
self.cleanup_timer.start()
def cleanup_task(self):
"""执行清理任务"""
try:
self._cleanup_old_tasks()
except Exception as e:
logging.error(f"清理任务时出错: {str(e)}")
finally:
self.start_cleanup_timer() # 重新启动定时器
# 全局 TTS 引擎实例
tts_engine = UnifiedTTSEngine()
def replace_domain(url: str) -> str:
"""替换URL中的域名为本地地址不使用urllib.parse"""
# 定义需要替换的域名列表
domains_to_replace = [
"http://1.13.185.116:9380",
"https://ragflow.szzysztech.com",
"1.13.185.116:9380",
"ragflow.szzysztech.com"
]
# 尝试替换每个可能的域名
for domain in domains_to_replace:
if domain in url:
# 直接替换域名部分
return url.replace(domain, "http://localhost:9380", 1)
# 如果未匹配到特定域名,尝试智能替换
if "://" in url:
# 分割协议和路径
protocol, path = url.split("://", 1)
# 查找第一个斜杠位置来确定域名结束位置
slash_pos = path.find("/")
if slash_pos > 0:
# 替换域名部分
return f"http://localhost:9380{path[slash_pos:]}"
else:
# 没有路径部分,直接返回本地地址
return "http://localhost:9380"
else:
# 没有协议部分,直接添加本地地址
return f"http://localhost:9380/{url}"
async def proxy_aichat_audio_stream(client_id: str, audio_url: str):
"""代理外部音频流请求"""
try:
# 替换域名为本地地址
local_url = audio_url
logging.info(f"代理音频流: {audio_url} -> {local_url}")
async with httpx.AsyncClient(timeout=60.0) as client:
async with client.stream("GET", local_url) as response:
# 流式转发音频数据
async for chunk in response.aiter_bytes():
if not await manager.send_bytes(client_id, chunk):
logging.warning(f"Audio proxy interrupted for {client_id}")
return
except Exception as e:
logging.error(f"Audio proxy failed: {str(e)}")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": f"音频流获取失败: {str(e)}"
}))
# 代理函数 - 文本流
async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict):
"""代理大模型文本流请求 - 兼容现有Flask实现"""
try:
logging.info(f"代理文本流: completions_url={completions_url} {payload}")
logging.debug(f"请求负载: {json.dumps(payload, ensure_ascii=False)}")
headers = {
"Content-Type": "application/json",
'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm'
}
# 创建TTS实例
tts_model = QwenTTS(
key=ALI_KEY,
format=payload.get('tts_stream_format', 'mp3'),
sample_rate=payload.get('tts_sample_rate', 48000),
model_name=payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
)
# 创建流会话
tts_stream_session_id = stream_manager.create_session(
tts_model=tts_model,
sample_rate=payload.get('tts_sample_rate', 48000),
stream_format=payload.get('tts_stream_format', 'mp3'),
session_id=None,
streaming_call=True
)
# logging.info(f"---tts_stream_session_id = {tts_stream_session_id}")
tts_stream_session_id_sent = False
# 使用更长的超时时间 (5分钟)
timeout = httpx.Timeout(300.0, connect=60.0)
async with httpx.AsyncClient(timeout=timeout) as client:
# 关键修改:使用流式请求模式
async with client.stream( # <-- 使用stream方法
"POST",
completions_url,
json=payload,
headers=headers
) as response:
logging.info(f"响应状态: HTTP {response.status_code}")
if response.status_code != 200:
# 读取错误信息(非流式)
error_content = await response.aread()
error_msg = f"后端错误: HTTP {response.status_code}"
error_msg += f" - {error_content[:200].decode()}" if error_content else ""
await manager.send_text(client_id, json.dumps({"type": "error", "message": error_msg}))
return
# 验证SSE流
content_type = response.headers.get("content-type", "").lower()
if "text/event-stream" not in content_type:
logging.warning("非流式响应,转发完整内容")
full_content = await response.aread()
await manager.send_text(client_id, json.dumps({
"type": "text",
"data": full_content.decode('utf-8')
}))
return
logging.info("开始处理SSE流")
event_count = 0
# 使用异步迭代器逐行处理
async for line in response.aiter_lines():
# 跳过空行和注释行
if not line or line.startswith(':'):
continue
# 处理SSE事件
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str: # 过滤空数据
try:
# 解析并提取增量文本
data_obj = json.loads(data_str)
delta_text = None
if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict):
delta_text = data_obj.get('data', None).get('delta_ans', "")
if tts_stream_session_id_sent is False:
data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}"
data_str = json.dumps(data_obj)
tts_stream_session_id_sent = True
# 直接转发原始数据
await manager.send_text(client_id, json.dumps({
"type": "text",
"data": data_str
}))
# 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析
if delta_text:
# 追加到会话管理器
stream_manager.append_text(tts_stream_session_id, delta_text)
# logging.info(f"文本代理转发: {data_str}")
event_count += 1
except Exception as e:
logging.error(f"事件发送失败: {str(e)}")
# 保持连接活性
await asyncio.sleep(0.001) # 避免CPU空转
logging.info(f"SSE流处理完成事件数: {event_count}")
# 发送结束信号
await manager.send_text(client_id, json.dumps({"type": "end"}))
except httpx.ReadTimeout:
logging.error("读取后端服务超时")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": "后端服务响应超时"
}))
except httpx.ConnectError as e:
logging.error(f"连接后端服务失败: {str(e)}")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": f"无法连接到后端服务: {str(e)}"
}))
except Exception as e:
logging.exception(f"文本代理失败: {str(e)}")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": f"文本流获取失败: {str(e)}"
}))
@tts_router.get("/audio/pcm_mp3")
async def stream_mp3():
def audio_generator():
path = './test.mp3'
try:
with open(path, 'rb') as f:
while True:
chunk = f.read(1024)
if not chunk:
break
yield chunk
except Exception as e:
logging.error(f"MP3 streaming error: {str(e)}")
return StreamingResponse(
audio_generator(),
media_type="audio/mpeg",
headers={
"Cache-Control": "no-store",
"Accept-Ranges": "bytes"
}
)
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
"""动态生成WAV头严格保持原有逻辑结构"""
with BytesIO() as wav_buffer:
with wave.open(wav_buffer, 'wb') as wav_file:
wav_file.setnchannels(1) # 保持原单声道设置
wav_file.setsampwidth(2) # 保持原16-bit设置
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
wav_buffer.seek(0)
return wav_buffer.read()
def generate_silence_header(duration_ms: int = 500) -> bytes:
"""生成静音数据用于MP3流式传输预缓冲"""
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
# ------------------------ API路由 ------------------------
@tts_router.post("/chats/{chat_id}/tts")
async def create_tts_request(chat_id: str, request: Request):
try:
data = await request.json()
logging.info(f"Creating TTS request: {data}")
# 参数校验
text = data.get("text", "").strip()
if not text:
raise HTTPException(400, detail="Text cannot be empty")
format = data.get("tts_stream_format", "mp3")
if format not in ["mp3", "wav", "pcm"]:
raise HTTPException(400, detail="Unsupported audio format")
sample_rate = data.get("tts_sample_rate", 48000)
if sample_rate not in [8000, 16000, 22050, 44100, 48000]:
raise HTTPException(400, detail="Unsupported sample rate")
model_name = data.get("model_name", "cosyvoice-v1/longxiaochun")
delay_gen_audio = data.get('delay_gen_audio', False)
# 创建TTS任务
audio_stream_id = tts_engine.create_tts_task(
text=text,
format=format,
sample_rate=sample_rate,
model_name=model_name,
key=ALI_KEY,
delay_gen_audio=delay_gen_audio
)
return JSONResponse(
status_code=200,
content={
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"ws_url": f"/chats/{chat_id}/tts/{audio_stream_id}", # WebSocket URL 2025 0622新增
"expires_at": (datetime.datetime.now() + datetime.timedelta(seconds=300)).isoformat()
}
)
except Exception as e:
logging.error(f"Request failed: {str(e)}")
raise HTTPException(500, detail="Internal server error")
executor = ThreadPoolExecutor()
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio(
chat_id: str,
audio_stream_id: str,
range: str = Header(None)
):
try:
# 获取任务信息
task = tts_engine.tasks.get(audio_stream_id)
if not task:
# 返回友好的错误信息而不是抛出异常
return JSONResponse(
status_code=404,
content={
"error": "Audio stream not found",
"message": f"The requested audio stream ID '{audio_stream_id}' does not exist or has expired",
"suggestion": "Please create a new TTS request and try again"
}
)
# 获取媒体类型
format = task['format']
media_type = {
"mp3": "audio/mpeg",
"wav": "audio/wav",
"pcm": f"audio/L16; rate={task['sample_rate']}; channels=1"
}[format]
# 如果任务已完成且有完整缓冲区处理Range请求
logging.info(f"get_tts_audio task = {task.get('completed', 'None')} {task.get('buffer', 'None')}")
# 创建响应内容生成器
def buffer_read(buffer):
content_length = buffer.getbuffer().nbytes
remaining = content_length
chunk_size = 4096
buffer.seek(0)
while remaining > 0:
read_size = min(remaining, chunk_size)
data = buffer.read(read_size)
if not data:
break
yield data
remaining -= len(data)
if task.get('completed') and task.get('buffer') is not None:
buffer = task['buffer']
total_size = buffer.getbuffer().nbytes
# 强制小文件使用流式传输避免206响应问题
if total_size < 1024 * 120: # 小于300KB
range = None
if range:
# 处理范围请求
return handle_range_request(range, buffer, total_size, media_type)
else:
return StreamingResponse(
buffer_read(buffer),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Cache-Control": "no-store",
"Transfer-Encoding": "chunked"
})
# 创建流式响应
logging.info("tts_engine.get_audio_stream--0")
return StreamingResponse(
tts_engine.get_audio_stream(audio_stream_id),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Cache-Control": "no-store",
"Transfer-Encoding": "chunked"
}
)
except Exception as e:
logging.error(f"Audio streaming failed: {str(e)}")
raise HTTPException(500, detail="Audio generation error")
def handle_range_request(range_header: str, buffer: BytesIO, total_size: int, media_type: str):
"""处理 HTTP Range 请求"""
try:
# 解析 Range 头部 (示例: "bytes=0-1023")
range_type, range_spec = range_header.split('=')
if range_type != 'bytes':
raise ValueError("Unsupported range type")
start_str, end_str = range_spec.split('-')
start = int(start_str)
end = int(end_str) if end_str else total_size - 1
logging.info(f"handle_range_request--1 {start_str}-{end_str} {end}")
# 验证范围有效性
if start >= total_size or end >= total_size:
raise HTTPException(status_code=416, headers={
"Content-Range": f"bytes */{total_size}"
})
# 计算内容长度
content_length = end - start + 1
# 设置状态码
status_code = 206 # Partial Content
if start == 0 and end == total_size - 1:
status_code = 200 # Full Content
# 设置流读取位置
buffer.seek(start)
# 创建响应内容生成器
def content_generator():
remaining = content_length
chunk_size = 4096
while remaining > 0:
read_size = min(remaining, chunk_size)
data = buffer.read(read_size)
if not data:
break
yield data
remaining -= len(data)
# 返回分块响应
return StreamingResponse(
content_generator(),
status_code=status_code,
media_type=media_type,
headers={
"Content-Range": f"bytes {start}-{end}/{total_size}",
"Content-Length": str(content_length),
"Accept-Ranges": "bytes",
"Cache-Control": "public, max-age=3600"
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@tts_router.websocket("/chats/{chat_id}/tts/{audio_stream_id}")
async def websocket_tts_endpoint(
websocket: WebSocket,
chat_id: str,
audio_stream_id: str
):
# 接收 header 参数
headers = websocket.headers
service_type = headers.get("x-tts-type") # 注意header 名称转为小写
# audio_url = headers.get("x-audio-url")
"""
前端示例
websocketConnection = uni.connectSocket({
url: url,
header: {
'Authorization': token,
'X-Tts-Type': 'AiChat', //'Ask' // 自定义参数1
'X-Device-Type': 'mobile', // 自定义参数2
'X-User-ID': '12345' // 自定义参数3
},
success: () => {
console.log('WebSocket connected');
},
fail: (err) => {
console.error('WebSocket connection failed:', err);
}
});
"""
# 创建唯一连接 ID
connection_id = str(uuid.uuid4())
# logging.info(f"---dale-- websocket connection_id = {connection_id} chat_id={chat_id}")
await manager.connect(websocket, connection_id)
completed_successfully = False
try:
# 根据tts_type路由到不同的音频源
if service_type == "AiChatTts":
# 音频代理服务
audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}"
# await proxy_aichat_audio_stream(connection_id, audio_url)
sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate')
await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate})
async for data in stream_manager.get_tts_buffer_data(audio_stream_id):
if not await manager.send_bytes(connection_id, data):
break
completed_successfully = True
elif service_type == "AiChatText":
# 文本代理服务
# 等待客户端发送初始请求数据 进行大模型对话代理时需要前端连接后发送payload
payload = await websocket.receive_json()
completions_url = f"http://localhost:9380/api/v1/chats/{chat_id}/completions"
await proxy_aichat_text_stream(connection_id, completions_url, payload)
completed_successfully = True
else:
# 使用引擎的生成器直接获取音频流
async for data in tts_engine.get_audio_stream(audio_stream_id):
if not await manager.send_bytes(connection_id, data):
logging.warning(f"Send failed, connection closed: {connection_id}")
break
completed_successfully = True
# 发送完成信号前检查连接状态
if manager.is_connected(connection_id):
# 发送完成信号
await manager.send_json(connection_id, {"status": "completed"})
# 添加短暂延迟确保消息送达
await asyncio.sleep(0.1)
# 主动关闭WebSocket连接
await manager.disconnect(connection_id, code=1000, reason="Audio stream completed")
except WebSocketDisconnect:
logging.info(f"WebSocket disconnected: {connection_id}")
except Exception as e:
logging.error(f"WebSocket TTS error: {str(e)}")
if manager.is_connected(connection_id):
await manager.send_json(connection_id, {"error": str(e)})
finally:
pass
# await manager.disconnect(connection_id)
def cleanup_cache():
"""清理过期缓存"""
with cache_lock:
now = datetime.datetime.now()
expired = [k for k, v in audio_text_cache.items()
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
for key in expired:
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
del audio_text_cache[key]
# 应用启动时启动清理线程
# start_background_cleaner()