import logging import binascii from copy import deepcopy from timeit import default_timer as timer import datetime from datetime import timedelta import threading, time, queue, uuid, time, array from threading import Lock, Thread from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from multiprocessing import Manager import base64, gzip import os, io, re, json from io import BytesIO from typing import Optional, Dict, Any import asyncio, httpx from collections import deque import websockets import uuid from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query from fastapi import FastAPI, UploadFile, File, Form, Header from fastapi.responses import StreamingResponse, JSONResponse, Response TTS_SAMPLERATE = 44100 # 22050 # 16000 FORMAT = "mp3" CHANNELS = 1 # 单声道 SAMPLE_WIDTH = 2 # 16-bit = 2字节 tts_router = APIRouter() # 路由器专属的生命周期管理器 @asynccontextmanager async def tts_lifespan(app: FastAPI): """tts_service路由模块的生命周期管理器""" print("tts_service路由器正在启动...") try: yield finally: print("tts_service路由器正在关闭...") # logger = logging.getLogger(__name__) class MillisecondsFormatter(logging.Formatter): """自定义日志格式器,添加毫秒时间戳""" def formatTime(self, record, datefmt=None): # 将时间戳转换为本地时间元组 ct = self.converter(record.created) # 格式化为 "小时:分钟:秒" t = time.strftime("%H:%M:%S", ct) # 添加毫秒(3位) return f"{t}.{int(record.msecs):03d}" # 配置全局日志格式 def configure_logging(): # 创建 Formatter log_format = "%(asctime)s - %(levelname)s - %(message)s" formatter = MillisecondsFormatter(log_format) # 获取根 Logger 并清除已有配置 root_logger = logging.getLogger() root_logger.handlers = [] # 创建并配置 Handler(输出到控制台) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) # 设置日志级别并添加 Handler root_logger.setLevel(logging.INFO) root_logger.addHandler(console_handler) # 调用配置函数(程序启动时运行一次) configure_logging() class StreamSessionManager: def __init__(self,manager = None): #self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}} #self.lock = threading.Lock() self.sessions = {} self.lock = threading.Lock() self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池 self.gc_interval = 300 # 5分钟清理一次 self.streaming_call_timeout = 10 # 10s self.gc_tts = 3 # 3s self.sentence_timeout = 2 # 2000ms句子超时 self.sentence_endings = set('。?!;.?!;') # 中英文结束符 # 增强版正则表达式:匹配中英文句子结束符(包含全角) self.sentence_pattern = re.compile( #r'([,,。?!;.?!;?!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;?!;…])' r'((?:(? self.sentence_timeout: text_to_process = "" with self.lock: if session['text_buffer']: text_to_process = session['text_buffer'] if text_to_process: # 直接处理剩余文本 session['current_processing'] = True gen_tts_audio_func(session_id, session['text_buffer']) session['text_buffer'] = "" # 重置处理状态 session['current_processing'] = False """ # 4. 会话超时检查 if current_time - session['last_active'] > session['gc_interval']: # 处理剩余文本 with self.lock: if session['text_buffer']: session['current_processing'] = True # 处理最后一段文本 gen_tts_audio_func(session_id, session['text_buffer']) session['text_buffer'] = "" # 重置处理状态 session['current_processing'] = False # 关闭会话 logging.info(f"--_process_tasks-- timeout {current_time} {session['last_active']} {session['gc_interval']}") self.close_session(session_id) break """ # 5. 休眠避免CPU空转 time.sleep(0.05) # 50ms检查间隔 def _generate_audio(self, session_id, text): # 20250718 新更新 """实际生成音频(顺序执行)- 用于非流式引擎""" session = self.sessions.get(session_id) if not session or session.get('should_stop', False): return try: #logging.info(f"StreamSessionManager _generate_audio--0 {text}") def on_data_whole(data: bytes): if data: try: session['last_active'] = time.time() session['buffer'].put({'type':'arraybuffer','data':data}) #session['buffer'].put(data) session['audio_chunk_size'] += len(data) #logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}") except queue.Full: logging.warning(f"Audio buffer full for session {session_id}") # 重置完成事件 session['sentence_complete_event'].clear() session['tts_model'].setup_tts(on_data = on_data_whole,completion_event=session['sentence_complete_event']) # 调用 TTS session['tts_model'].text_tts_call(text) session['last_active'] = time.time() session['audio_chunk_count'] += 1 # 等待句子完成,但会检查停止标志 start_time = time.time() timeout_or_stopped = True while not session['sentence_complete_event'].wait(timeout=0.5): # 每0.5秒检查一次 # 检查是否超时或收到停止信号 if time.time() - start_time > 30 : timeout_or_stopped = True break if session.get('should_stop',False): timeout_or_stopped = False break logging.info(f"StreamSessionManager _generate_audio 转换结束!!!" f"{session['audio_chunk_size']} {session_id} " f"收到停止信号: {not timeout_or_stopped}") session['buffer'].put({'type': 'sentence_end', 'data': ""}) if not session['tts_chunk_data_valid']: session['tts_chunk_data_valid'] = True except Exception as e: session['buffer'].put(f"ERROR:{str(e)}".encode()) session['sentence_complete_event'].set() # 确保事件被设置 def _stream_audio(self, session_id, text): """流式传输文本到TTS服务""" session = self.sessions.get(session_id) if not session: return # logging.info(f"Streaming text to TTS: {text}") try: # 重置完成事件状态 session['sentence_complete_event'].clear() # 使用流式调用发送文本 session['tts_model'].streaming_call(text) session['last_active'] = time.time() # 流式引擎不需要等待完成事件 session['sentence_complete_event'].set() except Exception as e: logging.error(f"Error in streaming_call: {str(e)}") session['buffer'].put(f"ERROR:{str(e)}".encode()) session['sentence_complete_event'].set() async def get_tts_buffer_data(self, session_id): """异步流式返回 TTS 音频数据(适配同步 queue.Queue,带 10 秒超时)""" session = self.sessions.get(session_id) if not session: raise ValueError(f"Session {session_id} not found") buffer = session['buffer'] # 这里是 queue.Queue last_data_time = time.time() # 记录最后一次获取数据的时间 get_tts_audio_size = 0 get_tts_audio_return = 0 while session['active']: try: # 检查会话是否被标记停止 if session.get('should_stop', False): logging.info(f"会话被标记停止: {session_id}") break # 使用 run_in_executor + wait_for 设置 10 秒超时 data = await asyncio.wait_for( asyncio.get_event_loop().run_in_executor(self.executor, buffer.get), timeout=10.0 # 10 秒超时 ) get_tts_audio_return += 1 # 检查停止信号 if isinstance(data, dict) and data.get('type') == 'stop_signal': logging.info(f"StreamSessionManager get_tts_buffer_data收到停止信号: {session_id}") break if isinstance(data, dict) and data.get('data'): get_tts_audio_size += len(data['data']) #logging.info(f"get_tts_buffer_data {data}") last_data_time = time.time() # 更新最后数据时间 yield data except asyncio.TimeoutError: # 10 秒内没有新数据,检查是否超时 if time.time() - last_data_time >= 10.0: logging.info(f"get_tts_buffer_data {session_id} Timeout after 10 seconds " f"data_size={get_tts_audio_size} qsize={buffer.qsize()} {get_tts_audio_return}") if buffer.qsize() >= 10: # 获取线程池状态 - 使用自定义的 executor active_threads = len(threading.enumerate()) # 当前系统线程数 # 获取线程池特定信息 pool_threads = [ t for t in threading.enumerate() if t.name.startswith("ThreadPoolExecutor") ] # 安全地获取等待任务数 pending_tasks = 0 if hasattr(self.executor, '_work_queue'): pending_tasks = self.executor._work_queue.qsize() logging.warning( f"[{threading.current_thread().name}] Timeout: " f"System threads={active_threads}, " f"ThreadPool threads={len(pool_threads)}, " f"Pending tasks={pending_tasks}" ) # 202507 调试发现,偶尔存在队列中已经生成了TTS音频数据,但是上述从队列中获取数据时,不能成功 # 所以做如下的复位关键资源的操作 dale_yxc self.reset_manager() # 出现取不到队列数据的异常,重置关键资源 break else: continue # 未超时,继续等待 except asyncio.CancelledError: logging.info(f"Session {session_id} stream cancelled") break except Exception as e: logging.error(f"Error in get_tts_buffer_data: {e}") break def close_session(self, session_id): with self.lock: if session_id in self.sessions: session = self.sessions[session_id] session['active'] = False # 设置完成事件 session['sentence_complete_event'].set() # 清理TTS资源 try: if session.get('streaming_call'): session['tts_model'].end_streaming_call() except: pass # 延迟清理会话 threading.Timer(1, self._clean_session, args=[session_id]).start() def _clean_session(self, session_id): with self.lock: if session_id in self.sessions: del self.sessions[session_id] def _cleanup_expired(self): """优化后的清理方法""" while True: time.sleep(30) now = time.time() to_remove = [] # 快速收集需要清理的会话ID(减少锁持有时间) with self.lock: for sid, session in list(self.sessions.items()): if now - session['last_active'] > self.gc_interval: to_remove.append(sid) # 在锁外执行实际清理 for sid in to_remove: self._clean_session(sid) logging.info(f"清理资源: {len(to_remove)}会话") def stop_session(self, session_id: str): """停止指定会话的音频生成""" if session_id in self.sessions: session = self.sessions[session_id] # 设置停止标志 session['should_stop'] = True # 如果使用队列,放入停止标记 if 'buffer' in session: try: # 放入特殊值通知生成循环退出 session['buffer'].put({"type": "stop_signal"}) except: pass # 设置完成事件,确保任务处理线程能退出 session['sentence_complete_event'].set() logging.info(f"StreamSessionManager stop_session: {session_id}") def get_session(self, session_id): return self.sessions.get(session_id) def _has_sentence_ending(self, text): """检测文本是否包含句子结束符""" if not text: return False # 检查常见结束符(包含全角字符) if any(char in self.sentence_endings for char in text[-3:]): return True # 检查中文段落结束(换行符前有结束符) if '\n' in text and any(char in self.sentence_endings for char in text.split('\n')[-2:-1]): return True return False def _split_and_extract(self, text): """ 增强型句子分割器 返回: (完整句子列表, 剩余文本) 限定返回长度 """ # 特殊处理:如果文本以逗号开头,先处理前面的部分 if text.startswith((",", ",")): return [text[0]], text[1:] # 1. 查找所有可能的句子结束位置 matches = list(self.sentence_pattern.finditer(text)) if not matches: return [], text # 没有找到结束符 # 2. 确定最后一个完整句子的结束位置 last_end = 0 complete_sentences = [] for match in matches: end_pos = match.end() sentence = text[last_end:end_pos].strip() # 跳过空句子 if not sentence: last_end = end_pos continue # 检查是否为有效句子(最小长度或包含结束符) if (len(sentence) > 6 or any(char in "。.?!?!" for char in sentence)) and (last_end<24): complete_sentences.append(sentence) last_end = end_pos else: # 短文本但包含结束符,可能是特殊符号 if any(char in "。.?!?!" for char in sentence): complete_sentences.append(sentence) last_end = end_pos # 3. 提取剩余文本 remaining_text = text[last_end:].strip() return complete_sentences, remaining_text def reset_manager(self): """完全重置管理器 - 极简版本""" logging.critical("Resetting StreamSessionManager...") # 步骤1: 关闭所有会话 for session_id in list(self.sessions.keys()): try: # 直接清理会话而不调用close_session if session_id in self.sessions: # 尝试释放TTS资源 try: if self.sessions[session_id].get('tts_model'): self.sessions[session_id]['tts_model'].cleanup() except Exception: pass del self.sessions[session_id] except Exception: pass # 步骤2: 重置关键资源 # 重建线程池 try: self.executor.shutdown(wait=False) except Exception: pass self.executor = ThreadPoolExecutor(max_workers=30) # 清空会话字典 self.sessions = {} # 重建锁对象 self.lock = threading.Lock() logging.critical("Reset completed") def get_self_thread_pool_status(self): # 获取线程池状态 - 使用自定义的 executor active_threads = len(threading.enumerate()) # 当前系统线程数 # 获取线程池特定信息 pool_threads = [ t for t in threading.enumerate() if t.name.startswith("ThreadPoolExecutor") ] # 安全地获取等待任务数 pending_tasks = 0 if hasattr(self.executor, '_work_queue'): pending_tasks = self.executor._work_queue.qsize() return ( f"[{threading.current_thread().name}] Timeout: " f"System threads={active_threads}, " f"ThreadPool threads={len(pool_threads)}, " f"Pending tasks={pending_tasks}" ) stream_manager = StreamSessionManager() def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} # WebSocket 连接管理 class ConnectionManager: def __init__(self): self.active_connections = {} self.aichat_audio_sessions = {} # 用于存储aichat音频会话,关联StreamSessionManager的session_id async def connect(self, websocket: WebSocket, connection_id: str): await websocket.accept() self.active_connections[connection_id] = websocket logging.info(f"新连接建立: {connection_id}") # 注册音频会话,关联StreamSessionManager的session_id def register_audio_session(self, connection_id: str, audio_session_id: str): self.aichat_audio_sessions[connection_id] = audio_session_id async def disconnect(self, connection_id: str, code=1000, reason: str = ""): # 通知StreamSessionManager管理器停止生音频 if connection_id in self.aichat_audio_sessions: audio_session_id = self.aichat_audio_sessions[connection_id] stream_manager.stop_session(audio_session_id) del self.aichat_audio_sessions[connection_id] if connection_id in self.active_connections: try: # 尝试正常关闭连接(非阻塞) await self.active_connections[connection_id].close(code=code, reason=reason) except: pass # 忽略关闭错误 finally: del self.active_connections[connection_id] def is_connected(self, connection_id: str) -> bool: """检查连接是否仍然活跃""" return connection_id in self.active_connections async def _safe_send(self, connection_id: str, send_func, *args): """安全发送的通用方法(核心修改)""" # 1. 检查连接是否存在 if connection_id not in self.active_connections: #logging.warning(f"尝试向不存在的连接发送数据: {connection_id}") return False websocket = self.active_connections[connection_id] try: # 2. 检查连接状态(关键修改) if websocket.client_state.name != "CONNECTED": logging.warning(f"连接 {connection_id} 状态为 {websocket.client_state.name}") await self.disconnect(connection_id) return False # 3. 执行发送操作 await send_func(websocket, *args) return True except (WebSocketDisconnect, RuntimeError) as e: # 4. 处理连接断开异常 logging.info(f"发送时检测到断开连接: {connection_id}, {str(e)}") await self.disconnect(connection_id) return False except Exception as e: # 5. 处理其他异常 logging.error(f"发送数据出错: {connection_id}, {str(e)}") await self.disconnect(connection_id) return False async def send_bytes(self, connection_id: str, data: bytes): """安全发送字节数据""" return await self._safe_send( connection_id, lambda ws, d: ws.send_bytes(d), data ) async def send_text(self, connection_id: str, message: str): """安全发送文本数据""" return await self._safe_send( connection_id, lambda ws, m: ws.send_text(m), message ) async def send_json(self, connection_id: str, data: dict): """安全发送JSON数据""" return await self._safe_send( connection_id, lambda ws, d: ws.send_json(d), data ) manager = ConnectionManager() def generate_mp3_header( sample_rate: int, bitrate_kbps: int, channels: int = 1, layer: str = "III" # 新增参数,支持 "I"/"II"/"III" ) -> bytes: """ 动态生成 MP3 帧头(4字节),支持 Layer I/II/III :param sample_rate: 采样率 (8000, 16000, 22050, 44100) :param bitrate_kbps: 比特率(单位 kbps) :param channels: 声道数 (1: 单声道, 2: 立体声) :param layer: 编码层 ("I", "II", "III") :return: 4字节的帧头数据 """ # ---------------------------------- # 参数校验 # ---------------------------------- valid_sample_rates = {8000, 16000, 22050, 44100, 48000} if sample_rate not in valid_sample_rates: raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}") valid_layers = {"I", "II", "III"} if layer not in valid_layers: raise ValueError(f"不支持的层,可选:{valid_layers}") # ---------------------------------- # 确定 MPEG 版本和采样率索引 # ---------------------------------- if sample_rate == 44100: mpeg_version = 0b11 # MPEG-1 sample_rate_index = 0b00 elif sample_rate == 22050: mpeg_version = 0b10 # MPEG-2 sample_rate_index = 0b00 elif sample_rate == 16000: mpeg_version = 0b10 # MPEG-2 sample_rate_index = 0b10 elif sample_rate == 8000: mpeg_version = 0b00 # MPEG-2.5 sample_rate_index = 0b10 else: raise ValueError("采样率与版本不匹配") # ---------------------------------- # 动态选择比特率表(关键扩展) # ---------------------------------- # Layer 编码映射(I:0b11, II:0b10, III:0b01) layer_code = { "I": 0b11, "II": 0b10, "III": 0b01 }[layer] # 比特率表(覆盖所有 Layer) bitrate_tables = { # ------------------------------- # MPEG-1 (0b11) # ------------------------------- # Layer I (0b11, 0b11): { 32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011, 160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111, 288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011, 416: 0b1100, 448: 0b1101 }, # Layer II (0b11, 0b10): { 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, 160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011, 320: 0b1100, 384: 0b1101 }, # Layer III (0b11, 0b01): { 32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011, 64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111, 128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011, 256: 0b0100, 320: 0b0101 }, # ------------------------------- # MPEG-2 (0b10) / MPEG-2.5 (0b00) # ------------------------------- # Layer I (0b10, 0b11): { 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, 144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011, 224: 0b1100, 256: 0b1101 }, (0b00, 0b11): { 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, 144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011, 224: 0b1100, 256: 0b1101 }, # Layer II (0b10, 0b10): { 8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011, 40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111, 80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011, 144: 0b1100, 160: 0b1101 }, (0b00, 0b10): { 8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011, 40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111, 80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011, 144: 0b1100, 160: 0b1101 }, # Layer III (0b10, 0b01): { 8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011, 40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111, 80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011, 144: 0b0100, 160: 0b0101 }, (0b00, 0b01): { 8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011, 40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111 } } # 获取当前版本的比特率表 key = (mpeg_version, layer_code) if key not in bitrate_tables: raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}") bitrate_table = bitrate_tables[key] if bitrate_kbps not in bitrate_table: raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}") bitrate_index = bitrate_table[bitrate_kbps] # ---------------------------------- # 确定声道模式 # ---------------------------------- if channels == 1: channel_mode = 0b11 # 单声道 elif channels == 2: channel_mode = 0b00 # 立体声 else: raise ValueError("声道数必须为1或2") # ---------------------------------- # 组合帧头字段(修正层编码) # ---------------------------------- sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111) version = mpeg_version << 19 # MPEG 版本 2位 layer_bits = layer_code << 17 # Layer 编码(I:0b11, II:0b10, III:0b01) protection = 0 << 16 # 无 CRC bitrate_bits = bitrate_index << 12 sample_rate_bits = sample_rate_index << 10 padding = 0 << 9 # 无填充 private = 0 << 8 mode = channel_mode << 6 mode_ext = 0 << 4 # 扩展模式(单声道无需设置) copyright = 0 << 3 original = 0 << 2 emphasis = 0b00 # 无强调 frame_header = ( sync | version | layer_bits | protection | bitrate_bits | sample_rate_bits | padding | private | mode | mode_ext | copyright | original | emphasis ) return frame_header.to_bytes(4, byteorder='big') # ------------------------------------------------ def audio_fade_in(audio_data, fade_length): # 假设音频数据是16位单声道PCM # 将二进制数据转换为整数数组 samples = array.array('h', audio_data) # 对前fade_length个样本进行淡入处理 for i in range(fade_length): fade_factor = i / fade_length samples[i] = int(samples[i] * fade_factor) # 将整数数组转换回二进制数据 return samples.tobytes() def parse_markdown_json(json_string): # 使用正则表达式匹配Markdown中的JSON代码块 match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL) if match: try: # 尝试解析JSON字符串 data = json.loads(match[1]) return {'success': True, 'data': data} except json.JSONDecodeError as e: # 如果解析失败,返回错误信息 return {'success': False, 'data': str(e)} else: return {'success': False, 'data': 'not a valid markdown json string'} # 全角字符到半角字符的映射 def fullwidth_to_halfwidth(s): full_to_half_map = { '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', '、': '、', '・': '・', ':': ':' } return ''.join(full_to_half_map.get(char, char) for char in s) def split_text_at_punctuation(text, chunk_size=100): # 使用正则表达式找到所有的标点符号和特殊字符 punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' tokens = re.split(punctuation_pattern, text) # 移除空字符串 tokens = [token for token in tokens if token] # 存储最终的文本块 chunks = [] current_chunk = '' for token in tokens: if len(current_chunk) + len(token) <= chunk_size: # 如果添加当前token后长度不超过chunk_size,则添加到当前块 current_chunk += (token + ' ') else: # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 chunks.append(current_chunk.strip()) current_chunk = token + ' ' # 添加最后一个块(如果有剩余) if current_chunk: chunks.append(current_chunk.strip()) return chunks def extract_text_from_markdown(markdown_text): # 移除Markdown标题 text = re.sub(r'#\s*[^#]+', '', markdown_text) # 移除内联代码块 text = re.sub(r'`[^`]+`', '', text) # 移除代码块 text = re.sub(r'```[\s\S]*?```', '', text) # 移除加粗和斜体 text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) # 移除链接 text = re.sub(r'\[.*?\]\(.*?\)', '', text) # 移除图片 text = re.sub(r'!\[.*?\]\(.*?\)', '', text) # 移除HTML标签 text = re.sub(r'<[^>]+>', '', text) # 转换标点符号 # text = re.sub(r'[^\w\s]', '', text) text = fullwidth_to_halfwidth(text) # 移除多余的空格 text = re.sub(r'\s+', ' ', text).strip() return text def encode_gzip_base64(original_data: bytes) -> str: """核心编码过程:二进制数据 → Gzip压缩 → Base64编码""" # Step 1: Gzip 压缩 with BytesIO() as buf: with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file: gz_file.write(original_data) compressed_bytes = buf.getvalue() # Step 2: Base64 编码(配置与Android端匹配) return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP) def test_qwen_chat(): messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': '你是谁?'} ] response = Generation.call( # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx", api_key=ALI_KEY, model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models messages=messages, result_format="message" ) if response.status_code == 200: print(response.output.choices[0].message.content) else: print(f"HTTP返回码:{response.status_code}") print(f"错误码:{response.code}") print(f"错误信息:{response.message}") print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code") ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75" from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse from dashscope.audio.tts import ( ResultCallback as TTSResultCallback, SpeechSynthesizer as TTSSpeechSynthesizer, SpeechSynthesisResult as TTSSpeechSynthesisResult, ) # cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本 from dashscope.audio.tts_v2 import ( ResultCallback as CosyResultCallback, SpeechSynthesizer as CosySpeechSynthesizer, AudioFormat, ) class QwenTTS: def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun", special_characters: Optional[Dict[str, str]] = None): import dashscope import ssl self.model_name = model_name.split('@')[0] dashscope.api_key = key ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证 self.synthesizer = None self.callback = None self.is_cosyvoice = False self.is_cosyvoice_v2 = False self.cosyvoice = "" self.voice = "" self.format = format self.sample_rate = sample_rate self.first_chunk = True if '/' in self.model_name: parts = self.model_name.split('/', 1) # 返回分离后的两个字符串parts[0], parts[1] if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2': self.is_cosyvoice = True self.cosyvoice = parts[0] self.voice = parts[1] if parts[0] == 'cosyvoice-v2': self.is_cosyvoice_v2 = True logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {self.cosyvoice} {self.voice}") # cyx self.completion_event = None # 新增:用于通知任务完成 # 特殊字符及其拼音映射 self.special_characters = special_characters or { "㼽": "chuang3", "䡇": "yue4" # 可以添加更多特殊字符的映射 } class Callback(TTSResultCallback): def __init__(self,data_callback=None,completion_event=None) -> None: self.dque = deque() self.data_callback = data_callback self.completion_event = completion_event # 新增完成事件引用 def _run(self): while True: if not self.dque: time.sleep(0) continue val = self.dque.popleft() if val: yield val else: break def on_open(self): pass def on_complete(self): self.dque.append(None) if self.data_callback: self.data_callback(None) # 发送结束信号 # 通知任务完成 if self.completion_event: self.completion_event.set() def on_error(self, response: SpeechSynthesisResponse): print("Qwen tts error", str(response)) raise RuntimeError(str(response)) def on_close(self): pass def on_event(self, result: TTSSpeechSynthesisResult): data =result.get_audio_frame() if data is not None: if len(data) > 0: if self.data_callback: self.data_callback(data) else: self.dque.append(data) #self.dque.append(result.get_audio_frame()) # -------------------------- class Callback_Cosy(CosyResultCallback): def __init__(self, data_callback=None,completion_event=None) -> None: self.dque = deque() self.data_callback = data_callback self.completion_event = completion_event # 新增完成事件引用 def _run(self): while True: if not self.dque: time.sleep(0) continue val = self.dque.popleft() if val: yield val else: break def on_open(self): #logging.info("Qwen CosyVoice tts open ") pass def on_complete(self): self.dque.append(None) if self.data_callback: self.data_callback(None) # 发送结束信号 # 通知任务完成 if self.completion_event: self.completion_event.set() def on_error(self, response: SpeechSynthesisResponse): print("Qwen tts error", str(response)) if self.data_callback: self.data_callback(f"ERROR:{str(response)}".encode()) raise RuntimeError(str(response)) def on_close(self): # print("---Qwen call back close") # cyx #logging.info("Qwen CosyVoice tts close") pass """ canceled for test 语音大模型CosyVoice def on_event(self, result: SpeechSynthesisResult): if result.get_audio_frame() is not None: self.dque.append(result.get_audio_frame()) """ def on_event(self, message): # logging.info(f"recv speech synthsis message {message}") pass # 以下适合语音大模型CosyVoice def on_data(self, data: bytes) -> None: if len(data) > 0: if self.data_callback: self.data_callback(data) else: self.dque.append(data) # -------------------------- def tts(self, text, on_data = None,completion_event=None): # logging.info(f"---QwenTTS tts begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx # text = self.normalize_text(text) print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx # text = self.normalize_text(text) try: # if self.model_name != 'cosyvoice-v1': if self.is_cosyvoice is False: self.callback = self.Callback( data_callback=on_data, completion_event=completion_event ) TTSSpeechSynthesizer.call(model=self.model_name, text=text, callback=self.callback, format=self.format) # format="mp3") else: self.callback = self.Callback_Cosy() format = self.get_audio_format(self.format, self.sample_rate) self.synthesizer = CosySpeechSynthesizer( model='cosyvoice-v2', # voice="longyuan", #"longfei", voice=self.voice, callback=self.callback, format=format ) self.synthesizer.call(text) except Exception as e: print(f"---dale---20 error {e}") # cyx # ----------------------------------- try: for data in self.callback._run(): # logging.info(f"dashcope return data {len(data)}") yield data # print(f"---Qwen return data {num_tokens_from_string(text)}") # yield num_tokens_from_string(text) except Exception as e: raise RuntimeError(f"**ERROR**: {e}") def setup_tts(self, on_data,completion_event=None): """设置 TTS 回调,返回配置好的 synthesizer""" #if not self.is_cosyvoice: # raise NotImplementedError("Only CosyVoice supported") if self.is_cosyvoice: # 创建 CosyVoice 回调 self.callback = self.Callback_Cosy( data_callback=on_data, completion_event=completion_event) else: self.callback = self.Callback( data_callback=on_data, completion_event=completion_event) if self.is_cosyvoice: format_val = self.get_audio_format(self.format, self.sample_rate) # logging.info(f"Qwen setup_tts {self.voice} {format_val}") self.synthesizer = CosySpeechSynthesizer( model=self.cosyvoice, #'cosyvoice-v1', voice=self.voice, # voice="longyuan", #"longfei", "longyuan_2" callback=self.callback, format=format_val ) return self.synthesizer def apply_phoneme_tags(self, text: str) -> str: """ 在文本中查找特殊字符并用标签包裹它们 """ """ SSML(Speech Synthesis Markup Language) 是一种基于 XML 的语音合成标记语言。 它不仅能让语音合成大模型读出更丰富的文本内容,还支持对语速、语调、停顿、音量等语音特征进行精细控制, 甚至可以添加背景音乐,带来更具表现力的语音效果。本文介绍CosyVoice的SSML功能及使用。 仅限cosyvoice-v2模型 """ # 如果文本已经是SSML格式,直接返回 if text.strip().startswith("") and text.strip().endswith(""): return text # 为特殊字符添加SSML标签 for char, pinyin in self.special_characters.items(): # 使用正则表达式确保只替换整个字符(避免部分匹配) pattern = r'([^<]|^)' + re.escape(char) + r'([^>]|$)' replacement = r'\1' + char + r'\2' text = re.sub(pattern, replacement, text) # 如果文本中已有SSML标签,直接返回 if "" in text: return text # 否则包裹在标签中 return f"{text}" def text_tts_call(self, text): # SSML(Speech Synthesis Markup Language) 是一种基于 XML 的语音合成标记语言。 # 仅限cosyvoice-v2模型 if self.special_characters and self.is_cosyvoice_v2 : text = self.apply_phoneme_tags(text) #logging.info(f"Applied SSML phoneme tags to text: {text}") volume = 50 if self.sample_rate < 10000: volume = 70 if self.synthesizer and self.is_cosyvoice: #logging.info(f"Qwen text_tts_call {text} {self.cosyvoice} {self.voice}") format_val = self.get_audio_format(self.format, self.sample_rate) self.synthesizer = CosySpeechSynthesizer( model = self.cosyvoice, #'cosyvoice-v1', voice = self.voice, # voice="longyuan", #"longfei", callback=self.callback, format=format_val ) self.synthesizer.call(text) if self.is_cosyvoice is False: logging.info(f"Qwen text_tts_call {text}") TTSSpeechSynthesizer.call(model=self.model_name, text=text, callback=self.callback, format=self.format) def streaming_call(self, text): if self.synthesizer: self.synthesizer.streaming_call(text) def end_streaming_call(self): if self.synthesizer: # logging.info(f"---dale end_streaming_call") self.synthesizer.streaming_complete() def cleanup(self): pass def get_audio_format(self, format: str, sample_rate: int): """动态获取音频格式""" from dashscope.audio.tts_v2 import AudioFormat logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}") format_map = { (8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS, (8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT, (8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT, (16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT, (22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS, (22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT, (22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT, (44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS, (44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT, (44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT, (48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS, (48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT, (48000, 'wav'): AudioFormat.WAV_48000HZ_MONO_16BIT } return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS) class DoubaoTTS: def __init__(self, key, format="mp3", sample_rate=8000, model_name="doubao-tts"): logging.info(f"---begin--init DoubaoTTS-- {format} {sample_rate} {model_name}") # 解析豆包认证信息 (appid, token, cluster, voice_type) try: self.appid = "7282190702" self.token = "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9" self.cluster = "volcano_tts" self.voice_type ="zh_female_qingxinnvsheng_mars_bigtts" # "zh_male_jieshuonansheng_mars_bigtts" #"zh_male_ruyaqingnian_mars_bigtts" #"zh_male_jieshuonansheng_mars_bigtts" except Exception as e: raise ValueError(f"Invalid Doubao key format: {str(e)}") self.format = format self.sample_rate = sample_rate self.model_name = model_name self.callback = None self.ws = None self.loop = None self.task = None self.event = threading.Event() self.data_queue = deque() self.host = "openspeech.bytedance.com" self.api_url = f"wss://{self.host}/api/v1/tts/ws_binary" self.default_header = bytearray(b'\x11\x10\x11\x00') self.total_data_size = 0 self.completion_event = None # 新增:用于通知任务完成 class Callback: def __init__(self, data_callback=None,completion_event=None): self.data_callback = data_callback self.data_queue = deque() self.completion_event = completion_event # 完成事件引用 def on_data(self, data): if self.data_callback: self.data_callback(data) else: self.data_queue.append(data) # 通知任务完成 if self.completion_event: self.completion_event.set() def on_complete(self): if self.data_callback: self.data_callback(None) def on_error(self, error): if self.data_callback: self.data_callback(f"ERROR:{error}".encode()) def setup_tts(self, on_data,completion_event): """设置回调,返回自身(因为豆包需要异步启动)""" self.callback = self.Callback( data_callback=on_data, completion_event=completion_event ) return self def text_tts_call(self, text): """同步调用,启动异步任务并等待完成""" self.total_data_size = 0 self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.task = self.loop.create_task(self._async_tts(text)) try: self.loop.run_until_complete(self.task) except Exception as e: logging.error(f"DoubaoTTS--0 call error: {e}") self.callback.on_error(str(e)) async def _async_tts(self, text): """异步执行TTS请求""" header = {"Authorization": f"Bearer; {self.token}"} request_json = { "app": { "appid": self.appid, "token": "access_token", # 固定值 "cluster": self.cluster }, "user": { "uid": str(uuid.uuid4()) # 随机用户ID }, "audio": { "voice_type": self.voice_type, "encoding": self.format, "speed_ratio": 1.0, "volume_ratio": 1.0, "pitch_ratio": 1.0, }, "request": { "reqid": str(uuid.uuid4()), "text": text, "text_type": "plain", "operation": "submit" # 使用submit模式支持流式 } } # 构建请求数据 payload_bytes = str.encode(json.dumps(request_json)) payload_bytes = gzip.compress(payload_bytes) full_client_request = bytearray(self.default_header) full_client_request.extend(len(payload_bytes).to_bytes(4, 'big')) full_client_request.extend(payload_bytes) try: async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws: self.ws = ws await ws.send(full_client_request) # 接收音频数据 while True: res = await ws.recv() done = self._parse_response(res) if done: self.callback.on_complete() break except Exception as e: logging.error(f"DoubaoTTS--1 WebSocket error: {e}") self.callback.on_error(str(e)) finally: # 通知任务完成 if self.completion_event: self.completion_event.set() def _parse_response(self, res): """解析豆包返回的二进制响应""" # 协议头解析 (4字节) header_size = res[0] & 0x0f message_type = res[1] >> 4 payload = res[header_size * 4:] # 音频数据响应 if message_type == 0xb: # audio-only server response message_flags = res[1] & 0x0f # ACK消息,忽略 if message_flags == 0: return False # 音频数据消息 sequence_number = int.from_bytes(payload[:4], "big", signed=True) payload_size = int.from_bytes(payload[4:8], "big", signed=False) audio_data = payload[8:8 + payload_size] if audio_data: self.total_data_size = self.total_data_size + len(audio_data) self.callback.on_data(audio_data) #logging.info(f"doubao _parse_response: {sequence_number} {len(audio_data)} {self.total_data_size}") # 序列号为负表示结束 return sequence_number < 0 # 错误响应 elif message_type == 0xf: code = int.from_bytes(payload[:4], "big", signed=False) msg_size = int.from_bytes(payload[4:8], "big", signed=False) error_msg = payload[8:8 + msg_size] try: # 尝试解压错误消息 error_msg = gzip.decompress(error_msg).decode() except: error_msg = error_msg.decode(errors='ignore') logging.error(f"DoubaoTTS error: {error_msg}") self.callback.on_error(error_msg) return False return False class UnifiedTTSEngine: def __init__(self): self.lock = threading.Lock() self.tasks = {} self.executor = ThreadPoolExecutor(max_workers=10) self.cache_expire = 300 # 5分钟缓存 # 启动清理过期任务的定时器 self.cleanup_timer = None self.start_cleanup_timer() def _cleanup_old_tasks(self): """清理过期任务""" now = time.time() with self.lock: expired_ids = [task_id for task_id, task in self.tasks.items() if now - task['created_at'] > self.cache_expire] for task_id in expired_ids: self._remove_task(task_id) def _remove_task(self, task_id): """移除任务""" if task_id in self.tasks: task = self.tasks.pop(task_id) # 取消可能的后台任务 if 'future' in task and not task['future'].done(): task['future'].cancel() # 其他资源在任务被移除后会被垃圾回收 # 资源释放机制总结: # 移除任务引用:self.tasks.pop() 解除任务对象引用,触发垃圾回收。 # 取消后台线程:future.cancel() 终止未完成线程,释放线程资源。 # 自动内存回收:Python GC 回收任务对象及其队列、缓冲区占用的内存。 # 线程池管理:执行器自动回收线程至池中,避免资源泄漏。 def create_tts_task(self, text, format, sample_rate, model_name, key, delay_gen_audio=False): """创建TTS任务(同步方法)""" self._cleanup_old_tasks() audio_stream_id = str(uuid.uuid4()) # 创建任务数据结构 task_data = { 'id': audio_stream_id, 'text': text, 'format': format, 'sample_rate': sample_rate, 'model_name': model_name, 'key': key, 'delay_gen_audio': delay_gen_audio, 'created_at': time.time(), 'status': 'pending', 'data_queue': deque(), 'event': threading.Event(), 'completed': False, 'error': None } with self.lock: self.tasks[audio_stream_id] = task_data # 如果不是延迟模式,立即启动任务 if not delay_gen_audio: self._start_tts_task(audio_stream_id) return audio_stream_id def _start_tts_task(self, audio_stream_id): # 启动TTS任务(后台线程) task = self.tasks.get(audio_stream_id) if not task or task['status'] != 'pending': return logging.info("已经启动 start tts task {audio_stream_id}") task['status'] = 'processing' # 在后台线程中执行TTS future = self.executor.submit(self._run_tts_sync, audio_stream_id) task['future'] = future # 如果需要等待任务完成 if not task.get('delay_gen_audio', True): try: # 等待任务完成(最多5分钟) future.result(timeout=300) logging.info(f"TTS任务 {audio_stream_id} 已完成") self._merge_audio_data(audio_stream_id) except concurrent.futures.TimeoutError: task['error'] = "TTS生成超时" task['completed'] = True logging.error(f"TTS任务 {audio_stream_id} 超时") except Exception as e: task['error'] = f"ERROR:{str(e)}" task['completed'] = True logging.exception(f"TTS任务执行异常: {str(e)}") def _run_tts_sync(self, audio_stream_id): # 同步执行TTS生成 在后台线程中执行 task = self.tasks.get(audio_stream_id) if not task: return try: # 创建完成事件 completion_event = threading.Event() # 创建TTS实例 # 根据model_name选择TTS引擎 # 前端传入 cosyvoice-v1/longhua@Tongyi-Qianwen model_name_wo_brand = task['model_name'].split('@')[0] model_name_version = model_name_wo_brand.split('/')[0] if "longhua" in task['model_name'] or "zh_female_qingxinnvsheng_mars_bigtts" in task['model_name']: # 豆包TTS tts = DoubaoTTS( key=task['key'], format=task['format'], sample_rate=task['sample_rate'], model_name=task['model_name'] ) else: # 通义千问TTS tts = QwenTTS( key=task['key'], format=task['format'], sample_rate=task['sample_rate'], model_name=task['model_name'] ) # 定义同步数据处理函数 def data_handler(data): if data is None: # 结束信号 task['completed'] = True task['event'].set() logging.info(f"--data_handler on_complete") elif data.startswith(b"ERROR"): # 错误信号 task['error'] = data.decode() task['completed'] = True task['event'].set() else: # 音频数据 task['data_queue'].append(data) # 设置并执行TTS synthesizer = tts.setup_tts(data_handler,completion_event) #synthesizer.call(task['text']) tts.text_tts_call(task['text']) # 等待完成或超时 # 等待完成或超时 if not completion_event.wait(timeout=300): # 5分钟超时 task['error'] = "TTS generation timeout" task['completed'] = True logging.info(f"--tts task event set error = {task['error']}") except Exception as e: logging.info(f"UnifiedTTSEngine _run_tts_sync ERROR: {str(e)}") task['error'] = f"ERROR:{str(e)}" task['completed'] = True finally: # 确保清理TTS资源 logging.info("UnifiedTTSEngine _run_tts_sync finally") if hasattr(tts, 'loop') and tts.loop: tts.loop.close() def _merge_audio_data(self, audio_stream_id): """将任务的所有音频数据合并到ByteIO缓冲区""" task = self.tasks.get(audio_stream_id) if not task or not task.get('completed'): return try: logging.info(f"开始合并音频数据: {audio_stream_id}") # 创建内存缓冲区 buffer = io.BytesIO() # 合并所有数据块 for data_chunk in task['data_queue']: buffer.write(data_chunk) # 重置指针位置以便读取 buffer.seek(0) # 保存到任务对象 task['buffer'] = buffer logging.info(f"音频数据合并完成,总大小: {buffer.getbuffer().nbytes} 字节") # 可选:清理原始数据队列以节省内存 task['data_queue'].clear() except Exception as e: logging.error(f"合并音频数据失败: {str(e)}") task['error'] = f"合并错误: {str(e)}" async def get_audio_stream(self, audio_stream_id): """获取音频流(异步生成器)""" task = self.tasks.get(audio_stream_id) if not task: raise RuntimeError("Audio stream not found") # 如果是延迟任务且未启动,现在启动 status 为 pending if task['delay_gen_audio'] and task['status'] == 'pending': self._start_tts_task(audio_stream_id) total_audio_data_size = 0 # 等待任务启动 while task['status'] == 'pending': await asyncio.sleep(0.1) # 流式返回数据 while not task['completed'] or task['data_queue']: while task['data_queue']: data = task['data_queue'].popleft() total_audio_data_size += len(data) #logging.info(f"yield audio data {len(data)} {total_audio_data_size}") yield data # 短暂等待新数据 await asyncio.sleep(0.05) # 检查错误 if task['error']: raise RuntimeError(task['error']) def start_cleanup_timer(self): """启动定时清理任务""" if self.cleanup_timer: self.cleanup_timer.cancel() self.cleanup_timer = threading.Timer(30.0, self.cleanup_task) # 每30秒清理一次 self.cleanup_timer.daemon = True # 设置为守护线程 self.cleanup_timer.start() def cleanup_task(self): """执行清理任务""" try: self._cleanup_old_tasks() except Exception as e: logging.error(f"清理任务时出错: {str(e)}") finally: self.start_cleanup_timer() # 重新启动定时器 # 全局 TTS 引擎实例 tts_engine = UnifiedTTSEngine() def replace_domain(url: str) -> str: """替换URL中的域名为本地地址,不使用urllib.parse""" # 定义需要替换的域名列表 domains_to_replace = [ "http://1.13.185.116:9380", "https://ragflow.szzysztech.com", "1.13.185.116:9380", "ragflow.szzysztech.com" ] # 尝试替换每个可能的域名 for domain in domains_to_replace: if domain in url: # 直接替换域名部分 return url.replace(domain, "http://localhost:9380", 1) # 如果未匹配到特定域名,尝试智能替换 if "://" in url: # 分割协议和路径 protocol, path = url.split("://", 1) # 查找第一个斜杠位置来确定域名结束位置 slash_pos = path.find("/") if slash_pos > 0: # 替换域名部分 return f"http://localhost:9380{path[slash_pos:]}" else: # 没有路径部分,直接返回本地地址 return "http://localhost:9380" else: # 没有协议部分,直接添加本地地址 return f"http://localhost:9380/{url}" async def proxy_aichat_audio_stream( client_id: str, audio_stream_id: str, combined_state: dict = None # 新增组合状态参数 ): try: stream_session = stream_manager.get_session(audio_stream_id) if not stream_session: logging.warning(f"Audio session not found: {audio_stream_id}") return # 注册当前连接的音频会话 manager.register_audio_session(client_id, audio_stream_id) sample_rate = stream_session.get('sample_rate') audio_data_size = 0 # 发送采样率 await manager.send_json(client_id, {"command": "sample_rate", "params": sample_rate}) # 处理音频流 async for data in stream_manager.get_tts_buffer_data(audio_stream_id): # 检查websocket 连接是否仍然活跃 if not manager.is_connected(client_id): logging.info(f"audio websocket 连接已断开,停止音频生成: {client_id}") # 通知会话管理器停止生成 stream_manager.stop_session(audio_stream_id) break if isinstance(data, dict): if data.get('type') == 'sentence_end': await manager.send_json(client_id, {"command": "sentence_end"}) elif data.get('type') == 'arraybuffer': audio_data = data.get('data') audio_data_size += len(audio_data) if not await manager.send_bytes(client_id, audio_data): break logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}") # 组合模式通知音频流结束 if combined_state: combined_state["audio_completed"].set() except Exception as e: logging.error(f"音频流处理失败: {str(e)}") # 仅在连接活跃时发送错误 if manager.is_connected(client_id): await manager.send_text(client_id, json.dumps({ "type": "error", "message": f"音频流错误: {str(e)}" })) finally: # 确保取消注册 if client_id in manager.aichat_audio_sessions: del manager.aichat_audio_sessions[client_id] # 代理函数 - 文本流 # 在微信小程序中,原来APK使用的SSE机制不能正常工作,需要使用WebSocket async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict, combined_state: dict = None): """代理大模型文本流请求 - 兼容现有Flask实现""" try: logging.info(f"代理文本流: completions_url={completions_url} {payload}") logging.debug(f"请求负载: {json.dumps(payload, ensure_ascii=False)}") headers = { "Content-Type": "application/json", 'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm' } tts_model = None tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen') #if 'longyuan' in tts_model_name: # tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen" #logging.info(f"---create tts_stream_session_id = {tts_stream_session_id}") tts_stream_session_id_sent = False # 使用更长的超时时间 (5分钟) timeout = httpx.Timeout(30.0, connect=20.0,read=20, write=20) async with httpx.AsyncClient(timeout=timeout) as client: # 关键修改:使用流式请求模式 async with client.stream( # <-- 使用stream方法 "POST", completions_url, json=payload, headers=headers ) as response: logging.info(f"响应状态: HTTP {response.status_code}") if response.status_code != 200: # 读取错误信息(非流式) error_content = await response.aread() error_msg = f"后端错误: HTTP {response.status_code}" error_msg += f" - {error_content[:200].decode()}" if error_content else "" await manager.send_text(client_id, json.dumps({"type": "error", "message": error_msg})) return # 验证SSE流 content_type = response.headers.get("content-type", "").lower() if "text/event-stream" not in content_type: logging.warning("非流式响应,转发完整内容") full_content = await response.aread() await manager.send_text(client_id, json.dumps({ "type": "text", "data": full_content.decode('utf-8') })) return if tts_model is None: # 创建TTS实例 tts_model = QwenTTS( key=ALI_KEY, format=payload.get('tts_stream_format', 'mp3'), sample_rate=payload.get('tts_sample_rate', 48000), model_name=tts_model_name ) streaming_call = False if tts_model.is_cosyvoice: streaming_call = True # 创建流会话 tts_stream_session_id = stream_manager.create_session( tts_model=tts_model, sample_rate=payload.get('tts_sample_rate', 48000), stream_format=payload.get('tts_stream_format', 'mp3'), session_id=None, streaming_call=streaming_call ) # 关键修改:设置TTS会话ID并触发就绪事件 if combined_state is not None and tts_stream_session_id: combined_state["tts_session_id"] = tts_stream_session_id combined_state["tts_ready_event"].set() # 触发事件通知主流程 logging.info(f"开始处理SSE流 {tts_stream_session_id}") event_count = 0 # 使用异步迭代器逐行处理 async for line in response.aiter_lines(): # 跳过空行和注释行 if not line or line.startswith(':'): continue # 处理SSE事件 if line.startswith("data:"): data_str = line[5:].strip() if data_str: # 过滤空数据 try: # 解析并提取增量文本 data_obj = json.loads(data_str) delta_text = None if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict): delta_text = data_obj.get('data', None).get('delta_ans', "") if tts_stream_session_id_sent is False: if tts_stream_session_id and combined_state is None: logging.info(f"--proxy_aichat_text_stream 发送audio_stream_url") data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}" data_str = json.dumps(data_obj) tts_stream_session_id_sent = True # 直接转发原始数据 await manager.send_text(client_id, json.dumps({ "type": "text", "data": data_str })) # 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析 if delta_text and tts_stream_session_id: # 追加到会话管理器 stream_manager.append_text(tts_stream_session_id, delta_text) # logging.info(f"文本代理转发: {data_str}") event_count += 1 except Exception as e: logging.error(f"事件发送失败: {str(e)}") # 保持连接活性 await asyncio.sleep(0.001) # 避免CPU空转 #logging.info(f"SSE流处理完成,事件数: {event_count}") # 发送文本流结束信号 await manager.send_text(client_id, json.dumps({"type": "AiChatTextEnd"})) # 标记文本输入结束 if tts_stream_session_id and stream_manager.finish_text_input: stream_manager.finish_text_input(tts_stream_session_id) except httpx.ConnectTimeout: logging.error("后端服务超时") await manager.send_text(client_id, json.dumps({ "type": "error", "message": "后端服务响应超时" })) except httpx.ConnectError as e: logging.error(f"连接后端服务失败: {str(e)}") await manager.send_text(client_id, json.dumps({ "type": "error", "message": f"无法连接到后端服务: {str(e)}" })) except Exception as e: logging.exception(f"文本代理失败: {str(e)}") await manager.send_text(client_id, json.dumps({ "type": "error", "message": f"文本流获取失败: {str(e)}" })) finally: pass # 记录连接池状态 #pool_status = await HTTPXConnectionPool.get_pool_status() #logging.debug(f"连接池状态: {json.dumps(pool_status)}") @tts_router.get("/audio/pcm_mp3") async def stream_mp3(): def audio_generator(): path = './test.mp3' try: with open(path, 'rb') as f: while True: chunk = f.read(1024) if not chunk: break yield chunk except Exception as e: logging.error(f"MP3 streaming error: {str(e)}") return StreamingResponse( audio_generator(), media_type="audio/mpeg", headers={ "Cache-Control": "no-store", "Accept-Ranges": "bytes" } ) def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes: """动态生成WAV头(严格保持原有逻辑结构)""" with BytesIO() as wav_buffer: with wave.open(wav_buffer, 'wb') as wav_file: wav_file.setnchannels(1) # 保持原单声道设置 wav_file.setsampwidth(2) # 保持原16-bit设置 wav_file.setframerate(sample_rate) wav_file.writeframes(pcm_data) wav_buffer.seek(0) return wav_buffer.read() def generate_silence_header(duration_ms: int = 500) -> bytes: """生成静音数据(用于MP3流式传输预缓冲)""" num_samples = int(TTS_SAMPLERATE * duration_ms / 1000) return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS # ------------------------ API路由 ------------------------ @tts_router.post("/chats/{chat_id}/tts") async def create_tts_request(chat_id: str, request: Request): try: data = await request.json() logging.info(f"Creating TTS request: {data}") # 参数校验 text = data.get("text", "").strip() if not text: raise HTTPException(400, detail="Text cannot be empty") format = data.get("tts_stream_format", "mp3") if format not in ["mp3", "wav", "pcm"]: raise HTTPException(400, detail="Unsupported audio format") sample_rate = data.get("tts_sample_rate", 48000) if sample_rate not in [8000, 16000, 22050, 44100, 48000]: raise HTTPException(400, detail="Unsupported sample rate") model_name = data.get("model_name", "cosyvoice-v1/longxiaochun") delay_gen_audio = data.get('delay_gen_audio', False) # 创建TTS任务 audio_stream_id = tts_engine.create_tts_task( text=text, format=format, sample_rate=sample_rate, model_name=model_name, key=ALI_KEY, delay_gen_audio=delay_gen_audio ) return JSONResponse( status_code=200, content={ "tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}", "url": f"/chats/{chat_id}/tts/{audio_stream_id}", "ws_url": f"/chats/{chat_id}/tts/{audio_stream_id}", # WebSocket URL 2025 0622新增 "expires_at": (datetime.datetime.now() + datetime.timedelta(seconds=300)).isoformat() } ) except Exception as e: logging.error(f"Request failed: {str(e)}") raise HTTPException(500, detail="Internal server error") @tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}") async def get_tts_audio( chat_id: str, audio_stream_id: str, range: str = Header(None) ): try: # 获取任务信息 task = tts_engine.tasks.get(audio_stream_id) if not task: # 返回友好的错误信息而不是抛出异常 return JSONResponse( status_code=404, content={ "error": "Audio stream not found", "message": f"The requested audio stream ID '{audio_stream_id}' does not exist or has expired", "suggestion": "Please create a new TTS request and try again" } ) # 获取媒体类型 format = task['format'] media_type = { "mp3": "audio/mpeg", "wav": "audio/wav", "pcm": f"audio/L16; rate={task['sample_rate']}; channels=1" }[format] # 如果任务已完成且有完整缓冲区,处理Range请求 logging.info(f"get_tts_audio task = {task.get('completed', 'None')} {task.get('buffer', 'None')}") # 创建响应内容生成器 def buffer_read(buffer): content_length = buffer.getbuffer().nbytes remaining = content_length chunk_size = 4096 buffer.seek(0) while remaining > 0: read_size = min(remaining, chunk_size) data = buffer.read(read_size) if not data: break yield data remaining -= len(data) if task.get('completed') and task.get('buffer') is not None: buffer = task['buffer'] total_size = buffer.getbuffer().nbytes # 强制小文件使用流式传输(避免206响应问题) if total_size < 1024 * 120: # 小于300KB range = None if range: # 处理范围请求 return handle_range_request(range, buffer, total_size, media_type) else: return StreamingResponse( buffer_read(buffer), media_type=media_type, headers={ "Accept-Ranges": "bytes", "Cache-Control": "no-store", "Transfer-Encoding": "chunked" }) # 创建流式响应 logging.info("tts_engine.get_audio_stream--0") return StreamingResponse( tts_engine.get_audio_stream(audio_stream_id), media_type=media_type, headers={ "Accept-Ranges": "bytes", "Cache-Control": "no-store", "Transfer-Encoding": "chunked" } ) except Exception as e: logging.error(f"Audio streaming failed: {str(e)}") raise HTTPException(500, detail="Audio generation error") def handle_range_request(range_header: str, buffer: BytesIO, total_size: int, media_type: str): """处理 HTTP Range 请求""" try: # 解析 Range 头部 (示例: "bytes=0-1023") range_type, range_spec = range_header.split('=') if range_type != 'bytes': raise ValueError("Unsupported range type") start_str, end_str = range_spec.split('-') start = int(start_str) end = int(end_str) if end_str else total_size - 1 logging.info(f"handle_range_request--1 {start_str}-{end_str} {end}") # 验证范围有效性 if start >= total_size or end >= total_size: raise HTTPException(status_code=416, headers={ "Content-Range": f"bytes */{total_size}" }) # 计算内容长度 content_length = end - start + 1 # 设置状态码 status_code = 206 # Partial Content if start == 0 and end == total_size - 1: status_code = 200 # Full Content # 设置流读取位置 buffer.seek(start) # 创建响应内容生成器 def content_generator(): remaining = content_length chunk_size = 4096 while remaining > 0: read_size = min(remaining, chunk_size) data = buffer.read(read_size) if not data: break yield data remaining -= len(data) # 返回分块响应 return StreamingResponse( content_generator(), status_code=status_code, media_type=media_type, headers={ "Content-Range": f"bytes {start}-{end}/{total_size}", "Content-Length": str(content_length), "Accept-Ranges": "bytes", "Cache-Control": "public, max-age=3600" } ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @tts_router.websocket("/chats/{chat_id}/tts/{audio_stream_id}") async def websocket_tts_endpoint( websocket: WebSocket, chat_id: str, audio_stream_id: str ): # 接收 header 参数 headers = websocket.headers service_type = headers.get("x-tts-type") # 注意:header 名称转为小写 #给H5代码特殊处理,H5代码中,x-tts-type的header不能工作 # 浏览器内置WebSocket 的连接时不能附加额外的header传递参数 if audio_stream_id == 'x-tts-type-is-TextToTts': service_type = 'TextToTts' # audio_url = headers.get("x-audio-url") """ 前端示例 websocketConnection = uni.connectSocket({ url: url, header: { 'Authorization': token, 'X-Tts-Type': 'AiChat', //'Ask' // 自定义参数1 'X-Device-Type': 'mobile', // 自定义参数2 'X-User-ID': '12345' // 自定义参数3 }, success: () => { console.log('WebSocket connected'); }, fail: (err) => { console.error('WebSocket connection failed:', err); } }); """ # 创建唯一连接 ID connection_id = str(uuid.uuid4()) # logging.info(f"---dale-- websocket connection_id = {connection_id} chat_id={chat_id}") await manager.connect(websocket, connection_id) completed_successfully = False try: # 根据tts_type路由到不同的音频源 if service_type == "AiChatTts": # 音频代理服务 await proxy_aichat_audio_stream(connection_id, audio_stream_id, combined_state = None) completed_successfully = True elif service_type == "AiChatText": # 文本代理服务 # 等待客户端发送初始请求数据 进行大模型对话代理时,需要前端连接后发送payload payload = await websocket.receive_json() # 在代理前检查连接池状态 completions_url = f"http://127.0.0.1:9380/api/v1/chats/{chat_id}/completions" await proxy_aichat_text_stream(connection_id, completions_url, payload, combined_state = None) completed_successfully = True elif service_type == "AiChatCombined": # 接收初始请求数据 payload = await websocket.receive_json() # 创建共享状态和同步事件 combined_state = { "tts_session_id": None, "tts_ready_event": asyncio.Event(), # TTS准备就绪事件 "audio_task": None, "text_completed": asyncio.Event(), "audio_completed": asyncio.Event() } # 启动文本流任务 text_task = asyncio.create_task( proxy_aichat_text_stream( client_id=connection_id, completions_url=f"http://127.0.0.1:9380/api/v1/chats/{chat_id}/completions", payload=payload, combined_state=combined_state ) ) try: # 等待TTS会话ID准备就绪(最多等待8秒) await asyncio.wait_for(combined_state["tts_ready_event"].wait(), timeout=8.0) if combined_state["tts_session_id"]: # 启动音频流任务 combined_state["audio_task"] = asyncio.create_task( proxy_aichat_audio_stream( client_id=connection_id, audio_stream_id=combined_state["tts_session_id"], combined_state=combined_state ) ) else: logging.warning("TTS会话ID未生成,跳过音频流任务") except asyncio.TimeoutError: logging.warning("等待TTS会话ID超时,跳过音频流任务") # 等待两个任务完成(如果音频任务未启动,text_task会正常完成) tasks = [text_task] if combined_state["audio_task"]: tasks.append(combined_state["audio_task"]) done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) # 检查任务状态并处理异常 for task in done: if task.exception(): logging.error(f"任务异常: {task.exception()}") # 发送错误消息给客户端 await manager.send_text(connection_id, json.dumps({ "type": "error", "message": str(task.exception()) })) # 取消任何未完成的任务 for task in pending: task.cancel() # 发送完成信号 if manager.is_connected(connection_id): await manager.send_text(connection_id, json.dumps({"type": "end"})) logging.info(" websocket_tts_endpoint AiChatCombined completed successfully") elif service_type == "TextToTts": # 前端将文本发送到后端,后端调用TTS引擎生成音频流 ,并且将生成音频的文本、生成音频的参数 # 返回音频,在1个websocket调用中完成 params_valid = True payload = await websocket.receive_json() # 参数校验 text = payload.get("text", "").strip() if not text: params_valid = False data = payload.get("params", {}) logging.info(f"websocket_tts_endpoint TextToTts:{text} {data}") format = data.get("tts_stream_format", "mp3") if format not in ["mp3", "wav", "pcm"]: params_valid = False sample_rate = data.get("tts_sample_rate", 48000) if sample_rate not in [8000, 16000, 22050, 44100, 48000]: params_valid = False model_name = data.get("model_name", "cosyvoice-v1/longxiaochun") delay_gen_audio = data.get('delay_gen_audio', False) if params_valid: # 创建TTS任务 audio_stream_id = tts_engine.create_tts_task( text=text, format=format, sample_rate=sample_rate, model_name=model_name, key=ALI_KEY, delay_gen_audio=delay_gen_audio ) # 使用引擎的生成器直接获取音频流 audio_data_size = 0 async for data in tts_engine.get_audio_stream(audio_stream_id): audio_data_size += len(data) if not await manager.send_bytes(connection_id, data): logging.warning(f"Send failed, connection closed: {connection_id}") break await manager.send_json(connection_id, {"command": "sentence_end"}) logging.info(f"websocket_tts_endpoint TextToTts completed successfully {audio_data_size} bytes") else: # 使用引擎的生成器直接获取音频流 async for data in tts_engine.get_audio_stream(audio_stream_id): if not await manager.send_bytes(connection_id, data): logging.warning(f"Send failed, connection closed: {connection_id}") break await manager.send_json(connection_id, {"command": "sentence_end"}) completed_successfully = True # 发送完成信号前检查连接状态 if manager.is_connected(connection_id): # 发送完成信号 await manager.send_json(connection_id, {"status": "completed"}) # 添加短暂延迟确保消息送达 await asyncio.sleep(0.1) # 主动关闭WebSocket连接 await manager.disconnect(connection_id, code=1000, reason="Audio stream completed") except WebSocketDisconnect: logging.info(f"websocket_tts_endpoint WebSocket disconnected: {connection_id}") except Exception as e: logging.error(f"websocket_tts_endpoint WebSocket TTS error: {str(e)}") if manager.is_connected(connection_id): await manager.send_json(connection_id, {"error": str(e)}) finally: pass # await manager.disconnect(connection_id) @tts_router.get("/debug/get_threadpool") async def get_threadpool(request: Request): params = dict(request.query_params) if params.get('reset'): stream_manager.reset_manager() return JSONResponse( status_code=200, content={ "status":stream_manager.get_self_thread_pool_status() } )