import logging import binascii from copy import deepcopy from timeit import default_timer as timer import datetime from datetime import timedelta import threading, time,queue,uuid,time,array from threading import Lock, Thread from concurrent.futures import ThreadPoolExecutor import base64, gzip import os,io, re, json from io import BytesIO from typing import Optional, Dict, Any from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query from fastapi import FastAPI, UploadFile, File, Form, Header from fastapi.responses import StreamingResponse,JSONResponse TTS_SAMPLERATE = 44100 # 22050 # 16000 FORMAT = "mp3" CHANNELS = 1 # 单声道 SAMPLE_WIDTH = 2 # 16-bit = 2字节 tts_router = APIRouter() logger = logging.getLogger(__name__) class StreamSessionManager: def __init__(self): self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}} self.lock = threading.Lock() self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池 self.gc_interval = 300 # 5分钟清理一次 self.gc_tts = 3 # 3s def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'): session_id = str(uuid.uuid4()) with self.lock: self.sessions[session_id] = { 'tts_model': tts_model, 'buffer': queue.Queue(maxsize=300), # 线程安全队列 'task_queue': queue.Queue(), 'active': True, 'last_active': time.time(), 'audio_chunk_count':0, 'finished': threading.Event(), # 添加事件对象 'sample_rate':sample_rate, 'stream_format':stream_format, "tts_chunk_data_valid":False } # 启动任务处理线程 threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start() return session_id def append_text(self, session_id, text): with self.lock: session = self.sessions.get(session_id) if not session: return # 将文本放入任务队列(非阻塞) try: session['task_queue'].put(text, block=False) except queue.Full: logging.warning(f"Session {session_id} task queue full") def _process_tasks(self, session_id): """任务处理线程(每个会话独立)""" while True: session = self.sessions.get(session_id) if not session or not session['active']: break try: # 合并多个文本块(最多等待50ms) texts = [] while len(texts) < 5: # 最大合并5个文本块 try: text = session['task_queue'].get(timeout=0.05) texts.append(text) except queue.Empty: break if texts: # 提交到线程池处理 future=self.executor.submit( self._generate_audio, session_id, ' '.join(texts) # 合并文本减少请求次数 ) future.result() # 等待转换任务执行完毕 # 会话超时检查 if time.time() - session['last_active'] > self.gc_interval: self.close_session(session_id) break if time.time() - session['last_active'] > self.gc_tts: session['finished'].set() break except Exception as e: logging.error(f"Task processing error: {str(e)}") def _generate_audio(self, session_id, text): """实际生成音频(线程池执行)""" session = self.sessions.get(session_id) if not session: return # logging.info(f"_generate_audio:{text}") first_chunk = True # logging.info(f"转换开始!!! {text}") try: for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']): if session['stream_format'] == 'wav': if first_chunk: chunk_len = len(chunk) if chunk_len > 2048: session['buffer'].put(audio_fade_in(chunk,1024)) else: session['buffer'].put(audio_fade_in(chunk, chunk_len)) first_chunk = False else: session['buffer'].put(chunk) else: session['buffer'].put(chunk) session['last_active'] = time.time() session['audio_chunk_count'] = session['audio_chunk_count'] + 1 if session['tts_chunk_data_valid'] is False: session['tts_chunk_data_valid'] = True #20250510 增加,表示连接TTS后台已经返回,可以通知前端了 logging.info(f"转换结束!!! {session['audio_chunk_count'] }") except Exception as e: session['buffer'].put(f"ERROR:{str(e)}") logging.info(f"--_generate_audio--error {str(e)}") def close_session(self, session_id): with self.lock: if session_id in self.sessions: # 标记会话为不活跃 self.sessions[session_id]['active'] = False # 延迟2秒后清理资源 threading.Timer(1, self._clean_session, args=[session_id]).start() def _clean_session(self, session_id): with self.lock: if session_id in self.sessions: del self.sessions[session_id] def get_session(self, session_id): return self.sessions.get(session_id) stream_manager = StreamSessionManager() def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} audio_text_cache = {} cache_lock = Lock() CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 """ @manager.route('/tts_stream/', methods=['GET']) def tts_stream(session_id): session = stream_manager.sessions.get(session_id) logging.info(f"--tts_stream {session}") if session is None: return get_error_data_result(message="Audio stream not found or expired.") def generate(): count = 0; finished_event = session['finished'] try: while not finished_event.is_set(): if not session or not session['active']: break try: chunk = session['buffer'].get_nowait() # count = count + 1 if isinstance(chunk, str) and chunk.startswith("ERROR"): logging.info(f"---tts stream error!!!! {chunk}") yield f"data:{{'error':'{chunk[6:]}'}}\n\n" break if session['stream_format'] == "wav": gzip_base64_data = encode_gzip_base64(chunk) + "\r\n" yield gzip_base64_data else: yield chunk retry_count = 0 # 成功收到数据重置重试计数器 except queue.Empty: if session['stream_format'] == "wav": pass else: pass except Exception as e: logging.info(f"tts streag get error2 {e} ") finally: # 确保流结束后关闭会话 if session: # 延迟关闭会话,确保所有数据已发送 stream_manager.close_session(session_id) logging.info(f"Session {session_id} closed.") # 关键响应头设置 if session['stream_format'] == "wav": resp = Response(stream_with_context(generate()), mimetype="audio/wav") else: resp = Response(stream_with_context(generate()), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") return resp """ def generate_mp3_header(bitrate_kbps=128, padding=0): # 字段定义 sync = 0b11111111111 # 同步字(11位) version = 0b11 # MPEG-1(2位) layer = 0b01 # Layer III(2位) protection = 0b0 # 无CRC(1位) bitrate_index = { # 比特率索引表(MPEG-1 Layer III) 32: 0b0001, 40: 0b0010, 48: 0b0011, 56: 0b0100, 64: 0b0101, 80: 0b0110, 96: 0b0111, 112: 0b1000, 128: 0b1001, 160: 0b1010, 192: 0b1011, 224: 0b1100, 256: 0b1101, 320: 0b1110 }[bitrate_kbps] sampling_rate = 0b00 # 44.1kHz(2位) padding_bit = padding # 填充位(1位) private = 0b0 # 私有位(1位) mode = 0b11 # 单声道(2位) mode_ext = 0b00 # 扩展模式(2位) copyright = 0b0 # 无版权(1位) original = 0b0 # 非原版(1位) emphasis = 0b00 # 无强调(2位) # 组合为32位整数(大端序) header = ( (sync << 21) | (version << 19) | (layer << 17) | (protection << 16) | (bitrate_index << 12) | (sampling_rate << 10) | (padding_bit << 9) | (private << 8) | (mode << 6) | (mode_ext << 4) | (copyright << 3) | (original << 2) | emphasis ) # 转换为4字节二进制数据 return header.to_bytes(4, byteorder='big') # ------------------------------------------------ def audio_fade_in(audio_data, fade_length): # 假设音频数据是16位单声道PCM # 将二进制数据转换为整数数组 samples = array.array('h', audio_data) # 对前fade_length个样本进行淡入处理 for i in range(fade_length): fade_factor = i / fade_length samples[i] = int(samples[i] * fade_factor) # 将整数数组转换回二进制数据 return samples.tobytes() def parse_markdown_json(json_string): # 使用正则表达式匹配Markdown中的JSON代码块 match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL) if match: try: # 尝试解析JSON字符串 data = json.loads(match[1]) return {'success': True, 'data': data} except json.JSONDecodeError as e: # 如果解析失败,返回错误信息 return {'success': False, 'data': str(e)} else: return {'success': False, 'data': 'not a valid markdown json string'} audio_text_cache = {} cache_lock = Lock() CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 # 全角字符到半角字符的映射 def fullwidth_to_halfwidth(s): full_to_half_map = { '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', '、': '、', '・': '・', ':': ':' } return ''.join(full_to_half_map.get(char, char) for char in s) def split_text_at_punctuation(text, chunk_size=100): # 使用正则表达式找到所有的标点符号和特殊字符 punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' tokens = re.split(punctuation_pattern, text) # 移除空字符串 tokens = [token for token in tokens if token] # 存储最终的文本块 chunks = [] current_chunk = '' for token in tokens: if len(current_chunk) + len(token) <= chunk_size: # 如果添加当前token后长度不超过chunk_size,则添加到当前块 current_chunk += (token + ' ') else: # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 chunks.append(current_chunk.strip()) current_chunk = token + ' ' # 添加最后一个块(如果有剩余) if current_chunk: chunks.append(current_chunk.strip()) return chunks def extract_text_from_markdown(markdown_text): # 移除Markdown标题 text = re.sub(r'#\s*[^#]+', '', markdown_text) # 移除内联代码块 text = re.sub(r'`[^`]+`', '', text) # 移除代码块 text = re.sub(r'```[\s\S]*?```', '', text) # 移除加粗和斜体 text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) # 移除链接 text = re.sub(r'\[.*?\]\(.*?\)', '', text) # 移除图片 text = re.sub(r'!\[.*?\]\(.*?\)', '', text) # 移除HTML标签 text = re.sub(r'<[^>]+>', '', text) # 转换标点符号 # text = re.sub(r'[^\w\s]', '', text) text = fullwidth_to_halfwidth(text) # 移除多余的空格 text = re.sub(r'\s+', ' ', text).strip() return text def encode_gzip_base64(original_data: bytes) -> str: """核心编码过程:二进制数据 → Gzip压缩 → Base64编码""" # Step 1: Gzip 压缩 with BytesIO() as buf: with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file: gz_file.write(original_data) compressed_bytes = buf.getvalue() # Step 2: Base64 编码(配置与Android端匹配) return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP) def clean_audio_cache(): """定时清理过期缓存""" with cache_lock: now = time.time() expired_keys = [ k for k, v in audio_text_cache.items() if now - v['created_at'] > CACHE_EXPIRE_SECONDS ] for k in expired_keys: entry = audio_text_cache.pop(k, None) if entry and entry.get('audio_stream'): entry['audio_stream'].close() def start_background_cleaner(): """启动后台清理线程""" def cleaner_loop(): while True: time.sleep(180) # 每3分钟清理一次 clean_audio_cache() cleaner_thread = Thread(target=cleaner_loop, daemon=True) cleaner_thread.start() def test_qwen_chat(): messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, {'role': 'user', 'content': '你是谁?'} ] response = Generation.call( # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx", api_key=ALI_KEY, model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models messages=messages, result_format="message" ) if response.status_code == 200: print(response.output.choices[0].message.content) else: print(f"HTTP返回码:{response.status_code}") print(f"错误码:{response.code}") print(f"错误信息:{response.message}") print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code") ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75" class QwenTTS: def __init__(self, key, model_name="cosyvoice-v1/longxiaochun", base_url=""): import dashscope import ssl print("---begin--init QwenTTS--") # cyx self.model_name = model_name dashscope.api_key = key ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证 self.synthesizer = None self.callback = None self.is_cosyvoice = False self.voice = "" if '/' in model_name: parts = model_name.split('/', 1) # 返回分离后的两个字符串parts[0], parts[1] if parts[0] == 'cosyvoice-v1': self.is_cosyvoice = True self.voice = parts[1] def tts(self, text): from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse if self.is_cosyvoice is False: from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult from collections import deque else: # cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本 from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat # , SpeechSynthesisResult from dashscope.audio.tts import SpeechSynthesisResult from collections import deque print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx class Callback(ResultCallback): def __init__(self) -> None: self.dque = deque() def _run(self): while True: if not self.dque: time.sleep(0) continue val = self.dque.popleft() if val: yield val else: break def on_open(self): pass def on_complete(self): self.dque.append(None) def on_error(self, response: SpeechSynthesisResponse): print("Qwen tts error", str(response)) raise RuntimeError(str(response)) def on_close(self): pass def on_event(self, result: SpeechSynthesisResult): if result.get_audio_frame() is not None: self.dque.append(result.get_audio_frame()) # -------------------------- class Callback_v2(ResultCallback): def __init__(self) -> None: self.dque = deque() def _run(self): while True: if not self.dque: time.sleep(0) continue val = self.dque.popleft() if val: yield val else: break def on_open(self): logging.info("Qwen tts open") pass def on_complete(self): self.dque.append(None) def on_error(self, response: SpeechSynthesisResponse): print("Qwen tts error", str(response)) raise RuntimeError(str(response)) def on_close(self): # print("---Qwen call back close") # cyx logging.info("Qwen tts close") pass """ canceled for test 语音大模型CosyVoice def on_event(self, result: SpeechSynthesisResult): if result.get_audio_frame() is not None: self.dque.append(result.get_audio_frame()) """ def on_event(self, message): # print(f"recv speech synthsis message {message}") pass # 以下适合语音大模型CosyVoice def on_data(self, data: bytes) -> None: if len(data) > 0: self.dque.append(data) # -------------------------- # text = self.normalize_text(text) try: # if self.model_name != 'cosyvoice-v1': if self.is_cosyvoice is False: self.callback = Callback() SpeechSynthesizer.call(model=self.model_name, text=text, callback=self.callback, format="wav") # format="mp3") else: self.callback = Callback_v2() format =self.get_audio_format(FORMAT,TTS_SAMPLERATE) self.synthesizer = SpeechSynthesizer( model='cosyvoice-v1', # voice="longyuan", #"longfei", voice=self.voice, callback=self.callback, format=format ) self.synthesizer.call(text) except Exception as e: print(f"---dale---20 error {e}") # cyx # ----------------------------------- try: for data in self.callback._run(): yield data # print(f"---Qwen return data {num_tokens_from_string(text)}") # yield num_tokens_from_string(text) except Exception as e: raise RuntimeError(f"**ERROR**: {e}") def get_audio_format(self, format: str, sample_rate: int): """动态获取音频格式""" from dashscope.audio.tts_v2 import AudioFormat logging.info(f"--get_audio_format-- {format} {sample_rate}") format_map = { (8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS, (8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT, (8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT, (16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT, (22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS, (22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT, (22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT, (44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS, (44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT, (44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT, } return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS) def end_tts(self): if self.synthesizer: self.synthesizer.streaming_complete() @tts_router.get("/audio/pcm_mp3") async def stream_mp3(): def audio_generator(): path = './test.mp3' try: with open(path, 'rb') as f: while True: chunk = f.read(1024) if not chunk: break yield chunk except Exception as e: logging.error(f"MP3 streaming error: {str(e)}") return StreamingResponse( audio_generator(), media_type="audio/mpeg", headers={"Cache-Control": "no-store"} ) def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes: """动态生成WAV头(严格保持原有逻辑结构)""" with BytesIO() as wav_buffer: with wave.open(wav_buffer, 'wb') as wav_file: wav_file.setnchannels(1) # 保持原单声道设置 wav_file.setsampwidth(2) # 保持原16-bit设置 wav_file.setframerate(sample_rate) wav_file.writeframes(pcm_data) wav_buffer.seek(0) return wav_buffer.read() def generate_silence_header(duration_ms: int = 500) -> bytes: """生成静音数据(用于MP3流式传输预缓冲)""" num_samples = int(TTS_SAMPLERATE * duration_ms / 1000) return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS @tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}") async def get_tts_audio( chat_id: str, audio_stream_id: str, request: Request, range: str = Header(None) # 新增Range头解析 ): with cache_lock: # tts_info = audio_text_cache.pop(audio_stream_id, None) tts_info = audio_text_cache.get(audio_stream_id, None) if not tts_info: raise HTTPException(404, detail="音频流已过期") audio_stream_len = tts_info.get('audio_stream_len', 0) audio_stream = tts_info['audio_stream'] text = tts_info['text'] model_name = tts_info.get('model_name', "cosyvoice-v1/longxiaochun") format = tts_info['format'] # 新增字段 sample_rate = tts_info['sample_rate'] # 新增字段 def stream_audio(): total = 0 try: for chunk in tts_mdl.tts(text): # print(f"data_length={total} {chunk}") hex_data = binascii.hexlify(chunk).decode("utf-8") + "\r\n" # <--- 添加\n # yield f"{hex_data}" # <--- SSE格式封装 print(f"yield {len(chunk)}") # yield chunk gzip_base64_data = encode_gzip_base64(chunk) + "\r\n" total = total + len(gzip_base64_data) yield gzip_base64_data print(f"tts gen end {total}") except Exception as e: yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode('utf-8') def read_buffer(): print(f"get audio_stream {audio_stream} {audio_stream.closed}") audio_stream.seek(0) # 关键重置操作 data = audio_stream.read(1024) while data: yield data data = audio_stream.read(1024) def generate_silence_header(): """生成500ms的静音MP3帧(约44100Hz)""" # 示例静音数据(实际需生成标准MP3静音帧) return b'\xff\xfb\xd4\x00' * 100 # 约400字节 # 保持原有生成器结构 def generate_audio(): # 保持原有QwenTTS调用方式 tts = QwenTTS(ALI_KEY, model_name) # 临时修改全局配置(保持原有逻辑) global FORMAT, TTS_SAMPLERATE original_format = FORMAT original_sample_rate = TTS_SAMPLERATE FORMAT = format TTS_SAMPLERATE = sample_rate try: for chunk in tts.tts(text): if format == 'wav': yield add_wav_header(chunk, sample_rate) else: yield chunk finally: # 恢复全局变量 FORMAT = original_format TTS_SAMPLERATE = original_sample_rate # 保持原有流响应逻辑 media_type_map = { 'mp3': 'audio/mpeg', 'wav': 'audio/wav', 'pcm': f'audio/L16; rate={sample_rate}; channels=1' } # 处理Range请求逻辑 if range: start =0 end =audio_stream_len total_length = audio_stream_len return StreamingResponse( read_buffer(), status_code=206, headers={ "Accept-Ranges": "bytes", "Content-Range": f"bytes {start}-{end}/{total_length}", "Cache-Control": "no-store" # 确保不缓存动态内容 } ) else: try: # 从缓存获取原始参数 if audio_stream is None: tts_mdl = QwenTTS(ALI_KEY, "cosyvoice-v1/longxiaochun") if audio_stream: logging.info("return audio stream buffer") # 确保流的位置在开始处 return StreamingResponse( read_buffer(), media_type=media_type_map[format], headers={ "Transfer-Encoding": "chunked", "Cache-Control": "no-store", "Content-Disposition": "inline", "Access-Control-Allow-Origin": "*" } ) else: return StreamingResponse( generate_audio(), media_type=media_type_map[format], headers={ "Transfer-Encoding": "chunked", "Cache-Control": "no-store", "Content-Disposition": "inline" } ) except Exception as e: logging.error(f"音频流错误: {str(e)}") raise HTTPException(500, detail="音频生成失败") @tts_router.post("/chats/{chat_id}/tts") async def create_tts_request(chat_id: str, request: Request): try: request_data = await request.json() # 获取并验证新参数 format = request_data.get('format', FORMAT) if format not in ['mp3', 'wav', 'pcm']: raise HTTPException(400, detail="不支持的音频格式") sample_rate = request_data.get('sample_rate', TTS_SAMPLERATE) if sample_rate not in [8000, 16000, 22050, 44100]: raise HTTPException(400, detail="不支持的采样率") # 保持原有参数处理 text = request_data.get("text") if not text or not text.strip(): raise HTTPException(400, detail="文本内容不能为空") # 存储新增参数到缓存 audio_stream_id = str(uuid.uuid4()) tts_info = { 'text': text, 'chat_id': chat_id, 'created_at': time.time(), 'audio_stream': None, 'model_name': request_data.get('model_name'), 'format': format, 'sample_rate': sample_rate } audio_stream_len = 0 # 保持原有延迟生成逻辑 if request_data.get('delay_gen_audio', False) is False: try: # 临时设置全局变量(保持原有逻辑) original_format = FORMAT original_sample_rate = TTS_SAMPLERATE tts = QwenTTS(ALI_KEY, tts_info.get('model_name', "cosyvoice-v1/longxiaochun")) buffer = io.BytesIO() for chunk in tts.tts(text): audio_stream_len = audio_stream_len + len(chunk) buffer.write(chunk) buffer.seek(0) tts_info['audio_stream'] = buffer tts_info['audio_stream_len'] = audio_stream_len finally: # 恢复全局变量 pass # 保持原有缓存逻辑 with cache_lock: audio_text_cache[audio_stream_id] = tts_info logging.info(f"create tts stream return {audio_stream_id} {audio_text_cache[audio_stream_id]}") # 保持原响应结构 return { "tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}", "audio_stream_id": audio_stream_id, "sample_rate": sample_rate } except Exception as e: logging.error(f"请求处理失败: {str(e)}") raise HTTPException(500, detail="内部服务器错误") # 辅助函数 # ------------------------ API路由 ------------------------ @tts_router.post("/chats1/{chat_id}/tts") async def create_tts_request1(chat_id: str, request: Request): """创建TTS音频流""" try: data = await request.json() logging.info(f"API--create_tts_request1-- {data}") # 参数校验 text = data.get("text", "").strip() if not text: raise HTTPException(400, detail="文本内容不能为空") format = data.get("format", "mp3") if format not in ["mp3", "wav", "pcm"]: raise HTTPException(400, detail="不支持的音频格式") sample_rate = data.get("sample_rate", 16000) if sample_rate not in [8000, 16000, 22050, 44100]: raise HTTPException(400, detail="不支持的采样率") format ="mp3" sample_rate = 44100 # 生成音频流 audio_stream_id = str(uuid.uuid4()) buffer = io.BytesIO() tts = QwenTTS(ALI_KEY) try: for chunk in tts.tts(text): buffer.write(chunk) buffer.seek(0) except Exception as e: logging.error(f"TTS生成失败: {str(e)}") raise HTTPException(500, detail="音频生成失败") # 存储到缓存 with cache_lock: audio_text_cache[audio_stream_id] = { "buffer": buffer, "format": format, "sample_rate": sample_rate, "created_at": datetime.datetime.now(), "size": buffer.getbuffer().nbytes } return JSONResponse( status_code=200, content={ "tts_url":f"/chats1/{chat_id}/tts/{audio_stream_id}", "url": f"/chats1/{chat_id}/tts/{audio_stream_id}", "expires_at": (datetime.datetime.now() + timedelta(seconds=CACHE_EXPIRE_SECONDS)).isoformat() } ) except Exception as e: logging.error(f"请求处理失败: {str(e)}") raise HTTPException(500, detail="服务器内部错误") @tts_router.get("/chats1/{chat_id}/tts/{audio_stream_id}") async def get_tts_audio1( chat_id: str, audio_stream_id: str, range: str = Header(None) ): """获取音频流""" # 清理过期缓存 cleanup_cache() # 获取缓存 with cache_lock: item = audio_text_cache.get(audio_stream_id) if not item: raise HTTPException(404, detail="音频流不存在或已过期") # 准备响应参数 buffer = item["buffer"] format = item["format"] total_size = item["size"] media_type = { "mp3": "audio/mpeg", "wav": "audio/wav", "pcm": f"audio/L16; rate={item['sample_rate']}; channels=1" }[format] # 处理范围请求 if range: return handle_range_request(range, buffer, total_size, media_type) # 完整文件响应 return StreamingResponse( iter(lambda: buffer.read(4096), b""), media_type=media_type, headers={ "Accept-Ranges": "bytes", "Content-Length": str(total_size), "Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}" } ) def handle_range_request(range: str, buffer: BytesIO, total_size: int, media_type: str): """处理HTTP范围请求""" try: unit, ranges = range.split("=") if unit != "bytes": raise ValueError start_str, end_str = ranges.split("-") start = int(start_str) if start_str else 0 end = int(end_str) if end_str else total_size - 1 # 验证范围有效性 if start >= total_size or end >= total_size: raise HTTPException(416, detail="请求范围无效", headers={ "Content-Range": f"bytes */{total_size}" }) content_length = end - start + 1 buffer.seek(start) def chunk_generator(): remaining = content_length while remaining > 0: chunk = buffer.read(min(4096, remaining)) if not chunk: break yield chunk remaining -= len(chunk) return StreamingResponse( chunk_generator(), status_code=206, headers={ "Content-Range": f"bytes {start}-{end}/{total_size}", "Content-Length": str(content_length), "Content-Type": media_type, "Accept-Ranges": "bytes" } ) except ValueError: raise HTTPException(400, detail="无效的Range头") def cleanup_cache(): """清理过期缓存""" with cache_lock: now = datetime.datetime.now() expired = [k for k, v in audio_text_cache.items() if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS] for key in expired: del audio_text_cache[key] # 应用启动时启动清理线程 # start_background_cleaner()