import logging import binascii from copy import deepcopy from timeit import default_timer as timer import datetime from datetime import timedelta import threading, time, queue, uuid, time, array from threading import Lock, Thread from concurrent.futures import ThreadPoolExecutor import base64, gzip import os, io, re, json from io import BytesIO from typing import Optional, Dict, Any import asyncio, httpx from collections import deque import websockets import uuid from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query from fastapi import FastAPI, UploadFile, File, Form, Header from fastapi.responses import StreamingResponse, JSONResponse, Response TTS_SAMPLERATE = 44100 # 22050 # 16000 FORMAT = "mp3" CHANNELS = 1 # 单声道 SAMPLE_WIDTH = 2 # 16-bit = 2字节 tts_router = APIRouter() # logger = logging.getLogger(__name__) class MillisecondsFormatter(logging.Formatter): """自定义日志格式器,添加毫秒时间戳""" def formatTime(self, record, datefmt=None): # 将时间戳转换为本地时间元组 ct = self.converter(record.created) # 格式化为 "小时:分钟:秒" t = time.strftime("%H:%M:%S", ct) # 添加毫秒(3位) return f"{t}.{int(record.msecs):03d}" # 配置全局日志格式 def configure_logging(): # 创建 Formatter log_format = "%(asctime)s - %(levelname)s - %(message)s" formatter = MillisecondsFormatter(log_format) # 获取根 Logger 并清除已有配置 root_logger = logging.getLogger() root_logger.handlers = [] # 创建并配置 Handler(输出到控制台) console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) # 设置日志级别并添加 Handler root_logger.setLevel(logging.INFO) root_logger.addHandler(console_handler) # 调用配置函数(程序启动时运行一次) configure_logging() class StreamSessionManager: def __init__(self): self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}} self.lock = threading.Lock() self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池 self.gc_interval = 300 # 5分钟清理一次 self.streaming_call_timeout = 10 # 10s self.gc_tts = 3 # 3s self.sentence_timeout = 2 # 2000ms句子超时 self.sentence_endings = set('。?!;.?!;') # 中英文结束符 # 增强版正则表达式:匹配中英文句子结束符(包含全角) self.sentence_pattern = re.compile( r'([,,。?!;.?!;?!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;?!;…])' ) self.sentence_audio_store = {} # {sentence_id: {'data': bytes, 'text': str, 'created_at': float}} self.sentence_lock = threading.Lock() threading.Thread(target=self._cleanup_expired, daemon=True).start() def create_session(self, tts_model, sample_rate=8000, stream_format='mp3', session_id=None, streaming_call=False): if not session_id: session_id = str(uuid.uuid4()) with self.lock: # 创建TTS实例并设置流式回调 tts_instance = tts_model # 定义音频数据回调函数 def on_data(data: bytes): if data: try: self.sessions[session_id]['last_active'] = time.time() self.sessions[session_id]['buffer'].put(data) self.sessions[session_id]['audio_chunk_size'] += len(data) #logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}") except queue.Full: logging.warning(f"Audio buffer full for session {session_id}") """ elif data is None: # 结束信号 # 仅对非流式引擎触发完成事件 if not streaming_call: logging.info(f"StreamSessionManager on_data sentence_complete_event set") self.sessions[session_id]['sentence_complete_event'].set() self.sessions[session_id]['current_processing'] = False """ # 创建完成事件 completion_event = threading.Event() # 设置TTS流式传输 tts_instance.setup_tts(on_data,completion_event) # 创建会话 self.sessions[session_id] = { 'tts_model': tts_model, 'buffer': queue.Queue(maxsize=300), # 线程安全队列 'task_queue': queue.Queue(), 'active': True, 'last_active': time.time(), 'audio_chunk_count': 0, 'audio_chunk_size': 0, 'finished': threading.Event(), # 添加事件对象 'sample_rate': sample_rate, 'stream_format': stream_format, "tts_chunk_data_valid": False, "text_buffer": "", # 新增文本缓冲区 "last_text_time": time.time(), # 最后文本到达时间 "streaming_call": streaming_call, "tts_stream_started": False, # 标记是否已启动流 "current_processing": False, # 标记是否正在处理句子 "sentence_complete_event": completion_event, #threading.Event(), 'sentences': [], # 存储句子ID列表 'current_sentence_index': 0, 'gc_interval':300, # 5分钟清理一次 } # 启动任务处理线程 threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start() return session_id def append_text(self, session_id, text): with self.lock: session = self.sessions.get(session_id) if not session: return # 更新文本缓冲区和时间戳 session['text_buffer'] += text session['last_text_time'] = time.time() # 将文本放入任务队列(非阻塞) try: session['task_queue'].put(text, block=False) except queue.Full: logging.warning(f"Session {session_id} task queue full") def finish_text_input(self, session_id): """标记文本输入结束,通知任务处理线程""" with self.lock: session = self.sessions.get(session_id) if not session: return session['gc_interval'] = 100 # 所有的文本输入已经结束,可以将超时检查时间缩短 def _process_tasks(self, session_id): # 20250718 新更新 """任务处理线程(每个会话独立)- 保留原有处理逻辑""" session = self.sessions.get(session_id) if not session or not session['active']: return # 根据引擎类型选择处理函数 if session.get('streaming_call'): gen_tts_audio_func = self._generate_audio #self._stream_audio else: gen_tts_audio_func = self._generate_audio while session['active']: current_time = time.time() text_to_process = "" # 1. 获取待处理文本 with self.lock: if session['text_buffer']: text_to_process = session['text_buffer'] # 2. 处理文本 if text_to_process and not session['current_processing'] : session['text_buffer'] = "" # 分割完整句子 complete_sentences, remaining_text = self._split_and_extract(text_to_process) # 保存剩余文本 if remaining_text: with self.lock: session['text_buffer'] = remaining_text + session['text_buffer'] # 合并并处理完整句子 if complete_sentences: # 智能合并句子(最长300字符) buffer = [] current_length = 0 # 处理每个句子 for sentence in complete_sentences: sent_length = len(sentence) # 添加到当前缓冲区 if current_length + sent_length <= 300: buffer.append(sentence) current_length += sent_length else: # 处理已缓冲的文本 if buffer: combined_text = "".join(buffer) # 重置完成事件状态 session['sentence_complete_event'].clear() session['current_processing'] = True # 生成音频 gen_tts_audio_func(session_id, combined_text) # 等待完成 if not session['sentence_complete_event'].wait(timeout=120.0): logging.warning(f"Timeout waiting for TTS completion: {combined_text}") # 重置处理状态 time.sleep(5.0) session['current_processing'] = False logging.info(f"StreamSessionManager _process_tasks 转换结束!!!") # 重置缓冲区 buffer = [sentence] current_length = sent_length # 处理剩余的缓冲文本 if buffer: combined_text = "".join(buffer) session['current_processing'] = True # 生成音频 gen_tts_audio_func(session_id, combined_text) #如果调用_stream_audio,则是同步调用,会阻塞,直到音频生成完成 # 重置处理状态 time.sleep(1.0) session['current_processing'] = False # 3. 检查超时未处理的文本 if current_time - session['last_text_time'] > self.sentence_timeout: with self.lock: if session['text_buffer']: # 直接处理剩余文本 session['current_processing'] = True gen_tts_audio_func(session_id, session['text_buffer']) session['text_buffer'] = "" # 重置处理状态 session['current_processing'] = False # 4. 会话超时检查 if current_time - session['last_active'] > session['gc_interval']: # 处理剩余文本 with self.lock: if session['text_buffer']: session['current_processing'] = True # 处理最后一段文本 gen_tts_audio_func(session_id, session['text_buffer']) session['text_buffer'] = "" # 重置处理状态 session['current_processing'] = False # 关闭会话 logging.info(f"--_process_tasks-- timeout {session['last_active']} {session['gc_interval']}") self.close_session(session_id) break # 5. 休眠避免CPU空转 time.sleep(0.05) # 50ms检查间隔 def _generate_audio(self, session_id, text): # 20250718 新更新 """实际生成音频(顺序执行)- 用于非流式引擎""" session = self.sessions.get(session_id) if not session: return try: #logging.info(f"StreamSessionManager _generate_audio--0 {text}") # 创建内存流 audio_stream = io.BytesIO() # 定义回调函数:直接写入流 def on_data_sentence(data: bytes): if data: audio_stream.write(data) def on_data_whole(data: bytes): if data: try: session['last_active'] = time.time() session['buffer'].put({'type':'arraybuffer','data':data}) #session['buffer'].put(data) session['audio_chunk_size'] += len(data) #logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}") except queue.Full: logging.warning(f"Audio buffer full for session {session_id}") # 重置完成事件 session['sentence_complete_event'].clear() session['tts_model'].setup_tts(on_data = on_data_whole,completion_event=session['sentence_complete_event']) # 调用 TTS session['tts_model'].text_tts_call(text) session['last_active'] = time.time() session['audio_chunk_count'] += 1 # 等待句子完成 if not session['sentence_complete_event'].wait(timeout=30): # 30秒超时 logging.warning(f"Timeout generating audio for: {text[:20]}...") logging.info(f"StreamSessionManager _generate_audio 转换结束!!!") session['buffer'].put({'type': 'sentence_end', 'data': ""}) # 获取音频数据 audio_data = audio_stream.getvalue() audio_stream.close() # 保存到句子存储 self.add_sentence_audio(session_id, text, audio_data) if not session['tts_chunk_data_valid']: session['tts_chunk_data_valid'] = True except Exception as e: session['buffer'].put(f"ERROR:{str(e)}".encode()) session['sentence_complete_event'].set() # 确保事件被设置 def _stream_audio(self, session_id, text): """流式传输文本到TTS服务""" session = self.sessions.get(session_id) if not session: return # logging.info(f"Streaming text to TTS: {text}") try: # 重置完成事件状态 session['sentence_complete_event'].clear() # 使用流式调用发送文本 session['tts_model'].streaming_call(text) session['last_active'] = time.time() # 流式引擎不需要等待完成事件 session['sentence_complete_event'].set() except Exception as e: logging.error(f"Error in streaming_call: {str(e)}") session['buffer'].put(f"ERROR:{str(e)}".encode()) session['sentence_complete_event'].set() async def get_tts_buffer_data(self, session_id): """异步流式返回 TTS 音频数据(适配同步 queue.Queue,带 10 秒超时)""" session = self.sessions.get(session_id) if not session: raise ValueError(f"Session {session_id} not found") buffer = session['buffer'] # 这里是 queue.Queue last_data_time = time.time() # 记录最后一次获取数据的时间 while session['active']: try: # 使用 run_in_executor + wait_for 设置 10 秒超时 data = await asyncio.wait_for( asyncio.get_event_loop().run_in_executor(None, buffer.get), timeout=10.0 # 10 秒超时 ) #logging.info(f"get_tts_buffer_data {data}") last_data_time = time.time() # 更新最后数据时间 yield data except asyncio.TimeoutError: # 10 秒内没有新数据,检查是否超时 if time.time() - last_data_time >= 10.0: break else: continue # 未超时,继续等待 except asyncio.CancelledError: logging.info(f"Session {session_id} stream cancelled") break except Exception as e: logging.error(f"Error in get_tts_buffer_data: {e}") break def close_session(self, session_id): with self.lock: if session_id in self.sessions: session = self.sessions[session_id] session['active'] = False # 清理关联的音频数据 with self.sentence_lock: # 标记会话相关的句子为过期 expired_sentences = [ sid for sid, data in self.sentence_audio_store.items() if data.get('session_id') == session_id ] # 设置完成事件 session['sentence_complete_event'].set() # 清理TTS资源 try: if session.get('streaming_call'): session['tts_model'].end_streaming_call() except: pass # 延迟清理会话 threading.Timer(1, self._clean_session, args=[session_id]).start() def _clean_session(self, session_id): with self.lock: if session_id in self.sessions: del self.sessions[session_id] def _cleanup_expired(self): """定时清理过期资源""" while True: time.sleep(30) now = time.time() # 清理过期会话 with self.lock: expired_sessions = [ sid for sid, session in self.sessions.items() if now - session['last_active'] > self.gc_interval ] for sid in expired_sessions: self._clean_session(sid) # 清理过期句子音频 with self.sentence_lock: expired_sentences = [ sid for sid, data in self.sentence_audio_store.items() if now - data['created_at'] > self.gc_interval ] for sid in expired_sentences: del self.sentence_audio_store[sid] logging.info(f"清理资源: {len(expired_sessions)}会话, {len(expired_sentences)}句子") def get_session(self, session_id): return self.sessions.get(session_id) def _has_sentence_ending(self, text): """检测文本是否包含句子结束符""" if not text: return False # 检查常见结束符(包含全角字符) if any(char in self.sentence_endings for char in text[-3:]): return True # 检查中文段落结束(换行符前有结束符) if '\n' in text and any(char in self.sentence_endings for char in text.split('\n')[-2:-1]): return True return False def _split_and_extract(self, text): """ 增强型句子分割器 返回: (完整句子列表, 剩余文本) """ # 特殊处理:如果文本以逗号开头,先处理前面的部分 if text.startswith((",", ",")): return [text[0]], text[1:] # 1. 查找所有可能的句子结束位置 matches = list(self.sentence_pattern.finditer(text)) if not matches: return [], text # 没有找到结束符 # 2. 确定最后一个完整句子的结束位置 last_end = 0 complete_sentences = [] for match in matches: end_pos = match.end() sentence = text[last_end:end_pos].strip() # 跳过空句子 if not sentence: last_end = end_pos continue # 检查是否为有效句子(最小长度或包含结束符) if (len(sentence) > 6 or any(char in "。.?!?!" for char in sentence)) and (last_end<24): complete_sentences.append(sentence) last_end = end_pos else: # 短文本但包含结束符,可能是特殊符号 if any(char in "。.?!?!" for char in sentence): complete_sentences.append(sentence) last_end = end_pos # 3. 提取剩余文本 remaining_text = text[last_end:].strip() return complete_sentences, remaining_text def add_sentence_audio(self, session_id, sentence_text, audio_data: bytes): """添加句子音频到存储""" with self.lock: if session_id not in self.sessions: return None # 生成唯一句子ID sentence_id = str(uuid.uuid4()) # 存储音频数据 with self.sentence_lock: self.sentence_audio_store[sentence_id] = { 'data': audio_data, 'text': sentence_text, 'created_at': time.time(), 'session_id': session_id, 'format': self.sessions[session_id]['stream_format'] } logging.info(f" StreamSessionManager add_sentence_audio") # 添加到会话的句子列表 self.sessions[session_id]['sentences'].append(sentence_id) return sentence_id def get_sentence_audio(self, sentence_id): """获取句子音频数据""" with self.sentence_lock: if sentence_id not in self.sentence_audio_store: return None return self.sentence_audio_store[sentence_id]['data'] def get_sentence_info(self, sentence_id): """获取句子信息""" with self.sentence_lock: if sentence_id not in self.sentence_audio_store: return None return self.sentence_audio_store[sentence_id] def get_next_sentence(self, session_id): """获取下一个句子的信息""" with self.lock: session = self.sessions.get(session_id) if not session or not session['active']: return None if session['current_sentence_index'] < len(session['sentences']): sentence_id = session['sentences'][session['current_sentence_index']] session['current_sentence_index'] += 1 return { 'id': sentence_id, 'url': f"/tts_sentence/{sentence_id}" # 虚拟URL } return None def get_sentence_audio_data(self, session_id): with self.lock: session = self.sessions.get(session_id) if not session or not session['active']: return None return self.sentence_audio_store stream_manager = StreamSessionManager() def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} audio_text_cache = {} cache_lock = Lock() CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 # WebSocket 连接管理 class ConnectionManager: def __init__(self): self.active_connections = {} async def connect(self, websocket: WebSocket, connection_id: str): await websocket.accept() self.active_connections[connection_id] = websocket logging.info(f"新连接建立: {connection_id}") async def disconnect(self, connection_id: str, code=1000, reason: str = ""): if connection_id in self.active_connections: try: # 尝试正常关闭连接(非阻塞) await self.active_connections[connection_id].close(code=code, reason=reason) except: pass # 忽略关闭错误 finally: del self.active_connections[connection_id] def is_connected(self, connection_id: str) -> bool: """检查连接是否仍然活跃""" return connection_id in self.active_connections async def _safe_send(self, connection_id: str, send_func, *args): """安全发送的通用方法(核心修改)""" # 1. 检查连接是否存在 if connection_id not in self.active_connections: logging.warning(f"尝试向不存在的连接发送数据: {connection_id}") return False websocket = self.active_connections[connection_id] try: # 2. 检查连接状态(关键修改) if websocket.client_state.name != "CONNECTED": logging.warning(f"连接 {connection_id} 状态为 {websocket.client_state.name}") await self.disconnect(connection_id) return False # 3. 执行发送操作 await send_func(websocket, *args) return True except (WebSocketDisconnect, RuntimeError) as e: # 4. 处理连接断开异常 logging.info(f"发送时检测到断开连接: {connection_id}, {str(e)}") await self.disconnect(connection_id) return False except Exception as e: # 5. 处理其他异常 logging.error(f"发送数据出错: {connection_id}, {str(e)}") await self.disconnect(connection_id) return False async def send_bytes(self, connection_id: str, data: bytes): """安全发送字节数据""" return await self._safe_send( connection_id, lambda ws, d: ws.send_bytes(d), data ) async def send_text(self, connection_id: str, message: str): """安全发送文本数据""" return await self._safe_send( connection_id, lambda ws, m: ws.send_text(m), message ) async def send_json(self, connection_id: str, data: dict): """安全发送JSON数据""" return await self._safe_send( connection_id, lambda ws, d: ws.send_json(d), data ) manager = ConnectionManager() def generate_mp3_header( sample_rate: int, bitrate_kbps: int, channels: int = 1, layer: str = "III" # 新增参数,支持 "I"/"II"/"III" ) -> bytes: """ 动态生成 MP3 帧头(4字节),支持 Layer I/II/III :param sample_rate: 采样率 (8000, 16000, 22050, 44100) :param bitrate_kbps: 比特率(单位 kbps) :param channels: 声道数 (1: 单声道, 2: 立体声) :param layer: 编码层 ("I", "II", "III") :return: 4字节的帧头数据 """ # ---------------------------------- # 参数校验 # ---------------------------------- valid_sample_rates = {8000, 16000, 22050, 44100, 48000} if sample_rate not in valid_sample_rates: raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}") valid_layers = {"I", "II", "III"} if layer not in valid_layers: raise ValueError(f"不支持的层,可选:{valid_layers}") # ---------------------------------- # 确定 MPEG 版本和采样率索引 # ---------------------------------- if sample_rate == 44100: mpeg_version = 0b11 # MPEG-1 sample_rate_index = 0b00 elif sample_rate == 22050: mpeg_version = 0b10 # MPEG-2 sample_rate_index = 0b00 elif sample_rate == 16000: mpeg_version = 0b10 # MPEG-2 sample_rate_index = 0b10 elif sample_rate == 8000: mpeg_version = 0b00 # MPEG-2.5 sample_rate_index = 0b10 else: raise ValueError("采样率与版本不匹配") # ---------------------------------- # 动态选择比特率表(关键扩展) # ---------------------------------- # Layer 编码映射(I:0b11, II:0b10, III:0b01) layer_code = { "I": 0b11, "II": 0b10, "III": 0b01 }[layer] # 比特率表(覆盖所有 Layer) bitrate_tables = { # ------------------------------- # MPEG-1 (0b11) # ------------------------------- # Layer I (0b11, 0b11): { 32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011, 160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111, 288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011, 416: 0b1100, 448: 0b1101 }, # Layer II (0b11, 0b10): { 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, 160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011, 320: 0b1100, 384: 0b1101 }, # Layer III (0b11, 0b01): { 32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011, 64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111, 128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011, 256: 0b0100, 320: 0b0101 }, # ------------------------------- # MPEG-2 (0b10) / MPEG-2.5 (0b00) # ------------------------------- # Layer I (0b10, 0b11): { 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, 144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011, 224: 0b1100, 256: 0b1101 }, (0b00, 0b11): { 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, 144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011, 224: 0b1100, 256: 0b1101 }, # Layer II (0b10, 0b10): { 8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011, 40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111, 80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011, 144: 0b1100, 160: 0b1101 }, (0b00, 0b10): { 8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011, 40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111, 80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011, 144: 0b1100, 160: 0b1101 }, # Layer III (0b10, 0b01): { 8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011, 40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111, 80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011, 144: 0b0100, 160: 0b0101 }, (0b00, 0b01): { 8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011, 40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111 } } # 获取当前版本的比特率表 key = (mpeg_version, layer_code) if key not in bitrate_tables: raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}") bitrate_table = bitrate_tables[key] if bitrate_kbps not in bitrate_table: raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}") bitrate_index = bitrate_table[bitrate_kbps] # ---------------------------------- # 确定声道模式 # ---------------------------------- if channels == 1: channel_mode = 0b11 # 单声道 elif channels == 2: channel_mode = 0b00 # 立体声 else: raise ValueError("声道数必须为1或2") # ---------------------------------- # 组合帧头字段(修正层编码) # ---------------------------------- sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111) version = mpeg_version << 19 # MPEG 版本 2位 layer_bits = layer_code << 17 # Layer 编码(I:0b11, II:0b10, III:0b01) protection = 0 << 16 # 无 CRC bitrate_bits = bitrate_index << 12 sample_rate_bits = sample_rate_index << 10 padding = 0 << 9 # 无填充 private = 0 << 8 mode = channel_mode << 6 mode_ext = 0 << 4 # 扩展模式(单声道无需设置) copyright = 0 << 3 original = 0 << 2 emphasis = 0b00 # 无强调 frame_header = ( sync | version | layer_bits | protection | bitrate_bits | sample_rate_bits | padding | private | mode | mode_ext | copyright | original | emphasis ) return frame_header.to_bytes(4, byteorder='big') # ------------------------------------------------ def audio_fade_in(audio_data, fade_length): # 假设音频数据是16位单声道PCM # 将二进制数据转换为整数数组 samples = array.array('h', audio_data) # 对前fade_length个样本进行淡入处理 for i in range(fade_length): fade_factor = i / fade_length samples[i] = int(samples[i] * fade_factor) # 将整数数组转换回二进制数据 return samples.tobytes() def parse_markdown_json(json_string): # 使用正则表达式匹配Markdown中的JSON代码块 match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL) if match: try: # 尝试解析JSON字符串 data = json.loads(match[1]) return {'success': True, 'data': data} except json.JSONDecodeError as e: # 如果解析失败,返回错误信息 return {'success': False, 'data': str(e)} else: return {'success': False, 'data': 'not a valid markdown json string'} audio_text_cache = {} cache_lock = Lock() CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 # 全角字符到半角字符的映射 def fullwidth_to_halfwidth(s): full_to_half_map = { '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', '、': '、', '・': '・', ':': ':' } return ''.join(full_to_half_map.get(char, char) for char in s) def split_text_at_punctuation(text, chunk_size=100): # 使用正则表达式找到所有的标点符号和特殊字符 punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' tokens = re.split(punctuation_pattern, text) # 移除空字符串 tokens = [token for token in tokens if token] # 存储最终的文本块 chunks = [] current_chunk = '' for token in tokens: if len(current_chunk) + len(token) <= chunk_size: # 如果添加当前token后长度不超过chunk_size,则添加到当前块 current_chunk += (token + ' ') else: # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 chunks.append(current_chunk.strip()) current_chunk = token + ' ' # 添加最后一个块(如果有剩余) if current_chunk: chunks.append(current_chunk.strip()) return chunks def extract_text_from_markdown(markdown_text): # 移除Markdown标题 text = re.sub(r'#\s*[^#]+', '', markdown_text) # 移除内联代码块 text = re.sub(r'`[^`]+`', '', text) # 移除代码块 text = re.sub(r'```[\s\S]*?```', '', text) # 移除加粗和斜体 text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) # 移除链接 text = re.sub(r'\[.*?\]\(.*?\)', '', text) # 移除图片 text = re.sub(r'!\[.*?\]\(.*?\)', '', text) # 移除HTML标签 text = re.sub(r'<[^>]+>', '', text) # 转换标点符号 # text = re.sub(r'[^\w\s]', '', text) text = fullwidth_to_halfwidth(text) # 移除多余的空格 text = re.sub(r'\s+', ' ', text).strip() return text def encode_gzip_base64(original_data: bytes) -> str: """核心编码过程:二进制数据 → Gzip压缩 → Base64编码""" # Step 1: Gzip 压缩 with BytesIO() as buf: with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file: gz_file.write(original_data) compressed_bytes = buf.getvalue() # Step 2: Base64 编码(配置与Android端匹配) return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP) def clean_audio_cache(): """定时清理过期缓存""" with cache_lock: now = time.time() expired_keys = [ k for k, v in audio_text_cache.items() if now - v['created_at'] > CACHE_EXPIRE_SECONDS ] for k in expired_keys: entry = audio_text_cache.pop(k, None) if entry and entry.get('audio_stream'): entry['audio_stream'].close() def start_background_cleaner(): """启动后台清理线程""" def cleaner_loop(): while True: time.sleep(180) # 每3分钟清理一次 clean_audio_cache() cleaner_thread = Thread(target=cleaner_loop, daemon=True) cleaner_thread.start() def test_qwen_chat(): messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': '你是谁?'} ] response = Generation.call( # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx", api_key=ALI_KEY, model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models messages=messages, result_format="message" ) if response.status_code == 200: print(response.output.choices[0].message.content) else: print(f"HTTP返回码:{response.status_code}") print(f"错误码:{response.code}") print(f"错误信息:{response.message}") print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code") ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75" from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse from dashscope.audio.tts import ( ResultCallback as TTSResultCallback, SpeechSynthesizer as TTSSpeechSynthesizer, SpeechSynthesisResult as TTSSpeechSynthesisResult, ) # cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本 from dashscope.audio.tts_v2 import ( ResultCallback as CosyResultCallback, SpeechSynthesizer as CosySpeechSynthesizer, AudioFormat, ) class QwenTTS: def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun", special_characters: Optional[Dict[str, str]] = None): import dashscope import ssl logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx self.model_name = model_name.split('@')[0] dashscope.api_key = key ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证 self.synthesizer = None self.callback = None self.is_cosyvoice = False self.voice = "" self.format = format self.sample_rate = sample_rate self.first_chunk = True if '/' in self.model_name: parts = self.model_name.split('/', 1) # 返回分离后的两个字符串parts[0], parts[1] if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2': self.is_cosyvoice = True self.voice = parts[1] self.completion_event = None # 新增:用于通知任务完成 # 特殊字符及其拼音映射 self.special_characters = special_characters or { "㼽": "chuang3", "䡇": "yue4" # 可以添加更多特殊字符的映射 } class Callback(TTSResultCallback): def __init__(self,data_callback=None,completion_event=None) -> None: self.dque = deque() self.data_callback = data_callback self.completion_event = completion_event # 新增完成事件引用 def _run(self): while True: if not self.dque: time.sleep(0) continue val = self.dque.popleft() if val: yield val else: break def on_open(self): pass def on_complete(self): self.dque.append(None) if self.data_callback: self.data_callback(None) # 发送结束信号 # 通知任务完成 if self.completion_event: self.completion_event.set() def on_error(self, response: SpeechSynthesisResponse): print("Qwen tts error", str(response)) raise RuntimeError(str(response)) def on_close(self): pass def on_event(self, result: TTSSpeechSynthesisResult): data =result.get_audio_frame() if data is not None: if len(data) > 0: if self.data_callback: self.data_callback(data) else: self.dque.append(data) #self.dque.append(result.get_audio_frame()) # -------------------------- class Callback_Cosy(CosyResultCallback): def __init__(self, data_callback=None,completion_event=None) -> None: self.dque = deque() self.data_callback = data_callback self.completion_event = completion_event # 新增完成事件引用 def _run(self): while True: if not self.dque: time.sleep(0) continue val = self.dque.popleft() if val: yield val else: break def on_open(self): logging.info("Qwen CosyVoice tts open ") pass def on_complete(self): self.dque.append(None) if self.data_callback: self.data_callback(None) # 发送结束信号 # 通知任务完成 if self.completion_event: self.completion_event.set() def on_error(self, response: SpeechSynthesisResponse): print("Qwen tts error", str(response)) if self.data_callback: self.data_callback(f"ERROR:{str(response)}".encode()) raise RuntimeError(str(response)) def on_close(self): # print("---Qwen call back close") # cyx logging.info("Qwen CosyVoice tts close") pass """ canceled for test 语音大模型CosyVoice def on_event(self, result: SpeechSynthesisResult): if result.get_audio_frame() is not None: self.dque.append(result.get_audio_frame()) """ def on_event(self, message): # logging.info(f"recv speech synthsis message {message}") pass # 以下适合语音大模型CosyVoice def on_data(self, data: bytes) -> None: if len(data) > 0: if self.data_callback: self.data_callback(data) else: self.dque.append(data) # -------------------------- def tts(self, text, on_data = None,completion_event=None): # logging.info(f"---QwenTTS tts begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx # text = self.normalize_text(text) print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx # text = self.normalize_text(text) try: # if self.model_name != 'cosyvoice-v1': if self.is_cosyvoice is False: self.callback = self.Callback( data_callback=on_data, completion_event=completion_event ) TTSSpeechSynthesizer.call(model=self.model_name, text=text, callback=self.callback, format=self.format) # format="mp3") else: self.callback = self.Callback_Cosy() format = self.get_audio_format(self.format, self.sample_rate) self.synthesizer = CosySpeechSynthesizer( model='cosyvoice-v2', # voice="longyuan", #"longfei", voice=self.voice, callback=self.callback, format=format ) self.synthesizer.call(text) except Exception as e: print(f"---dale---20 error {e}") # cyx # ----------------------------------- try: for data in self.callback._run(): # logging.info(f"dashcope return data {len(data)}") yield data # print(f"---Qwen return data {num_tokens_from_string(text)}") # yield num_tokens_from_string(text) except Exception as e: raise RuntimeError(f"**ERROR**: {e}") def setup_tts(self, on_data,completion_event=None): """设置 TTS 回调,返回配置好的 synthesizer""" #if not self.is_cosyvoice: # raise NotImplementedError("Only CosyVoice supported") if self.is_cosyvoice: # 创建 CosyVoice 回调 self.callback = self.Callback_Cosy( data_callback=on_data, completion_event=completion_event) else: self.callback = self.Callback( data_callback=on_data, completion_event=completion_event) if self.is_cosyvoice: format_val = self.get_audio_format(self.format, self.sample_rate) # logging.info(f"Qwen setup_tts {self.voice} {format_val}") self.synthesizer = CosySpeechSynthesizer( model='cosyvoice-v1', voice=self.voice, # voice="longyuan", #"longfei", callback=self.callback, format=format_val ) return self.synthesizer def apply_phoneme_tags(self, text: str) -> str: """ 在文本中查找特殊字符并用标签包裹它们 """ # 如果文本已经是SSML格式,直接返回 if text.strip().startswith("") and text.strip().endswith(""): return text # 为特殊字符添加SSML标签 for char, pinyin in self.special_characters.items(): # 使用正则表达式确保只替换整个字符(避免部分匹配) pattern = r'([^<]|^)' + re.escape(char) + r'([^>]|$)' replacement = r'\1' + char + r'\2' text = re.sub(pattern, replacement, text) # 如果文本中已有SSML标签,直接返回 if "" in text: return text # 否则包裹在标签中 return f"{text}" def text_tts_call(self, text): if self.special_characters and self.is_cosyvoice is False: text = self.apply_phoneme_tags(text) #logging.info(f"Applied SSML phoneme tags to text: {text}") if self.synthesizer and self.is_cosyvoice: logging.info(f"Qwen text_tts_call {text} {self.is_cosyvoice}") format_val = self.get_audio_format(self.format, self.sample_rate) self.synthesizer = CosySpeechSynthesizer( model='cosyvoice-v1', voice=self.voice, # voice="longyuan", #"longfei", callback=self.callback, format=format_val ) self.synthesizer.call(text) if self.is_cosyvoice is False: logging.info(f"Qwen text_tts_call {text}") TTSSpeechSynthesizer.call(model=self.model_name, text=text, callback=self.callback, format=self.format) def streaming_call(self, text): if self.synthesizer: self.synthesizer.streaming_call(text) def end_streaming_call(self): if self.synthesizer: # logging.info(f"---dale end_streaming_call") self.synthesizer.streaming_complete() def get_audio_format(self, format: str, sample_rate: int): """动态获取音频格式""" from dashscope.audio.tts_v2 import AudioFormat logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}") format_map = { (8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS, (8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT, (8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT, (16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT, (22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS, (22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT, (22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT, (44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS, (44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT, (44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT, (48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS, (48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT, (48000, 'wav'): AudioFormat.WAV_48000HZ_MONO_16BIT } return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS) class DoubaoTTS: def __init__(self, key, format="mp3", sample_rate=8000, model_name="doubao-tts"): logging.info(f"---begin--init DoubaoTTS-- {format} {sample_rate} {model_name}") # 解析豆包认证信息 (appid, token, cluster, voice_type) try: self.appid = "7282190702" self.token = "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9" self.cluster = "volcano_tts" self.voice_type ="zh_female_qingxinnvsheng_mars_bigtts" # "zh_male_jieshuonansheng_mars_bigtts" #"zh_male_ruyaqingnian_mars_bigtts" #"zh_male_jieshuonansheng_mars_bigtts" except Exception as e: raise ValueError(f"Invalid Doubao key format: {str(e)}") self.format = format self.sample_rate = sample_rate self.model_name = model_name self.callback = None self.ws = None self.loop = None self.task = None self.event = threading.Event() self.data_queue = deque() self.host = "openspeech.bytedance.com" self.api_url = f"wss://{self.host}/api/v1/tts/ws_binary" self.default_header = bytearray(b'\x11\x10\x11\x00') self.total_data_size = 0 self.completion_event = None # 新增:用于通知任务完成 class Callback: def __init__(self, data_callback=None,completion_event=None): self.data_callback = data_callback self.data_queue = deque() self.completion_event = completion_event # 完成事件引用 def on_data(self, data): if self.data_callback: self.data_callback(data) else: self.data_queue.append(data) # 通知任务完成 if self.completion_event: self.completion_event.set() def on_complete(self): if self.data_callback: self.data_callback(None) def on_error(self, error): if self.data_callback: self.data_callback(f"ERROR:{error}".encode()) def setup_tts(self, on_data,completion_event): """设置回调,返回自身(因为豆包需要异步启动)""" self.callback = self.Callback( data_callback=on_data, completion_event=completion_event ) return self def text_tts_call(self, text): """同步调用,启动异步任务并等待完成""" self.total_data_size = 0 self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) self.task = self.loop.create_task(self._async_tts(text)) try: self.loop.run_until_complete(self.task) except Exception as e: logging.error(f"DoubaoTTS--0 call error: {e}") self.callback.on_error(str(e)) async def _async_tts(self, text): """异步执行TTS请求""" header = {"Authorization": f"Bearer; {self.token}"} request_json = { "app": { "appid": self.appid, "token": "access_token", # 固定值 "cluster": self.cluster }, "user": { "uid": str(uuid.uuid4()) # 随机用户ID }, "audio": { "voice_type": self.voice_type, "encoding": self.format, "speed_ratio": 1.0, "volume_ratio": 1.0, "pitch_ratio": 1.0, }, "request": { "reqid": str(uuid.uuid4()), "text": text, "text_type": "plain", "operation": "submit" # 使用submit模式支持流式 } } # 构建请求数据 payload_bytes = str.encode(json.dumps(request_json)) payload_bytes = gzip.compress(payload_bytes) full_client_request = bytearray(self.default_header) full_client_request.extend(len(payload_bytes).to_bytes(4, 'big')) full_client_request.extend(payload_bytes) try: async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws: self.ws = ws await ws.send(full_client_request) # 接收音频数据 while True: res = await ws.recv() done = self._parse_response(res) if done: self.callback.on_complete() break except Exception as e: logging.error(f"DoubaoTTS--1 WebSocket error: {e}") self.callback.on_error(str(e)) finally: # 通知任务完成 if self.completion_event: self.completion_event.set() def _parse_response(self, res): """解析豆包返回的二进制响应""" # 协议头解析 (4字节) header_size = res[0] & 0x0f message_type = res[1] >> 4 payload = res[header_size * 4:] # 音频数据响应 if message_type == 0xb: # audio-only server response message_flags = res[1] & 0x0f # ACK消息,忽略 if message_flags == 0: return False # 音频数据消息 sequence_number = int.from_bytes(payload[:4], "big", signed=True) payload_size = int.from_bytes(payload[4:8], "big", signed=False) audio_data = payload[8:8 + payload_size] if audio_data: self.total_data_size = self.total_data_size + len(audio_data) self.callback.on_data(audio_data) #logging.info(f"doubao _parse_response: {sequence_number} {len(audio_data)} {self.total_data_size}") # 序列号为负表示结束 return sequence_number < 0 # 错误响应 elif message_type == 0xf: code = int.from_bytes(payload[:4], "big", signed=False) msg_size = int.from_bytes(payload[4:8], "big", signed=False) error_msg = payload[8:8 + msg_size] try: # 尝试解压错误消息 error_msg = gzip.decompress(error_msg).decode() except: error_msg = error_msg.decode(errors='ignore') logging.error(f"DoubaoTTS error: {error_msg}") self.callback.on_error(error_msg) return False return False class UnifiedTTSEngine: def __init__(self): self.lock = threading.Lock() self.tasks = {} self.executor = ThreadPoolExecutor(max_workers=10) self.cache_expire = 300 # 5分钟缓存 # 启动清理过期任务的定时器 self.cleanup_timer = None self.start_cleanup_timer() def _cleanup_old_tasks(self): """清理过期任务""" now = time.time() with self.lock: expired_ids = [task_id for task_id, task in self.tasks.items() if now - task['created_at'] > self.cache_expire] for task_id in expired_ids: self._remove_task(task_id) def _remove_task(self, task_id): """移除任务""" if task_id in self.tasks: task = self.tasks.pop(task_id) # 取消可能的后台任务 if 'future' in task and not task['future'].done(): task['future'].cancel() # 其他资源在任务被移除后会被垃圾回收 # 资源释放机制总结: # 移除任务引用:self.tasks.pop() 解除任务对象引用,触发垃圾回收。 # 取消后台线程:future.cancel() 终止未完成线程,释放线程资源。 # 自动内存回收:Python GC 回收任务对象及其队列、缓冲区占用的内存。 # 线程池管理:执行器自动回收线程至池中,避免资源泄漏。 def create_tts_task(self, text, format, sample_rate, model_name, key, delay_gen_audio=False): """创建TTS任务(同步方法)""" self._cleanup_old_tasks() audio_stream_id = str(uuid.uuid4()) # 创建任务数据结构 task_data = { 'id': audio_stream_id, 'text': text, 'format': format, 'sample_rate': sample_rate, 'model_name': model_name, 'key': key, 'delay_gen_audio': delay_gen_audio, 'created_at': time.time(), 'status': 'pending', 'data_queue': deque(), 'event': threading.Event(), 'completed': False, 'error': None } with self.lock: self.tasks[audio_stream_id] = task_data # 如果不是延迟模式,立即启动任务 if not delay_gen_audio: self._start_tts_task(audio_stream_id) return audio_stream_id def _start_tts_task(self, audio_stream_id): # 启动TTS任务(后台线程) task = self.tasks.get(audio_stream_id) if not task or task['status'] != 'pending': return logging.info("已经启动 start tts task {audio_stream_id}") task['status'] = 'processing' # 在后台线程中执行TTS future = self.executor.submit(self._run_tts_sync, audio_stream_id) task['future'] = future # 如果需要等待任务完成 if not task.get('delay_gen_audio', True): try: # 等待任务完成(最多5分钟) future.result(timeout=300) logging.info(f"TTS任务 {audio_stream_id} 已完成") self._merge_audio_data(audio_stream_id) except concurrent.futures.TimeoutError: task['error'] = "TTS生成超时" task['completed'] = True logging.error(f"TTS任务 {audio_stream_id} 超时") except Exception as e: task['error'] = f"ERROR:{str(e)}" task['completed'] = True logging.exception(f"TTS任务执行异常: {str(e)}") def _run_tts_sync(self, audio_stream_id): # 同步执行TTS生成 在后台线程中执行 task = self.tasks.get(audio_stream_id) if not task: return try: # 创建完成事件 completion_event = threading.Event() # 创建TTS实例 # 根据model_name选择TTS引擎 # 前端传入 cosyvoice-v1/longhua@Tongyi-Qianwen model_name_wo_brand = task['model_name'].split('@')[0] model_name_version = model_name_wo_brand.split('/')[0] if "longhua" in task['model_name'] or "zh_female_qingxinnvsheng_mars_bigtts" in task['model_name']: # 豆包TTS tts = DoubaoTTS( key=task['key'], format=task['format'], sample_rate=task['sample_rate'], model_name=task['model_name'] ) else: # 通义千问TTS tts = QwenTTS( key=task['key'], format=task['format'], sample_rate=task['sample_rate'], model_name=task['model_name'] ) # 定义同步数据处理函数 def data_handler(data): if data is None: # 结束信号 task['completed'] = True task['event'].set() logging.info(f"--data_handler on_complete") elif data.startswith(b"ERROR"): # 错误信号 task['error'] = data.decode() task['completed'] = True task['event'].set() else: # 音频数据 task['data_queue'].append(data) # 设置并执行TTS synthesizer = tts.setup_tts(data_handler,completion_event) #synthesizer.call(task['text']) tts.text_tts_call(task['text']) # 等待完成或超时 # 等待完成或超时 if not completion_event.wait(timeout=300): # 5分钟超时 task['error'] = "TTS generation timeout" task['completed'] = True logging.info(f"--tts task event set error = {task['error']}") except Exception as e: logging.info(f"UnifiedTTSEngine _run_tts_sync ERROR: {str(e)}") task['error'] = f"ERROR:{str(e)}" task['completed'] = True finally: # 确保清理TTS资源 logging.info("UnifiedTTSEngine _run_tts_sync finally") if hasattr(tts, 'loop') and tts.loop: tts.loop.close() def _merge_audio_data(self, audio_stream_id): """将任务的所有音频数据合并到ByteIO缓冲区""" task = self.tasks.get(audio_stream_id) if not task or not task.get('completed'): return try: logging.info(f"开始合并音频数据: {audio_stream_id}") # 创建内存缓冲区 buffer = io.BytesIO() # 合并所有数据块 for data_chunk in task['data_queue']: buffer.write(data_chunk) # 重置指针位置以便读取 buffer.seek(0) # 保存到任务对象 task['buffer'] = buffer logging.info(f"音频数据合并完成,总大小: {buffer.getbuffer().nbytes} 字节") # 可选:清理原始数据队列以节省内存 task['data_queue'].clear() except Exception as e: logging.error(f"合并音频数据失败: {str(e)}") task['error'] = f"合并错误: {str(e)}" async def get_audio_stream(self, audio_stream_id): """获取音频流(异步生成器)""" task = self.tasks.get(audio_stream_id) if not task: raise RuntimeError("Audio stream not found") # 如果是延迟任务且未启动,现在启动 status 为 pending if task['delay_gen_audio'] and task['status'] == 'pending': self._start_tts_task(audio_stream_id) total_audio_data_size = 0 # 等待任务启动 while task['status'] == 'pending': await asyncio.sleep(0.1) # 流式返回数据 while not task['completed'] or task['data_queue']: while task['data_queue']: data = task['data_queue'].popleft() total_audio_data_size += len(data) #logging.info(f"yield audio data {len(data)} {total_audio_data_size}") yield data # 短暂等待新数据 await asyncio.sleep(0.05) # 检查错误 if task['error']: raise RuntimeError(task['error']) def start_cleanup_timer(self): """启动定时清理任务""" if self.cleanup_timer: self.cleanup_timer.cancel() self.cleanup_timer = threading.Timer(30.0, self.cleanup_task) # 每30秒清理一次 self.cleanup_timer.daemon = True # 设置为守护线程 self.cleanup_timer.start() def cleanup_task(self): """执行清理任务""" try: self._cleanup_old_tasks() except Exception as e: logging.error(f"清理任务时出错: {str(e)}") finally: self.start_cleanup_timer() # 重新启动定时器 # 全局 TTS 引擎实例 tts_engine = UnifiedTTSEngine() def replace_domain(url: str) -> str: """替换URL中的域名为本地地址,不使用urllib.parse""" # 定义需要替换的域名列表 domains_to_replace = [ "http://1.13.185.116:9380", "https://ragflow.szzysztech.com", "1.13.185.116:9380", "ragflow.szzysztech.com" ] # 尝试替换每个可能的域名 for domain in domains_to_replace: if domain in url: # 直接替换域名部分 return url.replace(domain, "http://localhost:9380", 1) # 如果未匹配到特定域名,尝试智能替换 if "://" in url: # 分割协议和路径 protocol, path = url.split("://", 1) # 查找第一个斜杠位置来确定域名结束位置 slash_pos = path.find("/") if slash_pos > 0: # 替换域名部分 return f"http://localhost:9380{path[slash_pos:]}" else: # 没有路径部分,直接返回本地地址 return "http://localhost:9380" else: # 没有协议部分,直接添加本地地址 return f"http://localhost:9380/{url}" async def proxy_aichat_audio_stream(client_id: str, audio_url: str): """代理外部音频流请求""" try: # 替换域名为本地地址 local_url = audio_url logging.info(f"代理音频流: {audio_url} -> {local_url}") async with httpx.AsyncClient(timeout=60.0) as client: async with client.stream("GET", local_url) as response: # 流式转发音频数据 async for chunk in response.aiter_bytes(): if not await manager.send_bytes(client_id, chunk): logging.warning(f"Audio proxy interrupted for {client_id}") return except Exception as e: logging.error(f"Audio proxy failed: {str(e)}") await manager.send_text(client_id, json.dumps({ "type": "error", "message": f"音频流获取失败: {str(e)}" })) # 代理函数 - 文本流 # 在微信小程序中,原来APK使用的SSE机制不能正常工作,需要使用WebSocket async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict): """代理大模型文本流请求 - 兼容现有Flask实现""" try: logging.info(f"代理文本流: completions_url={completions_url} {payload}") logging.debug(f"请求负载: {json.dumps(payload, ensure_ascii=False)}") headers = { "Content-Type": "application/json", 'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm' } tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen') #if 'longyuan' in tts_model_name: # tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen" # 创建TTS实例 tts_model = QwenTTS( key=ALI_KEY, format=payload.get('tts_stream_format', 'mp3'), sample_rate=payload.get('tts_sample_rate', 48000), model_name=tts_model_name ) streaming_call = False if tts_model.is_cosyvoice: streaming_call = True # 创建流会话 tts_stream_session_id = stream_manager.create_session( tts_model=tts_model, sample_rate=payload.get('tts_sample_rate', 48000), stream_format=payload.get('tts_stream_format', 'mp3'), session_id=None, streaming_call= streaming_call ) # logging.info(f"---tts_stream_session_id = {tts_stream_session_id}") tts_stream_session_id_sent = False send_sentence_tts_url = False # 添加一个事件来标记所有句子已发送 all_sentences_sent = asyncio.Event() # 任务:监听并发送新生成的句子 async def send_new_sentences(): """监听并发送新生成的句子""" try: while True: # 获取下一个句子 sentence_info = stream_manager.get_next_sentence(tts_stream_session_id) if sentence_info: logging.info(f"--proxy_aichat_text_stream 发送sentence_info\r\n") # 发送句子信息 await manager.send_json(client_id, { "type": "tts_sentence", "id": sentence_info['id'], "text": stream_manager.get_sentence_info(sentence_info['id'])['text'], "url": sentence_info['url'] }) else: # 检查会话是否结束且没有更多句子 session = stream_manager.get_session(tts_stream_session_id) if not session or (not session['active']): #and session['current_sentence_index'] >= len(session['sentences'])): # 标记所有句子已发送 all_sentences_sent.set() break # 等待新句子生成 await asyncio.sleep(0.1) except asyncio.CancelledError: logging.info("句子监听任务被取消") except Exception as e: logging.error(f"句子监听任务出错: {str(e)}") all_sentences_sent.set() if send_sentence_tts_url: # 启动句子监听任务 sentence_task = asyncio.create_task(send_new_sentences()) # 使用更长的超时时间 (5分钟) timeout = httpx.Timeout(300.0, connect=60.0) async with httpx.AsyncClient(timeout=timeout) as client: # 关键修改:使用流式请求模式 async with client.stream( # <-- 使用stream方法 "POST", completions_url, json=payload, headers=headers ) as response: logging.info(f"响应状态: HTTP {response.status_code}") if response.status_code != 200: # 读取错误信息(非流式) error_content = await response.aread() error_msg = f"后端错误: HTTP {response.status_code}" error_msg += f" - {error_content[:200].decode()}" if error_content else "" await manager.send_text(client_id, json.dumps({"type": "error", "message": error_msg})) return # 验证SSE流 content_type = response.headers.get("content-type", "").lower() if "text/event-stream" not in content_type: logging.warning("非流式响应,转发完整内容") full_content = await response.aread() await manager.send_text(client_id, json.dumps({ "type": "text", "data": full_content.decode('utf-8') })) return logging.info("开始处理SSE流") event_count = 0 # 使用异步迭代器逐行处理 async for line in response.aiter_lines(): # 跳过空行和注释行 if not line or line.startswith(':'): continue # 处理SSE事件 if line.startswith("data:"): data_str = line[5:].strip() if data_str: # 过滤空数据 try: # 解析并提取增量文本 data_obj = json.loads(data_str) delta_text = None if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict): delta_text = data_obj.get('data', None).get('delta_ans', "") if tts_stream_session_id_sent is False: logging.info(f"--proxy_aichat_text_stream 发送audio_stream_url") data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}" data_str = json.dumps(data_obj) tts_stream_session_id_sent = True # 直接转发原始数据 await manager.send_text(client_id, json.dumps({ "type": "text", "data": data_str })) # 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析 if delta_text: # 追加到会话管理器 stream_manager.append_text(tts_stream_session_id, delta_text) # logging.info(f"文本代理转发: {data_str}") event_count += 1 except Exception as e: logging.error(f"事件发送失败: {str(e)}") # 保持连接活性 await asyncio.sleep(0.001) # 避免CPU空转 logging.info(f"SSE流处理完成,事件数: {event_count}") # 发送文本流结束信号 await manager.send_text(client_id, json.dumps({"type": "end"})) # 标记文本输入结束 if stream_manager.finish_text_input: stream_manager.finish_text_input(tts_stream_session_id) if send_sentence_tts_url: # 等待所有句子生成并发送(最多等待300秒) try: await asyncio.wait_for(all_sentences_sent.wait(), timeout=300.0) logging.info(f"所有TTS句子已发送") except asyncio.TimeoutError: logging.warning("等待TTS句子发送超时") # 取消句子监听任务(如果仍在运行) if not sentence_task.done(): sentence_task.cancel() try: await sentence_task except asyncio.CancelledError: pass except httpx.ReadTimeout: logging.error("读取后端服务超时") await manager.send_text(client_id, json.dumps({ "type": "error", "message": "后端服务响应超时" })) except httpx.ConnectError as e: logging.error(f"连接后端服务失败: {str(e)}") await manager.send_text(client_id, json.dumps({ "type": "error", "message": f"无法连接到后端服务: {str(e)}" })) except Exception as e: logging.exception(f"文本代理失败: {str(e)}") await manager.send_text(client_id, json.dumps({ "type": "error", "message": f"文本流获取失败: {str(e)}" })) @tts_router.get("/audio/pcm_mp3") async def stream_mp3(): def audio_generator(): path = './test.mp3' try: with open(path, 'rb') as f: while True: chunk = f.read(1024) if not chunk: break yield chunk except Exception as e: logging.error(f"MP3 streaming error: {str(e)}") return StreamingResponse( audio_generator(), media_type="audio/mpeg", headers={ "Cache-Control": "no-store", "Accept-Ranges": "bytes" } ) def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes: """动态生成WAV头(严格保持原有逻辑结构)""" with BytesIO() as wav_buffer: with wave.open(wav_buffer, 'wb') as wav_file: wav_file.setnchannels(1) # 保持原单声道设置 wav_file.setsampwidth(2) # 保持原16-bit设置 wav_file.setframerate(sample_rate) wav_file.writeframes(pcm_data) wav_buffer.seek(0) return wav_buffer.read() def generate_silence_header(duration_ms: int = 500) -> bytes: """生成静音数据(用于MP3流式传输预缓冲)""" num_samples = int(TTS_SAMPLERATE * duration_ms / 1000) return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS # ------------------------ API路由 ------------------------ @tts_router.get("/tts_sentence/{sentence_id}") async def get_sentence_audio(sentence_id: str): # 获取音频数据 audio_data = stream_manager.get_sentence_audio(sentence_id) if not audio_data: raise HTTPException(status_code=404, detail="Audio not found") # 获取音频格式 sentence_info = stream_manager.get_sentence_info(sentence_id) if not sentence_info: raise HTTPException(status_code=404, detail="Sentence info not found") # 确定MIME类型 format = sentence_info['format'] media_type = "audio/mpeg" if format == "mp3" else "audio/wav" logging.info(f"--http get sentence tts audio stream {sentence_id}") # 返回流式响应 return StreamingResponse( io.BytesIO(audio_data), media_type=media_type, headers={ "Content-Disposition": f"attachment; filename=audio.{format}", "Cache-Control": "max-age=3600" # 缓存1小时 } ) @tts_router.post("/chats/{chat_id}/tts") async def create_tts_request(chat_id: str, request: Request): try: data = await request.json() logging.info(f"Creating TTS request: {data}") # 参数校验 text = data.get("text", "").strip() if not text: raise HTTPException(400, detail="Text cannot be empty") format = data.get("tts_stream_format", "mp3") if format not in ["mp3", "wav", "pcm"]: raise HTTPException(400, detail="Unsupported audio format") sample_rate = data.get("tts_sample_rate", 48000) if sample_rate not in [8000, 16000, 22050, 44100, 48000]: raise HTTPException(400, detail="Unsupported sample rate") model_name = data.get("model_name", "cosyvoice-v1/longxiaochun") delay_gen_audio = data.get('delay_gen_audio', False) # 创建TTS任务 audio_stream_id = tts_engine.create_tts_task( text=text, format=format, sample_rate=sample_rate, model_name=model_name, key=ALI_KEY, delay_gen_audio=delay_gen_audio ) return JSONResponse( status_code=200, content={ "tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}", "url": f"/chats/{chat_id}/tts/{audio_stream_id}", "ws_url": f"/chats/{chat_id}/tts/{audio_stream_id}", # WebSocket URL 2025 0622新增 "expires_at": (datetime.datetime.now() + datetime.timedelta(seconds=300)).isoformat() } ) except Exception as e: logging.error(f"Request failed: {str(e)}") raise HTTPException(500, detail="Internal server error") executor = ThreadPoolExecutor() @tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}") async def get_tts_audio( chat_id: str, audio_stream_id: str, range: str = Header(None) ): try: # 获取任务信息 task = tts_engine.tasks.get(audio_stream_id) if not task: # 返回友好的错误信息而不是抛出异常 return JSONResponse( status_code=404, content={ "error": "Audio stream not found", "message": f"The requested audio stream ID '{audio_stream_id}' does not exist or has expired", "suggestion": "Please create a new TTS request and try again" } ) # 获取媒体类型 format = task['format'] media_type = { "mp3": "audio/mpeg", "wav": "audio/wav", "pcm": f"audio/L16; rate={task['sample_rate']}; channels=1" }[format] # 如果任务已完成且有完整缓冲区,处理Range请求 logging.info(f"get_tts_audio task = {task.get('completed', 'None')} {task.get('buffer', 'None')}") # 创建响应内容生成器 def buffer_read(buffer): content_length = buffer.getbuffer().nbytes remaining = content_length chunk_size = 4096 buffer.seek(0) while remaining > 0: read_size = min(remaining, chunk_size) data = buffer.read(read_size) if not data: break yield data remaining -= len(data) if task.get('completed') and task.get('buffer') is not None: buffer = task['buffer'] total_size = buffer.getbuffer().nbytes # 强制小文件使用流式传输(避免206响应问题) if total_size < 1024 * 120: # 小于300KB range = None if range: # 处理范围请求 return handle_range_request(range, buffer, total_size, media_type) else: return StreamingResponse( buffer_read(buffer), media_type=media_type, headers={ "Accept-Ranges": "bytes", "Cache-Control": "no-store", "Transfer-Encoding": "chunked" }) # 创建流式响应 logging.info("tts_engine.get_audio_stream--0") return StreamingResponse( tts_engine.get_audio_stream(audio_stream_id), media_type=media_type, headers={ "Accept-Ranges": "bytes", "Cache-Control": "no-store", "Transfer-Encoding": "chunked" } ) except Exception as e: logging.error(f"Audio streaming failed: {str(e)}") raise HTTPException(500, detail="Audio generation error") def handle_range_request(range_header: str, buffer: BytesIO, total_size: int, media_type: str): """处理 HTTP Range 请求""" try: # 解析 Range 头部 (示例: "bytes=0-1023") range_type, range_spec = range_header.split('=') if range_type != 'bytes': raise ValueError("Unsupported range type") start_str, end_str = range_spec.split('-') start = int(start_str) end = int(end_str) if end_str else total_size - 1 logging.info(f"handle_range_request--1 {start_str}-{end_str} {end}") # 验证范围有效性 if start >= total_size or end >= total_size: raise HTTPException(status_code=416, headers={ "Content-Range": f"bytes */{total_size}" }) # 计算内容长度 content_length = end - start + 1 # 设置状态码 status_code = 206 # Partial Content if start == 0 and end == total_size - 1: status_code = 200 # Full Content # 设置流读取位置 buffer.seek(start) # 创建响应内容生成器 def content_generator(): remaining = content_length chunk_size = 4096 while remaining > 0: read_size = min(remaining, chunk_size) data = buffer.read(read_size) if not data: break yield data remaining -= len(data) # 返回分块响应 return StreamingResponse( content_generator(), status_code=status_code, media_type=media_type, headers={ "Content-Range": f"bytes {start}-{end}/{total_size}", "Content-Length": str(content_length), "Accept-Ranges": "bytes", "Cache-Control": "public, max-age=3600" } ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @tts_router.websocket("/chats/{chat_id}/tts/{audio_stream_id}") async def websocket_tts_endpoint( websocket: WebSocket, chat_id: str, audio_stream_id: str ): # 接收 header 参数 headers = websocket.headers service_type = headers.get("x-tts-type") # 注意:header 名称转为小写 # audio_url = headers.get("x-audio-url") """ 前端示例 websocketConnection = uni.connectSocket({ url: url, header: { 'Authorization': token, 'X-Tts-Type': 'AiChat', //'Ask' // 自定义参数1 'X-Device-Type': 'mobile', // 自定义参数2 'X-User-ID': '12345' // 自定义参数3 }, success: () => { console.log('WebSocket connected'); }, fail: (err) => { console.error('WebSocket connection failed:', err); } }); """ # 创建唯一连接 ID connection_id = str(uuid.uuid4()) # logging.info(f"---dale-- websocket connection_id = {connection_id} chat_id={chat_id}") await manager.connect(websocket, connection_id) completed_successfully = False try: # 根据tts_type路由到不同的音频源 if service_type == "AiChatTts": # 音频代理服务 audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}" # await proxy_aichat_audio_stream(connection_id, audio_url) sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate') audio_data_size =0 await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate}) async for data in stream_manager.get_tts_buffer_data(audio_stream_id): if data.get('type') == 'sentence_end': await manager.send_json(connection_id, {"command": "sentence_end"}) if data.get('type') == 'arraybuffer': audio_data_size += len(data.get('data')) if not await manager.send_bytes(connection_id, data.get('data')): break completed_successfully = True logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}") elif service_type == "AiChatText": # 文本代理服务 # 等待客户端发送初始请求数据 进行大模型对话代理时,需要前端连接后发送payload payload = await websocket.receive_json() completions_url = f"http://localhost:9380/api/v1/chats/{chat_id}/completions" await proxy_aichat_text_stream(connection_id, completions_url, payload) completed_successfully = True else: # 使用引擎的生成器直接获取音频流 async for data in tts_engine.get_audio_stream(audio_stream_id): if not await manager.send_bytes(connection_id, data): logging.warning(f"Send failed, connection closed: {connection_id}") break await manager.send_json(connection_id, {"command": "sentence_end"}) completed_successfully = True # 发送完成信号前检查连接状态 if manager.is_connected(connection_id): # 发送完成信号 await manager.send_json(connection_id, {"status": "completed"}) # 添加短暂延迟确保消息送达 await asyncio.sleep(0.1) # 主动关闭WebSocket连接 await manager.disconnect(connection_id, code=1000, reason="Audio stream completed") except WebSocketDisconnect: logging.info(f"WebSocket disconnected: {connection_id}") except Exception as e: logging.error(f"WebSocket TTS error: {str(e)}") if manager.is_connected(connection_id): await manager.send_json(connection_id, {"error": str(e)}) finally: pass # await manager.disconnect(connection_id) def cleanup_cache(): """清理过期缓存""" with cache_lock: now = datetime.datetime.now() expired = [k for k, v in audio_text_cache.items() if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS] for key in expired: logging.info(f"del audio_text_cache= {audio_text_cache[key]}") del audio_text_cache[key] # 应用启动时启动清理线程 # start_background_cleaner()