Files
ragflow_python/asr-monitor-test/bk/tts_service.py

2329 lines
91 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import binascii
from copy import deepcopy
from timeit import default_timer as timer
import datetime
from datetime import timedelta
import threading, time, queue, uuid, time, array
from threading import Lock, Thread
from concurrent.futures import ThreadPoolExecutor
import base64, gzip
import os, io, re, json
from io import BytesIO
from typing import Optional, Dict, Any
import asyncio, httpx
from collections import deque
import websockets
import uuid
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query
from fastapi import FastAPI, UploadFile, File, Form, Header
from fastapi.responses import StreamingResponse, JSONResponse, Response
TTS_SAMPLERATE = 44100 # 22050 # 16000
FORMAT = "mp3"
CHANNELS = 1 # 单声道
SAMPLE_WIDTH = 2 # 16-bit = 2字节
tts_router = APIRouter()
# logger = logging.getLogger(__name__)
class MillisecondsFormatter(logging.Formatter):
"""自定义日志格式器,添加毫秒时间戳"""
def formatTime(self, record, datefmt=None):
# 将时间戳转换为本地时间元组
ct = self.converter(record.created)
# 格式化为 "小时:分钟:秒"
t = time.strftime("%H:%M:%S", ct)
# 添加毫秒3位
return f"{t}.{int(record.msecs):03d}"
# 配置全局日志格式
def configure_logging():
# 创建 Formatter
log_format = "%(asctime)s - %(levelname)s - %(message)s"
formatter = MillisecondsFormatter(log_format)
# 获取根 Logger 并清除已有配置
root_logger = logging.getLogger()
root_logger.handlers = []
# 创建并配置 Handler输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
# 设置日志级别并添加 Handler
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
# 调用配置函数(程序启动时运行一次)
configure_logging()
class StreamSessionManager:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次
self.streaming_call_timeout = 10 # 10s
self.gc_tts = 3 # 3s
self.sentence_timeout = 2 # 2000ms句子超时
self.sentence_endings = set('。?!;.?!;') # 中英文结束符
# 增强版正则表达式:匹配中英文句子结束符(包含全角)
self.sentence_pattern = re.compile(
r'([,,。?!;.?!;!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;!;…])'
)
self.sentence_audio_store = {} # {sentence_id: {'data': bytes, 'text': str, 'created_at': float}}
self.sentence_lock = threading.Lock()
threading.Thread(target=self._cleanup_expired, daemon=True).start()
def create_session(self, tts_model, sample_rate=8000, stream_format='mp3', session_id=None, streaming_call=False):
if not session_id:
session_id = str(uuid.uuid4())
with self.lock:
# 创建TTS实例并设置流式回调
tts_instance = tts_model
# 定义音频数据回调函数
def on_data(data: bytes):
if data:
try:
self.sessions[session_id]['last_active'] = time.time()
self.sessions[session_id]['buffer'].put(data)
self.sessions[session_id]['audio_chunk_size'] += len(data)
#logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}")
except queue.Full:
logging.warning(f"Audio buffer full for session {session_id}")
"""
elif data is None: # 结束信号
# 仅对非流式引擎触发完成事件
if not streaming_call:
logging.info(f"StreamSessionManager on_data sentence_complete_event set")
self.sessions[session_id]['sentence_complete_event'].set()
self.sessions[session_id]['current_processing'] = False
"""
# 创建完成事件
completion_event = threading.Event()
# 设置TTS流式传输
tts_instance.setup_tts(on_data,completion_event)
# 创建会话
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count': 0,
'audio_chunk_size': 0,
'finished': threading.Event(), # 添加事件对象
'sample_rate': sample_rate,
'stream_format': stream_format,
"tts_chunk_data_valid": False,
"text_buffer": "", # 新增文本缓冲区
"last_text_time": time.time(), # 最后文本到达时间
"streaming_call": streaming_call,
"tts_stream_started": False, # 标记是否已启动流
"current_processing": False, # 标记是否正在处理句子
"sentence_complete_event": completion_event, #threading.Event(),
'sentences': [], # 存储句子ID列表
'current_sentence_index': 0,
'gc_interval':300, # 5分钟清理一次
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 更新文本缓冲区和时间戳
session['text_buffer'] += text
session['last_text_time'] = time.time()
# 将文本放入任务队列(非阻塞)
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def finish_text_input(self, session_id):
"""标记文本输入结束,通知任务处理线程"""
with self.lock:
session = self.sessions.get(session_id)
if not session:
return
session['gc_interval'] = 100 # 所有的文本输入已经结束,可以将超时检查时间缩短
def _process_tasks(self, session_id): # 20250718 新更新
"""任务处理线程(每个会话独立)- 保留原有处理逻辑"""
session = self.sessions.get(session_id)
if not session or not session['active']:
return
# 根据引擎类型选择处理函数
if session.get('streaming_call'):
gen_tts_audio_func = self._generate_audio #self._stream_audio
else:
gen_tts_audio_func = self._generate_audio
while session['active']:
current_time = time.time()
text_to_process = ""
# 1. 获取待处理文本
with self.lock:
if session['text_buffer']:
text_to_process = session['text_buffer']
# 2. 处理文本
if text_to_process and not session['current_processing'] :
session['text_buffer'] = ""
# 分割完整句子
complete_sentences, remaining_text = self._split_and_extract(text_to_process)
# 保存剩余文本
if remaining_text:
with self.lock:
session['text_buffer'] = remaining_text + session['text_buffer']
# 合并并处理完整句子
if complete_sentences:
# 智能合并句子最长300字符
buffer = []
current_length = 0
# 处理每个句子
for sentence in complete_sentences:
sent_length = len(sentence)
# 添加到当前缓冲区
if current_length + sent_length <= 300:
buffer.append(sentence)
current_length += sent_length
else:
# 处理已缓冲的文本
if buffer:
combined_text = "".join(buffer)
# 重置完成事件状态
session['sentence_complete_event'].clear()
session['current_processing'] = True
# 生成音频
gen_tts_audio_func(session_id, combined_text)
# 等待完成
if not session['sentence_complete_event'].wait(timeout=120.0):
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
# 重置处理状态
time.sleep(5.0)
session['current_processing'] = False
logging.info(f"StreamSessionManager _process_tasks 转换结束!!!")
# 重置缓冲区
buffer = [sentence]
current_length = sent_length
# 处理剩余的缓冲文本
if buffer:
combined_text = "".join(buffer)
session['current_processing'] = True
# 生成音频
gen_tts_audio_func(session_id, combined_text) #如果调用_stream_audio则是同步调用会阻塞直到音频生成完成
# 重置处理状态
time.sleep(1.0)
session['current_processing'] = False
# 3. 检查超时未处理的文本
if current_time - session['last_text_time'] > self.sentence_timeout:
with self.lock:
if session['text_buffer']:
# 直接处理剩余文本
session['current_processing'] = True
gen_tts_audio_func(session_id, session['text_buffer'])
session['text_buffer'] = ""
# 重置处理状态
session['current_processing'] = False
# 4. 会话超时检查
if current_time - session['last_active'] > session['gc_interval']:
# 处理剩余文本
with self.lock:
if session['text_buffer']:
session['current_processing'] = True
# 处理最后一段文本
gen_tts_audio_func(session_id, session['text_buffer'])
session['text_buffer'] = ""
# 重置处理状态
session['current_processing'] = False
# 关闭会话
logging.info(f"--_process_tasks-- timeout {session['last_active']} {session['gc_interval']}")
self.close_session(session_id)
break
# 5. 休眠避免CPU空转
time.sleep(0.05) # 50ms检查间隔
def _generate_audio(self, session_id, text): # 20250718 新更新
"""实际生成音频(顺序执行)- 用于非流式引擎"""
session = self.sessions.get(session_id)
if not session:
return
try:
#logging.info(f"StreamSessionManager _generate_audio--0 {text}")
# 创建内存流
audio_stream = io.BytesIO()
# 定义回调函数:直接写入流
def on_data_sentence(data: bytes):
if data:
audio_stream.write(data)
def on_data_whole(data: bytes):
if data:
try:
session['last_active'] = time.time()
session['buffer'].put({'type':'arraybuffer','data':data})
#session['buffer'].put(data)
session['audio_chunk_size'] += len(data)
#logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}")
except queue.Full:
logging.warning(f"Audio buffer full for session {session_id}")
# 重置完成事件
session['sentence_complete_event'].clear()
session['tts_model'].setup_tts(on_data = on_data_whole,completion_event=session['sentence_complete_event'])
# 调用 TTS
session['tts_model'].text_tts_call(text)
session['last_active'] = time.time()
session['audio_chunk_count'] += 1
# 等待句子完成
if not session['sentence_complete_event'].wait(timeout=30): # 30秒超时
logging.warning(f"Timeout generating audio for: {text[:20]}...")
logging.info(f"StreamSessionManager _generate_audio 转换结束!!!")
session['buffer'].put({'type': 'sentence_end', 'data': ""})
# 获取音频数据
audio_data = audio_stream.getvalue()
audio_stream.close()
# 保存到句子存储
self.add_sentence_audio(session_id, text, audio_data)
if not session['tts_chunk_data_valid']:
session['tts_chunk_data_valid'] = True
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}".encode())
session['sentence_complete_event'].set() # 确保事件被设置
def _stream_audio(self, session_id, text):
"""流式传输文本到TTS服务"""
session = self.sessions.get(session_id)
if not session:
return
# logging.info(f"Streaming text to TTS: {text}")
try:
# 重置完成事件状态
session['sentence_complete_event'].clear()
# 使用流式调用发送文本
session['tts_model'].streaming_call(text)
session['last_active'] = time.time()
# 流式引擎不需要等待完成事件
session['sentence_complete_event'].set()
except Exception as e:
logging.error(f"Error in streaming_call: {str(e)}")
session['buffer'].put(f"ERROR:{str(e)}".encode())
session['sentence_complete_event'].set()
async def get_tts_buffer_data(self, session_id):
"""异步流式返回 TTS 音频数据(适配同步 queue.Queue带 10 秒超时)"""
session = self.sessions.get(session_id)
if not session:
raise ValueError(f"Session {session_id} not found")
buffer = session['buffer'] # 这里是 queue.Queue
last_data_time = time.time() # 记录最后一次获取数据的时间
while session['active']:
try:
# 使用 run_in_executor + wait_for 设置 10 秒超时
data = await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, buffer.get),
timeout=10.0 # 10 秒超时
)
#logging.info(f"get_tts_buffer_data {data}")
last_data_time = time.time() # 更新最后数据时间
yield data
except asyncio.TimeoutError:
# 10 秒内没有新数据,检查是否超时
if time.time() - last_data_time >= 10.0:
break
else:
continue # 未超时,继续等待
except asyncio.CancelledError:
logging.info(f"Session {session_id} stream cancelled")
break
except Exception as e:
logging.error(f"Error in get_tts_buffer_data: {e}")
break
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
session = self.sessions[session_id]
session['active'] = False
# 清理关联的音频数据
with self.sentence_lock:
# 标记会话相关的句子为过期
expired_sentences = [
sid for sid, data in self.sentence_audio_store.items()
if data.get('session_id') == session_id
]
# 设置完成事件
session['sentence_complete_event'].set()
# 清理TTS资源
try:
if session.get('streaming_call'):
session['tts_model'].end_streaming_call()
except:
pass
# 延迟清理会话
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def _cleanup_expired(self):
"""定时清理过期资源"""
while True:
time.sleep(30)
now = time.time()
# 清理过期会话
with self.lock:
expired_sessions = [
sid for sid, session in self.sessions.items()
if now - session['last_active'] > self.gc_interval
]
for sid in expired_sessions:
self._clean_session(sid)
# 清理过期句子音频
with self.sentence_lock:
expired_sentences = [
sid for sid, data in self.sentence_audio_store.items()
if now - data['created_at'] > self.gc_interval
]
for sid in expired_sentences:
del self.sentence_audio_store[sid]
logging.info(f"清理资源: {len(expired_sessions)}会话, {len(expired_sentences)}句子")
def get_session(self, session_id):
return self.sessions.get(session_id)
def _has_sentence_ending(self, text):
"""检测文本是否包含句子结束符"""
if not text:
return False
# 检查常见结束符(包含全角字符)
if any(char in self.sentence_endings for char in text[-3:]):
return True
# 检查中文段落结束(换行符前有结束符)
if '\n' in text and any(char in self.sentence_endings for char in text.split('\n')[-2:-1]):
return True
return False
def _split_and_extract(self, text):
"""
增强型句子分割器
返回: (完整句子列表, 剩余文本)
"""
# 特殊处理:如果文本以逗号开头,先处理前面的部分
if text.startswith((",", "")):
return [text[0]], text[1:]
# 1. 查找所有可能的句子结束位置
matches = list(self.sentence_pattern.finditer(text))
if not matches:
return [], text # 没有找到结束符
# 2. 确定最后一个完整句子的结束位置
last_end = 0
complete_sentences = []
for match in matches:
end_pos = match.end()
sentence = text[last_end:end_pos].strip()
# 跳过空句子
if not sentence:
last_end = end_pos
continue
# 检查是否为有效句子(最小长度或包含结束符)
if (len(sentence) > 6 or any(char in "。.?!!" for char in sentence)) and (last_end<24):
complete_sentences.append(sentence)
last_end = end_pos
else:
# 短文本但包含结束符,可能是特殊符号
if any(char in "。.?!!" for char in sentence):
complete_sentences.append(sentence)
last_end = end_pos
# 3. 提取剩余文本
remaining_text = text[last_end:].strip()
return complete_sentences, remaining_text
def add_sentence_audio(self, session_id, sentence_text, audio_data: bytes):
"""添加句子音频到存储"""
with self.lock:
if session_id not in self.sessions:
return None
# 生成唯一句子ID
sentence_id = str(uuid.uuid4())
# 存储音频数据
with self.sentence_lock:
self.sentence_audio_store[sentence_id] = {
'data': audio_data,
'text': sentence_text,
'created_at': time.time(),
'session_id': session_id,
'format': self.sessions[session_id]['stream_format']
}
logging.info(f" StreamSessionManager add_sentence_audio")
# 添加到会话的句子列表
self.sessions[session_id]['sentences'].append(sentence_id)
return sentence_id
def get_sentence_audio(self, sentence_id):
"""获取句子音频数据"""
with self.sentence_lock:
if sentence_id not in self.sentence_audio_store:
return None
return self.sentence_audio_store[sentence_id]['data']
def get_sentence_info(self, sentence_id):
"""获取句子信息"""
with self.sentence_lock:
if sentence_id not in self.sentence_audio_store:
return None
return self.sentence_audio_store[sentence_id]
def get_next_sentence(self, session_id):
"""获取下一个句子的信息"""
with self.lock:
session = self.sessions.get(session_id)
if not session or not session['active']:
return None
if session['current_sentence_index'] < len(session['sentences']):
sentence_id = session['sentences'][session['current_sentence_index']]
session['current_sentence_index'] += 1
return {
'id': sentence_id,
'url': f"/tts_sentence/{sentence_id}" # 虚拟URL
}
return None
def get_sentence_audio_data(self, session_id):
with self.lock:
session = self.sessions.get(session_id)
if not session or not session['active']:
return None
return self.sentence_audio_store
stream_manager = StreamSessionManager()
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# WebSocket 连接管理
class ConnectionManager:
def __init__(self):
self.active_connections = {}
async def connect(self, websocket: WebSocket, connection_id: str):
await websocket.accept()
self.active_connections[connection_id] = websocket
logging.info(f"新连接建立: {connection_id}")
async def disconnect(self, connection_id: str, code=1000, reason: str = ""):
if connection_id in self.active_connections:
try:
# 尝试正常关闭连接(非阻塞)
await self.active_connections[connection_id].close(code=code, reason=reason)
except:
pass # 忽略关闭错误
finally:
del self.active_connections[connection_id]
def is_connected(self, connection_id: str) -> bool:
"""检查连接是否仍然活跃"""
return connection_id in self.active_connections
async def _safe_send(self, connection_id: str, send_func, *args):
"""安全发送的通用方法(核心修改)"""
# 1. 检查连接是否存在
if connection_id not in self.active_connections:
logging.warning(f"尝试向不存在的连接发送数据: {connection_id}")
return False
websocket = self.active_connections[connection_id]
try:
# 2. 检查连接状态(关键修改)
if websocket.client_state.name != "CONNECTED":
logging.warning(f"连接 {connection_id} 状态为 {websocket.client_state.name}")
await self.disconnect(connection_id)
return False
# 3. 执行发送操作
await send_func(websocket, *args)
return True
except (WebSocketDisconnect, RuntimeError) as e:
# 4. 处理连接断开异常
logging.info(f"发送时检测到断开连接: {connection_id}, {str(e)}")
await self.disconnect(connection_id)
return False
except Exception as e:
# 5. 处理其他异常
logging.error(f"发送数据出错: {connection_id}, {str(e)}")
await self.disconnect(connection_id)
return False
async def send_bytes(self, connection_id: str, data: bytes):
"""安全发送字节数据"""
return await self._safe_send(
connection_id,
lambda ws, d: ws.send_bytes(d),
data
)
async def send_text(self, connection_id: str, message: str):
"""安全发送文本数据"""
return await self._safe_send(
connection_id,
lambda ws, m: ws.send_text(m),
message
)
async def send_json(self, connection_id: str, data: dict):
"""安全发送JSON数据"""
return await self._safe_send(
connection_id,
lambda ws, d: ws.send_json(d),
data
)
manager = ConnectionManager()
def generate_mp3_header(
sample_rate: int,
bitrate_kbps: int,
channels: int = 1,
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
) -> bytes:
"""
动态生成 MP3 帧头4字节支持 Layer I/II/III
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
:param bitrate_kbps: 比特率(单位 kbps
:param channels: 声道数 (1: 单声道, 2: 立体声)
:param layer: 编码层 ("I", "II", "III")
:return: 4字节的帧头数据
"""
# ----------------------------------
# 参数校验
# ----------------------------------
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
if sample_rate not in valid_sample_rates:
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
valid_layers = {"I", "II", "III"}
if layer not in valid_layers:
raise ValueError(f"不支持的层,可选:{valid_layers}")
# ----------------------------------
# 确定 MPEG 版本和采样率索引
# ----------------------------------
if sample_rate == 44100:
mpeg_version = 0b11 # MPEG-1
sample_rate_index = 0b00
elif sample_rate == 22050:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b00
elif sample_rate == 16000:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b10
elif sample_rate == 8000:
mpeg_version = 0b00 # MPEG-2.5
sample_rate_index = 0b10
else:
raise ValueError("采样率与版本不匹配")
# ----------------------------------
# 动态选择比特率表(关键扩展)
# ----------------------------------
# Layer 编码映射I:0b11, II:0b10, III:0b01
layer_code = {
"I": 0b11,
"II": 0b10,
"III": 0b01
}[layer]
# 比特率表(覆盖所有 Layer
bitrate_tables = {
# -------------------------------
# MPEG-1 (0b11)
# -------------------------------
# Layer I
(0b11, 0b11): {
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
416: 0b1100, 448: 0b1101
},
# Layer II
(0b11, 0b10): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
320: 0b1100, 384: 0b1101
},
# Layer III
(0b11, 0b01): {
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
256: 0b0100, 320: 0b0101
},
# -------------------------------
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
# -------------------------------
# Layer I
(0b10, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
(0b00, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
# Layer II
(0b10, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
(0b00, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
# Layer III
(0b10, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
144: 0b0100, 160: 0b0101
},
(0b00, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
}
}
# 获取当前版本的比特率表
key = (mpeg_version, layer_code)
if key not in bitrate_tables:
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
bitrate_table = bitrate_tables[key]
if bitrate_kbps not in bitrate_table:
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
bitrate_index = bitrate_table[bitrate_kbps]
# ----------------------------------
# 确定声道模式
# ----------------------------------
if channels == 1:
channel_mode = 0b11 # 单声道
elif channels == 2:
channel_mode = 0b00 # 立体声
else:
raise ValueError("声道数必须为1或2")
# ----------------------------------
# 组合帧头字段(修正层编码)
# ----------------------------------
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
version = mpeg_version << 19 # MPEG 版本 2位
layer_bits = layer_code << 17 # Layer 编码I:0b11, II:0b10, III:0b01
protection = 0 << 16 # 无 CRC
bitrate_bits = bitrate_index << 12
sample_rate_bits = sample_rate_index << 10
padding = 0 << 9 # 无填充
private = 0 << 8
mode = channel_mode << 6
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
copyright = 0 << 3
original = 0 << 2
emphasis = 0b00 # 无强调
frame_header = (
sync |
version |
layer_bits |
protection |
bitrate_bits |
sample_rate_bits |
padding |
private |
mode |
mode_ext |
copyright |
original |
emphasis
)
return frame_header.to_bytes(4, byteorder='big')
# ------------------------------------------------
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()
def parse_markdown_json(json_string):
# 使用正则表达式匹配Markdown中的JSON代码块
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
if match:
try:
# 尝试解析JSON字符串
data = json.loads(match[1])
return {'success': True, 'data': data}
except json.JSONDecodeError as e:
# 如果解析失败,返回错误信息
return {'success': False, 'data': str(e)}
else:
return {'success': False, 'data': 'not a valid markdown json string'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def encode_gzip_base64(original_data: bytes) -> str:
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
# Step 1: Gzip 压缩
with BytesIO() as buf:
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
gz_file.write(original_data)
compressed_bytes = buf.getvalue()
# Step 2: Base64 编码配置与Android端匹配
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符等同于Android的Base64.NO_WRAP
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
now = time.time()
expired_keys = [
k for k, v in audio_text_cache.items()
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
]
for k in expired_keys:
entry = audio_text_cache.pop(k, None)
if entry and entry.get('audio_stream'):
entry['audio_stream'].close()
def start_background_cleaner():
"""启动后台清理线程"""
def cleaner_loop():
while True:
time.sleep(180) # 每3分钟清理一次
clean_audio_cache()
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
cleaner_thread.start()
def test_qwen_chat():
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': '你是谁?'}
]
response = Generation.call(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key = "sk-xxx",
api_key=ALI_KEY,
model="qwen-plus", # 模型列表https://help.aliyun.com/zh/model-studio/getting-started/models
messages=messages,
result_format="message"
)
if response.status_code == 200:
print(response.output.choices[0].message.content)
else:
print(f"HTTP返回码{response.status_code}")
print(f"错误码:{response.code}")
print(f"错误信息:{response.message}")
print("请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import (
ResultCallback as TTSResultCallback,
SpeechSynthesizer as TTSSpeechSynthesizer,
SpeechSynthesisResult as TTSSpeechSynthesisResult,
)
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
from dashscope.audio.tts_v2 import (
ResultCallback as CosyResultCallback,
SpeechSynthesizer as CosySpeechSynthesizer,
AudioFormat,
)
class QwenTTS:
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun",
special_characters: Optional[Dict[str, str]] = None):
import dashscope
import ssl
logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx
self.model_name = model_name.split('@')[0]
dashscope.api_key = key
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
self.synthesizer = None
self.callback = None
self.is_cosyvoice = False
self.voice = ""
self.format = format
self.sample_rate = sample_rate
self.first_chunk = True
if '/' in self.model_name:
parts = self.model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2':
self.is_cosyvoice = True
self.voice = parts[1]
self.completion_event = None # 新增:用于通知任务完成
# 特殊字符及其拼音映射
self.special_characters = special_characters or {
"": "chuang3",
"": "yue4"
# 可以添加更多特殊字符的映射
}
class Callback(TTSResultCallback):
def __init__(self,data_callback=None,completion_event=None) -> None:
self.dque = deque()
self.data_callback = data_callback
self.completion_event = completion_event # 新增完成事件引用
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
if self.data_callback:
self.data_callback(None) # 发送结束信号
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: TTSSpeechSynthesisResult):
data =result.get_audio_frame()
if data is not None:
if len(data) > 0:
if self.data_callback:
self.data_callback(data)
else:
self.dque.append(data)
#self.dque.append(result.get_audio_frame())
# --------------------------
class Callback_Cosy(CosyResultCallback):
def __init__(self, data_callback=None,completion_event=None) -> None:
self.dque = deque()
self.data_callback = data_callback
self.completion_event = completion_event # 新增完成事件引用
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
logging.info("Qwen CosyVoice tts open ")
pass
def on_complete(self):
self.dque.append(None)
if self.data_callback:
self.data_callback(None) # 发送结束信号
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
if self.data_callback:
self.data_callback(f"ERROR:{str(response)}".encode())
raise RuntimeError(str(response))
def on_close(self):
# print("---Qwen call back close") # cyx
logging.info("Qwen CosyVoice tts close")
pass
""" canceled for test 语音大模型CosyVoice
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
"""
def on_event(self, message):
# logging.info(f"recv speech synthsis message {message}")
pass
# 以下适合语音大模型CosyVoice
def on_data(self, data: bytes) -> None:
if len(data) > 0:
if self.data_callback:
self.data_callback(data)
else:
self.dque.append(data)
# --------------------------
def tts(self, text, on_data = None,completion_event=None):
# logging.info(f"---QwenTTS tts begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
# text = self.normalize_text(text)
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
# text = self.normalize_text(text)
try:
# if self.model_name != 'cosyvoice-v1':
if self.is_cosyvoice is False:
self.callback = self.Callback(
data_callback=on_data,
completion_event=completion_event
)
TTSSpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format=self.format) # format="mp3")
else:
self.callback = self.Callback_Cosy()
format = self.get_audio_format(self.format, self.sample_rate)
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v2',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
self.synthesizer.call(text)
except Exception as e:
print(f"---dale---20 error {e}") # cyx
# -----------------------------------
try:
for data in self.callback._run():
# logging.info(f"dashcope return data {len(data)}")
yield data
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def setup_tts(self, on_data,completion_event=None):
"""设置 TTS 回调,返回配置好的 synthesizer"""
#if not self.is_cosyvoice:
# raise NotImplementedError("Only CosyVoice supported")
if self.is_cosyvoice:
# 创建 CosyVoice 回调
self.callback = self.Callback_Cosy(
data_callback=on_data,
completion_event=completion_event)
else:
self.callback = self.Callback(
data_callback=on_data,
completion_event=completion_event)
if self.is_cosyvoice:
format_val = self.get_audio_format(self.format, self.sample_rate)
# logging.info(f"Qwen setup_tts {self.voice} {format_val}")
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
voice=self.voice, # voice="longyuan", #"longfei",
callback=self.callback,
format=format_val
)
return self.synthesizer
def apply_phoneme_tags(self, text: str) -> str:
"""
在文本中查找特殊字符并用<phoneme>标签包裹它们
"""
# 如果文本已经是SSML格式直接返回
if text.strip().startswith("<speak>") and text.strip().endswith("</speak>"):
return text
# 为特殊字符添加SSML标签
for char, pinyin in self.special_characters.items():
# 使用正则表达式确保只替换整个字符(避免部分匹配)
pattern = r'([^<]|^)' + re.escape(char) + r'([^>]|$)'
replacement = r'\1<phoneme alphabet="py" ph="' + pinyin + r'">' + char + r'</phoneme>\2'
text = re.sub(pattern, replacement, text)
# 如果文本中已有SSML标签直接返回
if "<speak>" in text:
return text
# 否则包裹在<speak>标签中
return f"<speak>{text}</speak>"
def text_tts_call(self, text):
if self.special_characters and self.is_cosyvoice is False:
text = self.apply_phoneme_tags(text)
#logging.info(f"Applied SSML phoneme tags to text: {text}")
if self.synthesizer and self.is_cosyvoice:
logging.info(f"Qwen text_tts_call {text} {self.is_cosyvoice}")
format_val = self.get_audio_format(self.format, self.sample_rate)
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
voice=self.voice, # voice="longyuan", #"longfei",
callback=self.callback,
format=format_val
)
self.synthesizer.call(text)
if self.is_cosyvoice is False:
logging.info(f"Qwen text_tts_call {text}")
TTSSpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format=self.format)
def streaming_call(self, text):
if self.synthesizer:
self.synthesizer.streaming_call(text)
def end_streaming_call(self):
if self.synthesizer:
# logging.info(f"---dale end_streaming_call")
self.synthesizer.streaming_complete()
def get_audio_format(self, format: str, sample_rate: int):
"""动态获取音频格式"""
from dashscope.audio.tts_v2 import AudioFormat
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
(48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
(48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
(48000, 'wav'): AudioFormat.WAV_48000HZ_MONO_16BIT
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
class DoubaoTTS:
def __init__(self, key, format="mp3", sample_rate=8000, model_name="doubao-tts"):
logging.info(f"---begin--init DoubaoTTS-- {format} {sample_rate} {model_name}")
# 解析豆包认证信息 (appid, token, cluster, voice_type)
try:
self.appid = "7282190702"
self.token = "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9"
self.cluster = "volcano_tts"
self.voice_type ="zh_female_qingxinnvsheng_mars_bigtts" # "zh_male_jieshuonansheng_mars_bigtts" #"zh_male_ruyaqingnian_mars_bigtts" #"zh_male_jieshuonansheng_mars_bigtts"
except Exception as e:
raise ValueError(f"Invalid Doubao key format: {str(e)}")
self.format = format
self.sample_rate = sample_rate
self.model_name = model_name
self.callback = None
self.ws = None
self.loop = None
self.task = None
self.event = threading.Event()
self.data_queue = deque()
self.host = "openspeech.bytedance.com"
self.api_url = f"wss://{self.host}/api/v1/tts/ws_binary"
self.default_header = bytearray(b'\x11\x10\x11\x00')
self.total_data_size = 0
self.completion_event = None # 新增:用于通知任务完成
class Callback:
def __init__(self, data_callback=None,completion_event=None):
self.data_callback = data_callback
self.data_queue = deque()
self.completion_event = completion_event # 完成事件引用
def on_data(self, data):
if self.data_callback:
self.data_callback(data)
else:
self.data_queue.append(data)
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def on_complete(self):
if self.data_callback:
self.data_callback(None)
def on_error(self, error):
if self.data_callback:
self.data_callback(f"ERROR:{error}".encode())
def setup_tts(self, on_data,completion_event):
"""设置回调,返回自身(因为豆包需要异步启动)"""
self.callback = self.Callback(
data_callback=on_data,
completion_event=completion_event
)
return self
def text_tts_call(self, text):
"""同步调用,启动异步任务并等待完成"""
self.total_data_size = 0
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.task = self.loop.create_task(self._async_tts(text))
try:
self.loop.run_until_complete(self.task)
except Exception as e:
logging.error(f"DoubaoTTS--0 call error: {e}")
self.callback.on_error(str(e))
async def _async_tts(self, text):
"""异步执行TTS请求"""
header = {"Authorization": f"Bearer; {self.token}"}
request_json = {
"app": {
"appid": self.appid,
"token": "access_token", # 固定值
"cluster": self.cluster
},
"user": {
"uid": str(uuid.uuid4()) # 随机用户ID
},
"audio": {
"voice_type": self.voice_type,
"encoding": self.format,
"speed_ratio": 1.0,
"volume_ratio": 1.0,
"pitch_ratio": 1.0,
},
"request": {
"reqid": str(uuid.uuid4()),
"text": text,
"text_type": "plain",
"operation": "submit" # 使用submit模式支持流式
}
}
# 构建请求数据
payload_bytes = str.encode(json.dumps(request_json))
payload_bytes = gzip.compress(payload_bytes)
full_client_request = bytearray(self.default_header)
full_client_request.extend(len(payload_bytes).to_bytes(4, 'big'))
full_client_request.extend(payload_bytes)
try:
async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws:
self.ws = ws
await ws.send(full_client_request)
# 接收音频数据
while True:
res = await ws.recv()
done = self._parse_response(res)
if done:
self.callback.on_complete()
break
except Exception as e:
logging.error(f"DoubaoTTS--1 WebSocket error: {e}")
self.callback.on_error(str(e))
finally:
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def _parse_response(self, res):
"""解析豆包返回的二进制响应"""
# 协议头解析 (4字节)
header_size = res[0] & 0x0f
message_type = res[1] >> 4
payload = res[header_size * 4:]
# 音频数据响应
if message_type == 0xb: # audio-only server response
message_flags = res[1] & 0x0f
# ACK消息忽略
if message_flags == 0:
return False
# 音频数据消息
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
audio_data = payload[8:8 + payload_size]
if audio_data:
self.total_data_size = self.total_data_size + len(audio_data)
self.callback.on_data(audio_data)
#logging.info(f"doubao _parse_response: {sequence_number} {len(audio_data)} {self.total_data_size}")
# 序列号为负表示结束
return sequence_number < 0
# 错误响应
elif message_type == 0xf:
code = int.from_bytes(payload[:4], "big", signed=False)
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
error_msg = payload[8:8 + msg_size]
try:
# 尝试解压错误消息
error_msg = gzip.decompress(error_msg).decode()
except:
error_msg = error_msg.decode(errors='ignore')
logging.error(f"DoubaoTTS error: {error_msg}")
self.callback.on_error(error_msg)
return False
return False
class UnifiedTTSEngine:
def __init__(self):
self.lock = threading.Lock()
self.tasks = {}
self.executor = ThreadPoolExecutor(max_workers=10)
self.cache_expire = 300 # 5分钟缓存
# 启动清理过期任务的定时器
self.cleanup_timer = None
self.start_cleanup_timer()
def _cleanup_old_tasks(self):
"""清理过期任务"""
now = time.time()
with self.lock:
expired_ids = [task_id for task_id, task in self.tasks.items()
if now - task['created_at'] > self.cache_expire]
for task_id in expired_ids:
self._remove_task(task_id)
def _remove_task(self, task_id):
"""移除任务"""
if task_id in self.tasks:
task = self.tasks.pop(task_id)
# 取消可能的后台任务
if 'future' in task and not task['future'].done():
task['future'].cancel()
# 其他资源在任务被移除后会被垃圾回收
# 资源释放机制总结:
# 移除任务引用self.tasks.pop() 解除任务对象引用,触发垃圾回收。
# 取消后台线程future.cancel() 终止未完成线程,释放线程资源。
# 自动内存回收Python GC 回收任务对象及其队列、缓冲区占用的内存。
# 线程池管理:执行器自动回收线程至池中,避免资源泄漏。
def create_tts_task(self, text, format, sample_rate, model_name, key, delay_gen_audio=False):
"""创建TTS任务同步方法"""
self._cleanup_old_tasks()
audio_stream_id = str(uuid.uuid4())
# 创建任务数据结构
task_data = {
'id': audio_stream_id,
'text': text,
'format': format,
'sample_rate': sample_rate,
'model_name': model_name,
'key': key,
'delay_gen_audio': delay_gen_audio,
'created_at': time.time(),
'status': 'pending',
'data_queue': deque(),
'event': threading.Event(),
'completed': False,
'error': None
}
with self.lock:
self.tasks[audio_stream_id] = task_data
# 如果不是延迟模式,立即启动任务
if not delay_gen_audio:
self._start_tts_task(audio_stream_id)
return audio_stream_id
def _start_tts_task(self, audio_stream_id):
# 启动TTS任务后台线程
task = self.tasks.get(audio_stream_id)
if not task or task['status'] != 'pending':
return
logging.info("已经启动 start tts task {audio_stream_id}")
task['status'] = 'processing'
# 在后台线程中执行TTS
future = self.executor.submit(self._run_tts_sync, audio_stream_id)
task['future'] = future
# 如果需要等待任务完成
if not task.get('delay_gen_audio', True):
try:
# 等待任务完成最多5分钟
future.result(timeout=300)
logging.info(f"TTS任务 {audio_stream_id} 已完成")
self._merge_audio_data(audio_stream_id)
except concurrent.futures.TimeoutError:
task['error'] = "TTS生成超时"
task['completed'] = True
logging.error(f"TTS任务 {audio_stream_id} 超时")
except Exception as e:
task['error'] = f"ERROR:{str(e)}"
task['completed'] = True
logging.exception(f"TTS任务执行异常: {str(e)}")
def _run_tts_sync(self, audio_stream_id):
# 同步执行TTS生成 在后台线程中执行
task = self.tasks.get(audio_stream_id)
if not task:
return
try:
# 创建完成事件
completion_event = threading.Event()
# 创建TTS实例
# 根据model_name选择TTS引擎
# 前端传入 cosyvoice-v1/longhua@Tongyi-Qianwen
model_name_wo_brand = task['model_name'].split('@')[0]
model_name_version = model_name_wo_brand.split('/')[0]
if "longhua" in task['model_name'] or "zh_female_qingxinnvsheng_mars_bigtts" in task['model_name']:
# 豆包TTS
tts = DoubaoTTS(
key=task['key'],
format=task['format'],
sample_rate=task['sample_rate'],
model_name=task['model_name']
)
else:
# 通义千问TTS
tts = QwenTTS(
key=task['key'],
format=task['format'],
sample_rate=task['sample_rate'],
model_name=task['model_name']
)
# 定义同步数据处理函数
def data_handler(data):
if data is None: # 结束信号
task['completed'] = True
task['event'].set()
logging.info(f"--data_handler on_complete")
elif data.startswith(b"ERROR"): # 错误信号
task['error'] = data.decode()
task['completed'] = True
task['event'].set()
else: # 音频数据
task['data_queue'].append(data)
# 设置并执行TTS
synthesizer = tts.setup_tts(data_handler,completion_event)
#synthesizer.call(task['text'])
tts.text_tts_call(task['text'])
# 等待完成或超时
# 等待完成或超时
if not completion_event.wait(timeout=300): # 5分钟超时
task['error'] = "TTS generation timeout"
task['completed'] = True
logging.info(f"--tts task event set error = {task['error']}")
except Exception as e:
logging.info(f"UnifiedTTSEngine _run_tts_sync ERROR: {str(e)}")
task['error'] = f"ERROR:{str(e)}"
task['completed'] = True
finally:
# 确保清理TTS资源
logging.info("UnifiedTTSEngine _run_tts_sync finally")
if hasattr(tts, 'loop') and tts.loop:
tts.loop.close()
def _merge_audio_data(self, audio_stream_id):
"""将任务的所有音频数据合并到ByteIO缓冲区"""
task = self.tasks.get(audio_stream_id)
if not task or not task.get('completed'):
return
try:
logging.info(f"开始合并音频数据: {audio_stream_id}")
# 创建内存缓冲区
buffer = io.BytesIO()
# 合并所有数据块
for data_chunk in task['data_queue']:
buffer.write(data_chunk)
# 重置指针位置以便读取
buffer.seek(0)
# 保存到任务对象
task['buffer'] = buffer
logging.info(f"音频数据合并完成,总大小: {buffer.getbuffer().nbytes} 字节")
# 可选:清理原始数据队列以节省内存
task['data_queue'].clear()
except Exception as e:
logging.error(f"合并音频数据失败: {str(e)}")
task['error'] = f"合并错误: {str(e)}"
async def get_audio_stream(self, audio_stream_id):
"""获取音频流(异步生成器)"""
task = self.tasks.get(audio_stream_id)
if not task:
raise RuntimeError("Audio stream not found")
# 如果是延迟任务且未启动,现在启动 status 为 pending
if task['delay_gen_audio'] and task['status'] == 'pending':
self._start_tts_task(audio_stream_id)
total_audio_data_size = 0
# 等待任务启动
while task['status'] == 'pending':
await asyncio.sleep(0.1)
# 流式返回数据
while not task['completed'] or task['data_queue']:
while task['data_queue']:
data = task['data_queue'].popleft()
total_audio_data_size += len(data)
#logging.info(f"yield audio data {len(data)} {total_audio_data_size}")
yield data
# 短暂等待新数据
await asyncio.sleep(0.05)
# 检查错误
if task['error']:
raise RuntimeError(task['error'])
def start_cleanup_timer(self):
"""启动定时清理任务"""
if self.cleanup_timer:
self.cleanup_timer.cancel()
self.cleanup_timer = threading.Timer(30.0, self.cleanup_task) # 每30秒清理一次
self.cleanup_timer.daemon = True # 设置为守护线程
self.cleanup_timer.start()
def cleanup_task(self):
"""执行清理任务"""
try:
self._cleanup_old_tasks()
except Exception as e:
logging.error(f"清理任务时出错: {str(e)}")
finally:
self.start_cleanup_timer() # 重新启动定时器
# 全局 TTS 引擎实例
tts_engine = UnifiedTTSEngine()
def replace_domain(url: str) -> str:
"""替换URL中的域名为本地地址不使用urllib.parse"""
# 定义需要替换的域名列表
domains_to_replace = [
"http://1.13.185.116:9380",
"https://ragflow.szzysztech.com",
"1.13.185.116:9380",
"ragflow.szzysztech.com"
]
# 尝试替换每个可能的域名
for domain in domains_to_replace:
if domain in url:
# 直接替换域名部分
return url.replace(domain, "http://localhost:9380", 1)
# 如果未匹配到特定域名,尝试智能替换
if "://" in url:
# 分割协议和路径
protocol, path = url.split("://", 1)
# 查找第一个斜杠位置来确定域名结束位置
slash_pos = path.find("/")
if slash_pos > 0:
# 替换域名部分
return f"http://localhost:9380{path[slash_pos:]}"
else:
# 没有路径部分,直接返回本地地址
return "http://localhost:9380"
else:
# 没有协议部分,直接添加本地地址
return f"http://localhost:9380/{url}"
async def proxy_aichat_audio_stream(client_id: str, audio_url: str):
"""代理外部音频流请求"""
try:
# 替换域名为本地地址
local_url = audio_url
logging.info(f"代理音频流: {audio_url} -> {local_url}")
async with httpx.AsyncClient(timeout=60.0) as client:
async with client.stream("GET", local_url) as response:
# 流式转发音频数据
async for chunk in response.aiter_bytes():
if not await manager.send_bytes(client_id, chunk):
logging.warning(f"Audio proxy interrupted for {client_id}")
return
except Exception as e:
logging.error(f"Audio proxy failed: {str(e)}")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": f"音频流获取失败: {str(e)}"
}))
# 代理函数 - 文本流
# 在微信小程序中原来APK使用的SSE机制不能正常工作需要使用WebSocket
async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict):
"""代理大模型文本流请求 - 兼容现有Flask实现"""
try:
logging.info(f"代理文本流: completions_url={completions_url} {payload}")
logging.debug(f"请求负载: {json.dumps(payload, ensure_ascii=False)}")
headers = {
"Content-Type": "application/json",
'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm'
}
tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
#if 'longyuan' in tts_model_name:
# tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen"
# 创建TTS实例
tts_model = QwenTTS(
key=ALI_KEY,
format=payload.get('tts_stream_format', 'mp3'),
sample_rate=payload.get('tts_sample_rate', 48000),
model_name=tts_model_name
)
streaming_call = False
if tts_model.is_cosyvoice:
streaming_call = True
# 创建流会话
tts_stream_session_id = stream_manager.create_session(
tts_model=tts_model,
sample_rate=payload.get('tts_sample_rate', 48000),
stream_format=payload.get('tts_stream_format', 'mp3'),
session_id=None,
streaming_call= streaming_call
)
# logging.info(f"---tts_stream_session_id = {tts_stream_session_id}")
tts_stream_session_id_sent = False
send_sentence_tts_url = False
# 添加一个事件来标记所有句子已发送
all_sentences_sent = asyncio.Event()
# 任务:监听并发送新生成的句子
async def send_new_sentences():
"""监听并发送新生成的句子"""
try:
while True:
# 获取下一个句子
sentence_info = stream_manager.get_next_sentence(tts_stream_session_id)
if sentence_info:
logging.info(f"--proxy_aichat_text_stream 发送sentence_info\r\n")
# 发送句子信息
await manager.send_json(client_id, {
"type": "tts_sentence",
"id": sentence_info['id'],
"text": stream_manager.get_sentence_info(sentence_info['id'])['text'],
"url": sentence_info['url']
})
else:
# 检查会话是否结束且没有更多句子
session = stream_manager.get_session(tts_stream_session_id)
if not session or (not session['active']):
#and session['current_sentence_index'] >= len(session['sentences'])):
# 标记所有句子已发送
all_sentences_sent.set()
break
# 等待新句子生成
await asyncio.sleep(0.1)
except asyncio.CancelledError:
logging.info("句子监听任务被取消")
except Exception as e:
logging.error(f"句子监听任务出错: {str(e)}")
all_sentences_sent.set()
if send_sentence_tts_url:
# 启动句子监听任务
sentence_task = asyncio.create_task(send_new_sentences())
# 使用更长的超时时间 (5分钟)
timeout = httpx.Timeout(300.0, connect=60.0)
async with httpx.AsyncClient(timeout=timeout) as client:
# 关键修改:使用流式请求模式
async with client.stream( # <-- 使用stream方法
"POST",
completions_url,
json=payload,
headers=headers
) as response:
logging.info(f"响应状态: HTTP {response.status_code}")
if response.status_code != 200:
# 读取错误信息(非流式)
error_content = await response.aread()
error_msg = f"后端错误: HTTP {response.status_code}"
error_msg += f" - {error_content[:200].decode()}" if error_content else ""
await manager.send_text(client_id, json.dumps({"type": "error", "message": error_msg}))
return
# 验证SSE流
content_type = response.headers.get("content-type", "").lower()
if "text/event-stream" not in content_type:
logging.warning("非流式响应,转发完整内容")
full_content = await response.aread()
await manager.send_text(client_id, json.dumps({
"type": "text",
"data": full_content.decode('utf-8')
}))
return
logging.info("开始处理SSE流")
event_count = 0
# 使用异步迭代器逐行处理
async for line in response.aiter_lines():
# 跳过空行和注释行
if not line or line.startswith(':'):
continue
# 处理SSE事件
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str: # 过滤空数据
try:
# 解析并提取增量文本
data_obj = json.loads(data_str)
delta_text = None
if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict):
delta_text = data_obj.get('data', None).get('delta_ans', "")
if tts_stream_session_id_sent is False:
logging.info(f"--proxy_aichat_text_stream 发送audio_stream_url")
data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}"
data_str = json.dumps(data_obj)
tts_stream_session_id_sent = True
# 直接转发原始数据
await manager.send_text(client_id, json.dumps({
"type": "text",
"data": data_str
}))
# 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析
if delta_text:
# 追加到会话管理器
stream_manager.append_text(tts_stream_session_id, delta_text)
# logging.info(f"文本代理转发: {data_str}")
event_count += 1
except Exception as e:
logging.error(f"事件发送失败: {str(e)}")
# 保持连接活性
await asyncio.sleep(0.001) # 避免CPU空转
logging.info(f"SSE流处理完成事件数: {event_count}")
# 发送文本流结束信号
await manager.send_text(client_id, json.dumps({"type": "end"}))
# 标记文本输入结束
if stream_manager.finish_text_input:
stream_manager.finish_text_input(tts_stream_session_id)
if send_sentence_tts_url:
# 等待所有句子生成并发送最多等待300秒
try:
await asyncio.wait_for(all_sentences_sent.wait(), timeout=300.0)
logging.info(f"所有TTS句子已发送")
except asyncio.TimeoutError:
logging.warning("等待TTS句子发送超时")
# 取消句子监听任务(如果仍在运行)
if not sentence_task.done():
sentence_task.cancel()
try:
await sentence_task
except asyncio.CancelledError:
pass
except httpx.ReadTimeout:
logging.error("读取后端服务超时")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": "后端服务响应超时"
}))
except httpx.ConnectError as e:
logging.error(f"连接后端服务失败: {str(e)}")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": f"无法连接到后端服务: {str(e)}"
}))
except Exception as e:
logging.exception(f"文本代理失败: {str(e)}")
await manager.send_text(client_id, json.dumps({
"type": "error",
"message": f"文本流获取失败: {str(e)}"
}))
@tts_router.get("/audio/pcm_mp3")
async def stream_mp3():
def audio_generator():
path = './test.mp3'
try:
with open(path, 'rb') as f:
while True:
chunk = f.read(1024)
if not chunk:
break
yield chunk
except Exception as e:
logging.error(f"MP3 streaming error: {str(e)}")
return StreamingResponse(
audio_generator(),
media_type="audio/mpeg",
headers={
"Cache-Control": "no-store",
"Accept-Ranges": "bytes"
}
)
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
"""动态生成WAV头严格保持原有逻辑结构"""
with BytesIO() as wav_buffer:
with wave.open(wav_buffer, 'wb') as wav_file:
wav_file.setnchannels(1) # 保持原单声道设置
wav_file.setsampwidth(2) # 保持原16-bit设置
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
wav_buffer.seek(0)
return wav_buffer.read()
def generate_silence_header(duration_ms: int = 500) -> bytes:
"""生成静音数据用于MP3流式传输预缓冲"""
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
# ------------------------ API路由 ------------------------
@tts_router.get("/tts_sentence/{sentence_id}")
async def get_sentence_audio(sentence_id: str):
# 获取音频数据
audio_data = stream_manager.get_sentence_audio(sentence_id)
if not audio_data:
raise HTTPException(status_code=404, detail="Audio not found")
# 获取音频格式
sentence_info = stream_manager.get_sentence_info(sentence_id)
if not sentence_info:
raise HTTPException(status_code=404, detail="Sentence info not found")
# 确定MIME类型
format = sentence_info['format']
media_type = "audio/mpeg" if format == "mp3" else "audio/wav"
logging.info(f"--http get sentence tts audio stream {sentence_id}")
# 返回流式响应
return StreamingResponse(
io.BytesIO(audio_data),
media_type=media_type,
headers={
"Content-Disposition": f"attachment; filename=audio.{format}",
"Cache-Control": "max-age=3600" # 缓存1小时
}
)
@tts_router.post("/chats/{chat_id}/tts")
async def create_tts_request(chat_id: str, request: Request):
try:
data = await request.json()
logging.info(f"Creating TTS request: {data}")
# 参数校验
text = data.get("text", "").strip()
if not text:
raise HTTPException(400, detail="Text cannot be empty")
format = data.get("tts_stream_format", "mp3")
if format not in ["mp3", "wav", "pcm"]:
raise HTTPException(400, detail="Unsupported audio format")
sample_rate = data.get("tts_sample_rate", 48000)
if sample_rate not in [8000, 16000, 22050, 44100, 48000]:
raise HTTPException(400, detail="Unsupported sample rate")
model_name = data.get("model_name", "cosyvoice-v1/longxiaochun")
delay_gen_audio = data.get('delay_gen_audio', False)
# 创建TTS任务
audio_stream_id = tts_engine.create_tts_task(
text=text,
format=format,
sample_rate=sample_rate,
model_name=model_name,
key=ALI_KEY,
delay_gen_audio=delay_gen_audio
)
return JSONResponse(
status_code=200,
content={
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"ws_url": f"/chats/{chat_id}/tts/{audio_stream_id}", # WebSocket URL 2025 0622新增
"expires_at": (datetime.datetime.now() + datetime.timedelta(seconds=300)).isoformat()
}
)
except Exception as e:
logging.error(f"Request failed: {str(e)}")
raise HTTPException(500, detail="Internal server error")
executor = ThreadPoolExecutor()
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio(
chat_id: str,
audio_stream_id: str,
range: str = Header(None)
):
try:
# 获取任务信息
task = tts_engine.tasks.get(audio_stream_id)
if not task:
# 返回友好的错误信息而不是抛出异常
return JSONResponse(
status_code=404,
content={
"error": "Audio stream not found",
"message": f"The requested audio stream ID '{audio_stream_id}' does not exist or has expired",
"suggestion": "Please create a new TTS request and try again"
}
)
# 获取媒体类型
format = task['format']
media_type = {
"mp3": "audio/mpeg",
"wav": "audio/wav",
"pcm": f"audio/L16; rate={task['sample_rate']}; channels=1"
}[format]
# 如果任务已完成且有完整缓冲区处理Range请求
logging.info(f"get_tts_audio task = {task.get('completed', 'None')} {task.get('buffer', 'None')}")
# 创建响应内容生成器
def buffer_read(buffer):
content_length = buffer.getbuffer().nbytes
remaining = content_length
chunk_size = 4096
buffer.seek(0)
while remaining > 0:
read_size = min(remaining, chunk_size)
data = buffer.read(read_size)
if not data:
break
yield data
remaining -= len(data)
if task.get('completed') and task.get('buffer') is not None:
buffer = task['buffer']
total_size = buffer.getbuffer().nbytes
# 强制小文件使用流式传输避免206响应问题
if total_size < 1024 * 120: # 小于300KB
range = None
if range:
# 处理范围请求
return handle_range_request(range, buffer, total_size, media_type)
else:
return StreamingResponse(
buffer_read(buffer),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Cache-Control": "no-store",
"Transfer-Encoding": "chunked"
})
# 创建流式响应
logging.info("tts_engine.get_audio_stream--0")
return StreamingResponse(
tts_engine.get_audio_stream(audio_stream_id),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Cache-Control": "no-store",
"Transfer-Encoding": "chunked"
}
)
except Exception as e:
logging.error(f"Audio streaming failed: {str(e)}")
raise HTTPException(500, detail="Audio generation error")
def handle_range_request(range_header: str, buffer: BytesIO, total_size: int, media_type: str):
"""处理 HTTP Range 请求"""
try:
# 解析 Range 头部 (示例: "bytes=0-1023")
range_type, range_spec = range_header.split('=')
if range_type != 'bytes':
raise ValueError("Unsupported range type")
start_str, end_str = range_spec.split('-')
start = int(start_str)
end = int(end_str) if end_str else total_size - 1
logging.info(f"handle_range_request--1 {start_str}-{end_str} {end}")
# 验证范围有效性
if start >= total_size or end >= total_size:
raise HTTPException(status_code=416, headers={
"Content-Range": f"bytes */{total_size}"
})
# 计算内容长度
content_length = end - start + 1
# 设置状态码
status_code = 206 # Partial Content
if start == 0 and end == total_size - 1:
status_code = 200 # Full Content
# 设置流读取位置
buffer.seek(start)
# 创建响应内容生成器
def content_generator():
remaining = content_length
chunk_size = 4096
while remaining > 0:
read_size = min(remaining, chunk_size)
data = buffer.read(read_size)
if not data:
break
yield data
remaining -= len(data)
# 返回分块响应
return StreamingResponse(
content_generator(),
status_code=status_code,
media_type=media_type,
headers={
"Content-Range": f"bytes {start}-{end}/{total_size}",
"Content-Length": str(content_length),
"Accept-Ranges": "bytes",
"Cache-Control": "public, max-age=3600"
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@tts_router.websocket("/chats/{chat_id}/tts/{audio_stream_id}")
async def websocket_tts_endpoint(
websocket: WebSocket,
chat_id: str,
audio_stream_id: str
):
# 接收 header 参数
headers = websocket.headers
service_type = headers.get("x-tts-type") # 注意header 名称转为小写
# audio_url = headers.get("x-audio-url")
"""
前端示例
websocketConnection = uni.connectSocket({
url: url,
header: {
'Authorization': token,
'X-Tts-Type': 'AiChat', //'Ask' // 自定义参数1
'X-Device-Type': 'mobile', // 自定义参数2
'X-User-ID': '12345' // 自定义参数3
},
success: () => {
console.log('WebSocket connected');
},
fail: (err) => {
console.error('WebSocket connection failed:', err);
}
});
"""
# 创建唯一连接 ID
connection_id = str(uuid.uuid4())
# logging.info(f"---dale-- websocket connection_id = {connection_id} chat_id={chat_id}")
await manager.connect(websocket, connection_id)
completed_successfully = False
try:
# 根据tts_type路由到不同的音频源
if service_type == "AiChatTts":
# 音频代理服务
audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}"
# await proxy_aichat_audio_stream(connection_id, audio_url)
sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate')
audio_data_size =0
await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate})
async for data in stream_manager.get_tts_buffer_data(audio_stream_id):
if data.get('type') == 'sentence_end':
await manager.send_json(connection_id, {"command": "sentence_end"})
if data.get('type') == 'arraybuffer':
audio_data_size += len(data.get('data'))
if not await manager.send_bytes(connection_id, data.get('data')):
break
completed_successfully = True
logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}")
elif service_type == "AiChatText":
# 文本代理服务
# 等待客户端发送初始请求数据 进行大模型对话代理时需要前端连接后发送payload
payload = await websocket.receive_json()
completions_url = f"http://localhost:9380/api/v1/chats/{chat_id}/completions"
await proxy_aichat_text_stream(connection_id, completions_url, payload)
completed_successfully = True
else:
# 使用引擎的生成器直接获取音频流
async for data in tts_engine.get_audio_stream(audio_stream_id):
if not await manager.send_bytes(connection_id, data):
logging.warning(f"Send failed, connection closed: {connection_id}")
break
await manager.send_json(connection_id, {"command": "sentence_end"})
completed_successfully = True
# 发送完成信号前检查连接状态
if manager.is_connected(connection_id):
# 发送完成信号
await manager.send_json(connection_id, {"status": "completed"})
# 添加短暂延迟确保消息送达
await asyncio.sleep(0.1)
# 主动关闭WebSocket连接
await manager.disconnect(connection_id, code=1000, reason="Audio stream completed")
except WebSocketDisconnect:
logging.info(f"WebSocket disconnected: {connection_id}")
except Exception as e:
logging.error(f"WebSocket TTS error: {str(e)}")
if manager.is_connected(connection_id):
await manager.send_json(connection_id, {"error": str(e)})
finally:
pass
# await manager.disconnect(connection_id)
def cleanup_cache():
"""清理过期缓存"""
with cache_lock:
now = datetime.datetime.now()
expired = [k for k, v in audio_text_cache.items()
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
for key in expired:
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
del audio_text_cache[key]
# 应用启动时启动清理线程
# start_background_cleaner()