diff --git a/api/apps/sdk/dale_extra.py b/api/apps/sdk/dale_extra.py index 56766672..6491357d 100644 --- a/api/apps/sdk/dale_extra.py +++ b/api/apps/sdk/dale_extra.py @@ -31,7 +31,7 @@ from api.utils.api_utils import get_result from api.utils.file_utils import get_project_base_directory from rag.utils.minio_conn import RAGFlowMinio import logging -import base64, gzip +import base64, gzip,json import io, re, json from io import BytesIO import queue,time,uuid,os,array @@ -135,13 +135,13 @@ def upload_file(tenant_id,mesum_id): mesum_id_str = str(mesum_id) labels_with_id = get_labels_with_id(mesum_id) - antique_labels = ','.join([item['label'] for item in labels_with_id]) + antique_labels = ';'.join([item['label'] for item in labels_with_id]) # 使用分号分隔 joined_string = antique_labels antiques_selected = f"{joined_string}" #logging.info(f"mesumid={mesum_id} {joined_string}") - prompt1 = (f"你是一名图片识别和理解助手" + prompt1= (f"你是一名图片识别和理解助手" f"任务是先识别图片中文字,然后理解文字中包含的内容,分析哪一项可以作为识别出文字的标题," f"你的回答有3个结果,第一个结果匹配出的结果,JSON键值为antique" f"从下面的候选项:{antiques_selected}进行匹配,每一个候选项中间以';'分割,如果没有任何匹配则结果为'',以免误触发讲解,匹配成功则输出匹配出的内容" @@ -151,25 +151,29 @@ def upload_file(tenant_id,mesum_id): f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据") prompt = ( - f"作为图片识别和理解助手,您的任务是:" - f"\n1. 图片基本上就是展品标题、历史人物或者历史事件" - f"\n2. 精确识别图片中的文字内容,理解文字语义,重点分析字体较大的文字" - f"\n3. 识别出的文字包含标题或者接近于标题的文字" - f"\n4. 从以下候选标题中选择最佳匹配项:" - f"\n {antiques_selected}" - f"\n\n### 输出要求:" - f"\n- 以严格JSON格式输出,包含3个字段:" - f"\n • `antique`: 匹配的标题(多个用英文分号';'分割,最多匹配3个,无匹配则空字符串)" - f"\n • `text`: 识别出的完整文字" - f"\n • `match_score`: 整体匹配度(0-1的浮点数),1=完全匹配" - f"\n\n### 匹配规则:" - f"\n1. 语义匹配优先于字面匹配" - f"\n2. 考虑同义词、近义词和描述性匹配" - f"\n3. 允许部分匹配(如'青铜酒器'匹配'青铜器')" - f"\n4. 若无明确匹配项,`antique`返回空字符串" - f"\n\n### 重要:" - f"\n- 输出必须是可直接解析的JSON,无任何前置/后置文本" - f"\n- 匹配度评分需客观反映文本与候选标题的相似度" + f"作为博物馆展品识别专家,您的任务是:" + f"\n1. 识别图片中的文字内容,重点关注展品标题(通常是最大/最显眼的文字)" + f"\n2. 从以下候选标题中匹配最佳项:{antiques_selected}" + f"\n3. 匹配规则:" + f"\n - 优先匹配完整标题(如'铜踵饰残片'匹配'铜踵饰残片')" + f"\n - 其次匹配关键词(如'刻辞卜骨'可匹配'刻辞卜骨')" + f"\n - 允许部分匹配(如'铜器'匹配'青铜器')" + f"\n - 忽略拼音、英文和次要描述文字" + f"\n - 如果近似,不好区分,则输出数组供前端选择,如:青铜车䡇匹配青铜车䡇;青铜车軛一对" + f"\n4. 输出要求:" + f"\n - 匹配结果最多不超过5个" + f"\n - 用英文分号';'分隔多个匹配项" + f"\n\n输出严格JSON格式:" + f"\n{'{'}" + f"\n \"antique\": \"匹配结果(多个用分号分隔)\"," + f"\n \"text\": \"识别出的完整文字\"," + f"\n \"match_score\": 整体匹配度(0-1)" + f"\n{'}'}" + f"\n\n示例:" + f"\n候选标题:青铜车䡇;玉虎;甲骨文;刻辞卜骨" + f"\n识别文字:『青铜车䡇』商代..." + f"\n正确输出:" + f"\n{'{'}\"antique\": \"青铜车轼\", \"text\": \"青铜车轼 (yue)...\", \"match_score\": 0.95{'}'}" ) file = request.files['file'] @@ -212,7 +216,7 @@ def upload_file(tenant_id,mesum_id): ] ) - """ + client = OpenAI( api_key="sk-a47a3fb5f4a94f66bbaf713779101c75", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", @@ -238,20 +242,58 @@ def upload_file(tenant_id,mesum_id): }, ], ) + """ + vl_model = "doubao-1-5-thinking-vision-pro-250428" + client = OpenAI( + api_key="1e04d30a-0c56-4dbd-b873-53f26649c64f", + base_url="https://ark.cn-beijing.volces.com/api/v3", + ) + response = client.chat.completions.create( + # 指定您创建的方舟推理接入点 ID,此处已帮您修改为您的推理接入点 ID + model=vl_model, + messages=[ + { + "role": "system", + "content": [{"type": "text", "text": prompt}], + }, + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{img_base}" + }, + } + ], + } + ], + ) message = response.choices[0].message - parsed_json_res = parse_markdown_json(message.content) parsed_json_data = {"antique": "", "text": "", "match_score": 0} matchedArray = [] + try: + if isinstance(message.content, str): + parsed_json_res = parse_markdown_json(message.content) # 优先识别带有markdown格式 + if parsed_json_res.get('success') is False: # 如果识别失败,再识别普通json格式(字符串) + parsed_json_data = json.loads(message.content) + parsed_json_res['success'] = True + parsed_json_res['data'] = parsed_json_data + except Exception as e: + pass + #logging.info(f"识别完成 {message.content} {parsed_json_data} ") + if parsed_json_res.get('success') is True: parsed_json_data = parsed_json_res.get('data') matchedAntiqueArray = parsed_json_data.get('antique').split(';') # 识别出的文物的数组,中间以';'分割,可能有多个 if len(matchedAntiqueArray) ==1: # 只有一个匹配项,直接返回 + logging.info(f"识别完成 得到1个,{parsed_json_data.get('antique')} {labels_with_id} ") for item in labels_with_id: if item['label'] == parsed_json_data.get('antique'): parsed_json_data['id'] = item.get('id') else: # 有多个匹配项,需要进行多个匹配 - for label in matchedAntiqueArray: + for label in matchedAntiqueArray[:5]: antique= {'label':label} for item in labels_with_id: if item['label'] == label: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index dfcac843..b42a60c4 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -115,6 +115,7 @@ def update(tenant_id, chat_id, session_id): @token_required def completion(tenant_id, chat_id): # chat_id 和 别的文件中的dialog_id 应该是一个意思? cyx 2025-01-25 req = request.json + logging.info(f"/chats/{chat_id}/completions--0 req={req}") # cyx if not req.get("session_id"): # session_id 和 别的文件中的conversation_id 应该是一个意思? cyx 2025-01-25 conv = { "id": get_uuid(), diff --git a/api/db/db_models.py b/api/db/db_models.py index 8edf3c7b..f7b4a06b 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -1042,6 +1042,7 @@ class MesumOverview(DataBaseModel): index=False) free = IntegerField(default=0, index=False) + hotspot_rank = IntegerField(default=50, index=False) def __str__(self): return self.name diff --git a/asr-monitor-test/Dockerfile b/asr-monitor-test/Dockerfile index 27967302..490b3442 100644 --- a/asr-monitor-test/Dockerfile +++ b/asr-monitor-test/Dockerfile @@ -34,5 +34,5 @@ COPY app ./app EXPOSE 9580 # 启动命令 -CMD ["python3", "-m", "app.main"] -# CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "9480"] +#CMD ["python3", "-m", "app.main"] +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "9580"] diff --git a/asr-monitor-test/app.log b/asr-monitor-test/app.log index dc8857b9..52a2d065 100644 --- a/asr-monitor-test/app.log +++ b/asr-monitor-test/app.log @@ -1,15 +1,142 @@ -INFO: Started server process [2877053] +INFO: Started server process [175422] INFO: Waiting for application startup. -16:01:53.789 - INFO - 监控服务已启动 +14:09:53.459 - INFO - 监控服务已启动 INFO: Application startup complete. INFO: Uvicorn running on http://0.0.0.0:9580 (Press CTRL+C to quit) -WARNING: Invalid HTTP request received. +14:10:04.372 - INFO - verify_token user={'user_id': '76538cf0-a6cf-4aa8-8440-382dd2330384', 'openid': 'obKSz7V6a-avAF-vtQrnk_rnuSGE', 'phone': '18676776176', 'email': None, 'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI3NjUzOGNmMC1hNmNmLTRhYTgtODQ0MC0zODJkZDIzMzAzODQiLCJleHAiOjE3NTM2NzkxMzB9.k-lALo9ulLGnu5O9qZALEp45F2loDnfdBZ09C9vglIw', 'balance': 0, 'status': 1, 'last_login_time': 1753074330, 'create_time': 1748960538, 'create_date': datetime.datetime(2025, 6, 3, 22, 22, 18), 'update_time': 1753074330, 'update_date': datetime.datetime(2025, 7, 21, 13, 5, 30)} ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ ASR & Monitor Service Start ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ -INFO: 91.196.152.109:33535 - "GET / HTTP/1.1" 404 Not Found -INFO: 220.196.160.51:32110 - "GET / HTTP/1.1" 404 Not Found -INFO: 180.101.245.250:54556 - "GET / HTTP/1.1" 404 Not Found -INFO: 220.196.160.75:44896 - "GET / HTTP/1.1" 404 Not Found -INFO: 91.196.152.210:40043 - "GET /favicon.ico HTTP/1.1" 404 Not Found -WARNING: Invalid HTTP request received. +tts_service路由器正在启动... +INFO: 43.140.60.33:0 - "GET /auth/verify HTTP/1.1" 200 OK +INFO: 43.144.107.28:0 - "POST /payment/get_user_museum_subscriptions HTTP/1.1" 200 OK +14:10:08.416 - INFO - Creating TTS request: {'text': '提梁三足铜盉,5.2749,商,通梁高22.5厘米,高19.5厘米,口径4.8厘米,腹径11.0厘米,平谷区南独乐河公社刘家河大队出土,首都博物馆藏。铜盉为长颈,圆鼓腹,三足,有一绦状提梁,肩部附一圆柱形流。提梁盉在郑州和湖北盘龙城的二里岗期铜器中均未见。安阳殷墟妇好墓中出有一件, 其 制与此盉有所不同,花纹亦比较精美。 此盉可能是与殷墟二期提梁盉的过渡形态。', 'session_id': '0b4cdbaeaf9111efa53df171065841e8', 'delay_gen_audio': True, 'tts_sample_rate': 8000, 'tts_stream_format': 'mp3', 'model_name': 'sambert-zhichu-v1@Tongyi-Qianwen', 'sample_rate': 8000, 'stream_format': 'mp3'} +INFO: 43.140.60.33:0 - "POST /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts?device_id=17528308107741998517 HTTP/1.1" 200 OK +INFO: ('1.13.185.116', 40190) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/9dd9cff1-093c-4b9f-a3d6-a701fdcf7ad3" [accepted] +14:10:09.160 - INFO - 新连接建立: b9980686-a10d-4082-9d13-bf1260ea95b7 +14:10:09.160 - INFO - 已经启动 start tts task {audio_stream_id} +14:10:09.161 - INFO - ---begin--init QwenTTS-- mp3 8000 sambert-zhichu-v1@Tongyi-Qianwen +14:10:09.161 - INFO - Qwen text_tts_call 提梁三足铜盉,5.2749,商,通梁高22.5厘米,高19.5厘米,口径4.8厘米,腹径11.0厘米,平谷区南独乐河公社刘家河大队出土,首都博物馆藏。铜盉为长颈,圆鼓腹,三足,有一绦状提梁,肩部附一圆柱形流。提梁盉在郑州和湖北盘龙城的二里岗期铜器中均未见。安阳殷墟妇好墓中出有一件, 其 制与此盉有所不同,花纹亦比较精美。 此盉可能是与殷墟二期提梁盉的过渡形态。 +INFO: connection open +INFO: connection closed +14:10:11.978 - INFO - 发送时检测到断开连接: b9980686-a10d-4082-9d13-bf1260ea95b7, +14:10:11.978 - WARNING - Send failed, connection closed: b9980686-a10d-4082-9d13-bf1260ea95b7 +14:10:11.978 - WARNING - 尝试向不存在的连接发送数据: b9980686-a10d-4082-9d13-bf1260ea95b7 +14:10:12.091 - INFO - --data_handler on_complete +14:10:12.091 - INFO - --tts task event set error = None +14:10:12.091 - INFO - UnifiedTTSEngine _run_tts_sync finally +INFO: ('1.13.185.116', 40194) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/0770c61f-3a83-4c68-9ac8-e304405c0eef" [accepted] +14:10:12.371 - INFO - 新连接建立: eb86ddd7-224b-45a9-a0d0-081c52b5b17e +INFO: connection open +14:10:12.493 - INFO - 代理文本流: completions_url=http://127.0.0.1:9380/api/v1/chats/39e9a2ba5a4711f0865bbb55c66f9471/completions {'question': '请介绍提梁三足铜盉', 'stream': True, 'tts_model': 'sambert-zhichu-v1@Tongyi-Qianwen', 'tts_sample_rate': 8000, 'tts_stream_format': 'mp3', 'tts_disable': True} +14:10:15.632 - INFO - HTTP Request: POST http://127.0.0.1:9380/api/v1/chats/39e9a2ba5a4711f0865bbb55c66f9471/completions "HTTP/1.1 200 OK" +14:10:15.632 - INFO - 响应状态: HTTP 200 +14:10:15.632 - INFO - ---begin--init QwenTTS-- mp3 8000 sambert-zhichu-v1@Tongyi-Qianwen +14:10:15.632 - INFO - --StreamSessionManager create_session last_active=1753078215.632958 +14:10:15.633 - INFO - 开始处理SSE流 +14:10:15.633 - INFO - --proxy_aichat_text_stream 发送audio_stream_url +14:10:15.683 - INFO - Qwen text_tts_call 提梁三足铜盉是商代文物, +INFO: ('1.13.185.116', 54072) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/06af3ad0-024e-4e6d-a3ea-71373ff05280" [accepted] +14:10:16.294 - INFO - 新连接建立: be24c0be-8a30-460c-99ff-747745d2c522 +INFO: connection open +14:10:17.144 - INFO - Qwen text_tts_call 通梁高22.5厘米,高19.5厘米,口径4.8厘米,腹径11.0厘米。出土于平谷区南独乐河公社刘家河大队,现藏于首都博物馆。 +14:10:19.318 - INFO - Qwen text_tts_call 其特征为长颈圆鼓腹三足,有一绦状提梁,肩部附一圆柱形流。 +14:10:20.791 - INFO - SSE流处理完成,事件数: 12 +INFO: connection closed +14:10:21.059 - INFO - Qwen text_tts_call 此盉在郑州和湖北盘龙城的二里岗期铜器中未见,与安阳殷墟妇好墓出土的提梁盉有所不同,花纹较精美,可能是殷墟二期提梁盉的过渡形态。 +INFO: connection closed +14:10:21.394 - INFO - 发送时检测到断开连接: be24c0be-8a30-460c-99ff-747745d2c522, +14:10:21.394 - INFO - --- proxy AiChatTts audio_data_size=94212 +14:10:23.002 - INFO - 清理资源: 0会话 +14:10:24.462 - INFO - Creating TTS request: {'text': '玉牛首,商,通长3.7厘米,宽2.7厘米,河南省信阳市罗山县天湖墓地出土,信阳博物馆藏。\n', 'session_id': '0b4cdbaeaf9111efa53df171065841e8', 'delay_gen_audio': True, 'tts_sample_rate': 8000, 'tts_stream_format': 'mp3', 'model_name': 'sambert-zhichu-v1@Tongyi-Qianwen', 'sample_rate': 8000, 'stream_format': 'mp3'} +INFO: 43.140.60.44:0 - "POST /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts?device_id=17528308107741998517 HTTP/1.1" 200 OK +INFO: ('1.13.185.116', 33942) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/5286022d-3c15-4aa0-a11d-6af6823c89ae" [accepted] +14:10:24.737 - INFO - 新连接建立: 0f8be756-1592-4fe2-958e-074819796722 +14:10:24.737 - INFO - 已经启动 start tts task {audio_stream_id} +14:10:24.737 - INFO - ---begin--init QwenTTS-- mp3 8000 sambert-zhichu-v1@Tongyi-Qianwen +14:10:24.737 - INFO - Qwen text_tts_call 玉牛首,商,通长3.7厘米,宽2.7厘米,河南省信阳市罗山县天湖墓地出土,信阳博物馆藏。 +INFO: connection open +14:10:25.734 - INFO - --data_handler on_complete +14:10:25.734 - INFO - --tts task event set error = None +14:10:25.734 - INFO - UnifiedTTSEngine _run_tts_sync finally +INFO: connection closed +INFO: ('1.13.185.116', 33958) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/d3fe9c0c-45ab-4821-98b4-ef28ad6b7264" [accepted] +14:10:30.629 - INFO - 新连接建立: 64ac60ed-59ee-4d39-ab67-c92a583aacd6 +INFO: connection open +14:10:30.690 - INFO - 代理文本流: completions_url=http://127.0.0.1:9380/api/v1/chats/39e9a2ba5a4711f0865bbb55c66f9471/completions {'question': '请介绍玉牛首', 'stream': True, 'tts_model': 'sambert-zhichu-v1@Tongyi-Qianwen', 'tts_sample_rate': 8000, 'tts_stream_format': 'mp3', 'tts_disable': True} +14:10:33.444 - INFO - HTTP Request: POST http://127.0.0.1:9380/api/v1/chats/39e9a2ba5a4711f0865bbb55c66f9471/completions "HTTP/1.1 200 OK" +14:10:33.444 - INFO - 响应状态: HTTP 200 +14:10:33.444 - INFO - ---begin--init QwenTTS-- mp3 8000 sambert-zhichu-v1@Tongyi-Qianwen +14:10:33.445 - INFO - --StreamSessionManager create_session last_active=1753078233.4450512 +14:10:33.445 - INFO - 开始处理SSE流 +14:10:33.446 - INFO - --proxy_aichat_text_stream 发送audio_stream_url +14:10:33.495 - INFO - Qwen text_tts_call 玉牛首是商代的一件玉器,通长3.7厘米, +INFO: ('1.13.185.116', 33966) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/dfa471ae-b608-48f8-aa38-f3da27aa9da9" [accepted] +14:10:33.919 - INFO - 新连接建立: 2cb47c8b-842c-47a8-b221-42caf74e089a +INFO: connection open +14:10:35.077 - INFO - Qwen text_tts_call 宽2.7厘米,出土于河南省信阳市罗山县天湖墓地,现藏于信阳博物馆。 +14:10:36.370 - INFO - SSE流处理完成,事件数: 9 +INFO: connection closed +14:10:36.911 - INFO - Qwen text_tts_call 这件玉器以牛首为题材,雕刻精细,形象生动,反映了商代玉器工艺的高超水平。玉牛首不仅是珍贵的文物,也是研究商代文化和艺术的重要资料。 +INFO: connection closed +14:10:37.301 - INFO - 发送时检测到断开连接: 2cb47c8b-842c-47a8-b221-42caf74e089a, +14:10:37.301 - INFO - --- proxy AiChatTts audio_data_size=49671 +14:10:38.255 - INFO - verify_token user={'user_id': '76538cf0-a6cf-4aa8-8440-382dd2330384', 'openid': 'obKSz7V6a-avAF-vtQrnk_rnuSGE', 'phone': '18676776176', 'email': None, 'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI3NjUzOGNmMC1hNmNmLTRhYTgtODQ0MC0zODJkZDIzMzAzODQiLCJleHAiOjE3NTM2NzkxMzB9.k-lALo9ulLGnu5O9qZALEp45F2loDnfdBZ09C9vglIw', 'balance': 0, 'status': 1, 'last_login_time': 1753074330, 'create_time': 1748960538, 'create_date': datetime.datetime(2025, 6, 3, 22, 22, 18), 'update_time': 1753074330, 'update_date': datetime.datetime(2025, 7, 21, 13, 5, 30)} +INFO: 43.144.107.210:0 - "GET /auth/verify HTTP/1.1" 200 OK +INFO: 43.144.107.28:0 - "POST /payment/get_user_museum_subscriptions HTTP/1.1" 200 OK +14:10:53.032 - INFO - 清理资源: 0会话 +14:11:23.062 - INFO - 清理资源: 0会话 +14:11:53.090 - INFO - 清理资源: 0会话 +14:12:23.118 - INFO - 清理资源: 0会话 +14:12:53.136 - INFO - 清理资源: 0会话 +14:13:02.462 - INFO - verify_token user={'user_id': '9d0ba2ee-6821-4245-aadc-ea270313d92b', 'openid': 'obKSz7csdQWD7l4tqm7w_aoySrkM', 'phone': '18146525018', 'email': None, 'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI5ZDBiYTJlZS02ODIxLTQyNDUtYWFkYy1lYTI3MDMxM2Q5MmIiLCJleHAiOjE3NTMwODcwMTd9.6z1_LMFK0vebxkel1H27Nx9OCyeHaOzASpNhL27Gjkk', 'balance': 0, 'status': 1, 'last_login_time': 1752482217, 'create_time': 1751012368, 'create_date': datetime.datetime(2025, 6, 27, 16, 19, 28), 'update_time': 1752482217, 'update_date': datetime.datetime(2025, 7, 14, 16, 36, 57)} +INFO: 43.140.60.33:0 - "GET /auth/verify HTTP/1.1" 200 OK +INFO: 43.144.107.28:0 - "POST /payment/get_user_museum_subscriptions HTTP/1.1" 200 OK +14:13:23.162 - INFO - 清理资源: 0会话 +14:13:53.168 - INFO - 清理资源: 0会话 +14:14:23.194 - INFO - 清理资源: 0会话 +14:14:43.262 - INFO - wechat login data={'code': '0f1tbl1w35mEh530lf2w37qmFg4tbl1w', 'encryptedData': 'gBWrIB+qeddiOPRlaawf6ChAd2LqpA/4RxSwdeAr8JhDbHz2csnNl6QY0sWegHddLh4gUr5b3EZmWlTigEZIQa7PNqRVJviQczmGymKSU/X+iL7msmSbpPAcO7RZc6tzd8LdZYbYNcACW0qeCqmv7iyXx4FJzTFrKMF2L821N/F/xFe7yR2Sjf/w29/R9Rodqa4NnZNpUZ5QNRIOEwA8OQ==', 'iv': '5xb8I5Lx6PGh4zK2lwanpg=='} +14:14:43.896 - INFO - get_wx_session return {'session_key': 'Zx4n4H7AUgo18OKGcxTZ9w==', 'openid': 'obKSz7cpajH1Z6wECKSOLKXJ_XS8'} +14:14:43.900 - INFO - 解密数据: {'phoneNumber': '15901055018', 'purePhoneNumber': '15901055018', 'countryCode': '86', 'watermark': {'timestamp': 1753078481, 'appid': 'wx446813bfb3a6985a'}} +14:14:43.900 - INFO - decrypt_data return 15901055018 +14:14:43.909 - INFO - login return {'user_id': '6dd06c74-902e-4f42-b226-d93b9ee5c1df', 'openid': 'obKSz7cpajH1Z6wECKSOLKXJ_XS8', 'phone': '15901055018', 'email': None, 'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI2ZGQwNmM3NC05MDJlLTRmNDItYjIyNi1kOTNiOWVlNWMxZGYiLCJleHAiOjE3NTM0MzU4Mjh9.mZSLMkE5eDpJeIpurl3B2PFwxPPbwEkS-q3zPyhvqfo', 'balance': 0, 'status': 1, 'last_login_time': 1752831028, 'create_time': 1748941778, 'create_date': datetime.datetime(2025, 6, 3, 17, 9, 38), 'update_time': 1752831028, 'update_date': datetime.datetime(2025, 7, 18, 17, 30, 28)} +INFO: 43.144.107.210:0 - "POST /auth/login HTTP/1.1" 200 OK +14:14:44.535 - INFO - verify_token user={'user_id': '6dd06c74-902e-4f42-b226-d93b9ee5c1df', 'openid': 'obKSz7cpajH1Z6wECKSOLKXJ_XS8', 'phone': '15901055018', 'email': None, 'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiI2ZGQwNmM3NC05MDJlLTRmNDItYjIyNi1kOTNiOWVlNWMxZGYiLCJleHAiOjE3NTM2ODMyODN9.xBIu5A1efp5xcF10mH-GrJeSi_56BAMzfP3qVfugfhU', 'balance': 0, 'status': 1, 'last_login_time': 1753078483, 'create_time': 1748941778, 'create_date': datetime.datetime(2025, 6, 3, 17, 9, 38), 'update_time': 1753078483, 'update_date': datetime.datetime(2025, 7, 21, 14, 14, 43)} +INFO: 43.144.107.28:0 - "GET /auth/verify HTTP/1.1" 200 OK +INFO: 106.55.206.109:0 - "POST /payment/get_user_museum_subscriptions HTTP/1.1" 200 OK +14:14:53.223 - INFO - 清理资源: 0会话 +14:14:58.786 - INFO - Creating TTS request: {'text': '提梁卣,商,口径13.6厘米,最大腹径22.2厘米,通高26.8厘米,2008年安阳刘家庄北地H326出土,中国社会科学院考古研究所藏。器身扁圆形,直口微敛,长颈、鼓腹下垂,矮圈足,绹索状提梁两端饰小兽首。盖面饰两组独立兽面纹,相背排列;颈部饰八条勾喙夔纹,腹部两组兽面纹以倒夔纹补空,圈足饰八条夔纹。H326位于刘家庄北地两座商代夯土建筑间的通道中,打破路土层,推测为房主仓促撤离时埋藏。此卣填补了商末周初青铜器形制演变的空白。卣的器型特征为:敛口,颈部两侧有提梁,鼓腹,上有盖,盖上有钮,下有圈足。它的命名来自宋人,器物上并不见自名。在不同时期的西周铭文,涉及到有大量赏赐时多有“秬鬯一卣”的记录,在这里“卣”是用作度量单位。鬯是酒的一种,是黍与香草酿制,用以降神。', 'session_id': '0b4cdbaeaf9111efa53df171065841e8', 'delay_gen_audio': True, 'tts_sample_rate': 8000, 'tts_stream_format': 'mp3', 'model_name': 'sambert-zhichu-v1@Tongyi-Qianwen', 'sample_rate': 8000, 'stream_format': 'mp3'} +INFO: 43.144.107.210:0 - "POST /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts?device_id=17530784708251252320 HTTP/1.1" 200 OK +INFO: ('1.13.185.116', 37746) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/0f7bb773-8933-462a-805f-42f1adbc00fc" [accepted] +14:14:58.967 - INFO - 新连接建立: 1c2ac13f-9a2c-4bec-9241-13b5357c56f6 +14:14:58.967 - INFO - 已经启动 start tts task {audio_stream_id} +14:14:58.967 - INFO - ---begin--init QwenTTS-- mp3 8000 sambert-zhichu-v1@Tongyi-Qianwen +INFO: connection open +14:14:58.967 - INFO - Qwen text_tts_call 提梁卣,商,口径13.6厘米,最大腹径22.2厘米,通高26.8厘米,2008年安阳刘家庄北地H326出土,中国社会科学院考古研究所藏。器身扁圆形,直口微敛,长颈、鼓腹下垂,矮圈足,绹索状提梁两端饰小兽首。盖面饰两组独立兽面纹,相背排列;颈部饰八条勾喙夔纹,腹部两组兽面纹以倒夔纹补空,圈足饰八条夔纹。H326位于刘家庄北地两座商代夯土建筑间的通道中,打破路土层,推测为房主仓促撤离时埋藏。此卣填补了商末周初青铜器形制演变的空白。卣的器型特征为:敛口,颈部两侧有提梁,鼓腹,上有盖,盖上有钮,下有圈足。它的命名来自宋人,器物上并不见自名。在不同时期的西周铭文,涉及到有大量赏赐时多有“秬鬯一卣”的记录,在这里“卣”是用作度量单位。鬯是酒的一种,是黍与香草酿制,用以降神。 +14:15:04.257 - INFO - --data_handler on_complete +14:15:04.257 - INFO - --tts task event set error = None +14:15:04.257 - INFO - UnifiedTTSEngine _run_tts_sync finally +INFO: connection closed +14:15:29.264 - INFO - Creating TTS request: {'text': '提梁卣,商,口径13.6厘米,最大腹径22.2厘米,通高26.8厘米,2008年安阳刘家庄北地H326出土,中国社会科学院考古研究所藏。器身扁圆形,直口微敛,长颈、鼓腹下垂,矮圈足,绹索状提梁两端饰小兽首。盖面饰两组独立兽面纹,相背排列;颈部饰八条勾喙夔纹,腹部两组兽面纹以倒夔纹补空,圈足饰八条夔纹。H326位于刘家庄北地两座商代夯土建筑间的通道中,打破路土层,推测为房主仓促撤离时埋藏。此卣填补了商末周初青铜器形制演变的空白。卣的器型特征为:敛口,颈部两侧有提梁,鼓腹,上有盖,盖上有钮,下有圈足。它的命名来自宋人,器物上并不见自名。在不同时期的西周铭文,涉及到有大量赏赐时多有“秬鬯一卣”的记录,在这里“卣”是用作度量单位。鬯是酒的一种,是黍与香草酿制,用以降神。', 'session_id': '0b4cdbaeaf9111efa53df171065841e8', 'delay_gen_audio': True, 'tts_sample_rate': 8000, 'tts_stream_format': 'mp3', 'model_name': 'sambert-zhichu-v1@Tongyi-Qianwen', 'sample_rate': 8000, 'stream_format': 'mp3'} +INFO: 43.140.60.33:0 - "POST /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts?device_id=17510123186117505004 HTTP/1.1" 200 OK +INFO: ('1.13.185.116', 45602) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/1b664897-37ab-46c9-9140-70c44bd23d15" [accepted] +14:15:29.562 - INFO - 新连接建立: 344077b2-343f-4ac6-a37d-b83e81403078 +14:15:29.562 - INFO - 已经启动 start tts task {audio_stream_id} +14:15:29.562 - INFO - ---begin--init QwenTTS-- mp3 8000 sambert-zhichu-v1@Tongyi-Qianwen +14:15:29.562 - INFO - Qwen text_tts_call 提梁卣,商,口径13.6厘米,最大腹径22.2厘米,通高26.8厘米,2008年安阳刘家庄北地H326出土,中国社会科学院考古研究所藏。器身扁圆形,直口微敛,长颈、鼓腹下垂,矮圈足,绹索状提梁两端饰小兽首。盖面饰两组独立兽面纹,相背排列;颈部饰八条勾喙夔纹,腹部两组兽面纹以倒夔纹补空,圈足饰八条夔纹。H326位于刘家庄北地两座商代夯土建筑间的通道中,打破路土层,推测为房主仓促撤离时埋藏。此卣填补了商末周初青铜器形制演变的空白。卣的器型特征为:敛口,颈部两侧有提梁,鼓腹,上有盖,盖上有钮,下有圈足。它的命名来自宋人,器物上并不见自名。在不同时期的西周铭文,涉及到有大量赏赐时多有“秬鬯一卣”的记录,在这里“卣”是用作度量单位。鬯是酒的一种,是黍与香草酿制,用以降神。 +INFO: connection open +14:15:34.797 - INFO - --data_handler on_complete +14:15:34.797 - INFO - --tts task event set error = None +14:15:34.797 - INFO - UnifiedTTSEngine _run_tts_sync finally +INFO: connection closed +14:15:34.848 - INFO - 发送时检测到断开连接: 344077b2-343f-4ac6-a37d-b83e81403078, +INFO: ('1.13.185.116', 48856) - "WebSocket /tts/chats/39e9a2ba5a4711f0865bbb55c66f9471/tts/54697bbc-a6ca-49c3-97a5-b73ac4e30bc9" [accepted] +14:16:25.962 - INFO - 新连接建立: e030143e-537e-4b14-a3ee-6594e50a1711 +INFO: connection open +14:16:25.995 - INFO - 代理文本流: completions_url=http://127.0.0.1:9380/api/v1/chats/39e9a2ba5a4711f0865bbb55c66f9471/completions {'question': '请简单介绍兽面纹铜觚,字符不超过150字', 'stream': True, 'tts_model': 'sambert-zhichu-v1@Tongyi-Qianwen', 'tts_sample_rate': 8000, 'tts_stream_format': 'mp3', 'tts_disable': True} +14:16:28.918 - INFO - HTTP Request: POST http://127.0.0.1:9380/api/v1/chats/39e9a2ba5a4711f0865bbb55c66f9471/completions "HTTP/1.1 200 OK" +14:16:28.918 - INFO - 响应状态: HTTP 200 +14:16:28.918 - INFO - ---begin--init QwenTTS-- mp3 8000 sambert-zhichu-v1@Tongyi-Qianwen +14:16:31.918 - ERROR - create_session Timeout acquiring lock for session creation +14:16:31.919 - INFO - 开始处理SSE流 +14:16:31.919 - INFO - --proxy_aichat_text_stream 发送audio_stream_url diff --git a/asr-monitor-test/app/config.py b/asr-monitor-test/app/config.py index e364696d..c7d4fe13 100644 --- a/asr-monitor-test/app/config.py +++ b/asr-monitor-test/app/config.py @@ -1,8 +1,8 @@ from pymysql.cursors import DictCursor DATABASE_CONFIG = { - "host": "localhost", - "port": 5455, + "host": "localhost",#"ragflow-mysql",#"localhost", + "port": 5455,#3306, #5455, "user": "root", "password": "infini_rag_flow", "database": "rag_flow", diff --git a/asr-monitor-test/app/database.py b/asr-monitor-test/app/database.py index 9f5f885a..fd91a5e1 100644 --- a/asr-monitor-test/app/database.py +++ b/asr-monitor-test/app/database.py @@ -2,7 +2,7 @@ import pymysql from pymysql import Connection from pymysql.err import OperationalError, InterfaceError from contextlib import contextmanager -from config import DATABASE_CONFIG +from app.config import DATABASE_CONFIG from datetime import datetime,timedelta import logging from zoneinfo import ZoneInfo # Python 3.9+ 内置 diff --git a/asr-monitor-test/app/login_service.py b/asr-monitor-test/app/login_service.py index 0598f003..c363ac63 100644 --- a/asr-monitor-test/app/login_service.py +++ b/asr-monitor-test/app/login_service.py @@ -9,7 +9,7 @@ from Crypto.Cipher import AES import base64,uuid,asyncio import requests from datetime import datetime,timedelta -from database import * +from app.database import * login_router = APIRouter() logger = logging.getLogger("login") diff --git a/asr-monitor-test/app/main.py b/asr-monitor-test/app/main.py index 79d67fa7..e79278a7 100644 --- a/asr-monitor-test/app/main.py +++ b/asr-monitor-test/app/main.py @@ -9,10 +9,18 @@ import json from contextlib import asynccontextmanager from dotenv import load_dotenv - - import uvicorn +# 加载 .env 文件中的环境变量 +load_dotenv() # 默认加载项目根目录的 .env 文件 + +from app.asr_service import asr_router +from app.monitor_service import monitor_router +from app.tts_service import tts_router,tts_lifespan +from app.login_service import login_router +from app.chat_service import chat_router +from app.payment_service import payment_router + @asynccontextmanager async def lifespan(app: FastAPI): """生命周期管理""" @@ -20,21 +28,17 @@ async def lifespan(app: FastAPI): print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") print(" ASR & Monitor Service Start") print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + # 启动TTS路由器的生命周期 + tts_lifespan_ctx = tts_lifespan(app) + await tts_lifespan_ctx.__aenter__() yield + # 关闭TTS路由器的生命周期 + await tts_lifespan_ctx.__aexit__(None, None, None) # 服务停止清理 print("\n━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") print(" Service Stopped Cleanly") print("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") -# 加载 .env 文件中的环境变量 -load_dotenv() # 默认加载项目根目录的 .env 文件 - -from app.asr_service import asr_router -from app.monitor_service import monitor_router -from app.tts_service import tts_router -from app.login_service import login_router -from app.chat_service import chat_router -from app.payment_service import payment_router # 创建应用实例 app = FastAPI(lifespan=lifespan) diff --git a/asr-monitor-test/app/payment_service.py b/asr-monitor-test/app/payment_service.py index 896ea082..0752026f 100644 --- a/asr-monitor-test/app/payment_service.py +++ b/asr-monitor-test/app/payment_service.py @@ -12,9 +12,8 @@ from decimal import Decimal from uuid import UUID import json import logging -from database import * +from app.database import * from jose import JWTError, jwt -from database import * import base64 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding diff --git a/asr-monitor-test/app/run.sh b/asr-monitor-test/app/run.sh deleted file mode 100755 index 603f9b77..00000000 --- a/asr-monitor-test/app/run.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -export PYTHONPATH=. -python3 -m main.py diff --git a/asr-monitor-test/app/test.wav b/asr-monitor-test/app/test.wav deleted file mode 100644 index e689f964..00000000 Binary files a/asr-monitor-test/app/test.wav and /dev/null differ diff --git a/asr-monitor-test/app/tts_service.py b/asr-monitor-test/app/tts_service.py index af367aa5..9566ffe2 100644 --- a/asr-monitor-test/app/tts_service.py +++ b/asr-monitor-test/app/tts_service.py @@ -7,6 +7,8 @@ from datetime import timedelta import threading, time, queue, uuid, time, array from threading import Lock, Thread from concurrent.futures import ThreadPoolExecutor +from contextlib import asynccontextmanager +from multiprocessing import Manager import base64, gzip import os, io, re, json from io import BytesIO @@ -29,7 +31,20 @@ SAMPLE_WIDTH = 2 # 16-bit = 2字节 tts_router = APIRouter() +# 路由器专属的生命周期管理器 +@asynccontextmanager +async def tts_lifespan(app: FastAPI): + """tts_service路由模块的生命周期管理器""" + print("tts_service路由器正在启动...") + + try: + yield + finally: + print("tts_service路由器正在关闭...") + # logger = logging.getLogger(__name__) + + class MillisecondsFormatter(logging.Formatter): """自定义日志格式器,添加毫秒时间戳""" @@ -66,50 +81,61 @@ configure_logging() class StreamSessionManager: - def __init__(self): - self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}} - self.lock = threading.Lock() + def __init__(self,manager = None): + #self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}} + #self.lock = threading.Lock() + + self.sessions = {} + self.lock = threading.Lock() + self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池 self.gc_interval = 300 # 5分钟清理一次 - self.streaming_call_timeout = 15 # 20s + self.streaming_call_timeout = 10 # 10s self.gc_tts = 3 # 3s - self.sentence_timeout = 1.5 # 1500ms句子超时 + self.sentence_timeout = 2 # 2000ms句子超时 self.sentence_endings = set('。?!;.?!;') # 中英文结束符 # 增强版正则表达式:匹配中英文句子结束符(包含全角) self.sentence_pattern = re.compile( - r'([,,。?!;.?!;?!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;?!;…])' + #r'([,,。?!;.?!;?!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;?!;…])' + + r'((?:(? self.sentence_timeout: - with self.lock: - if session['text_buffer']: - # 直接处理剩余文本 - gen_tts_audio_func(session_id, session['text_buffer']) - session['text_buffer'] = "" - - if current_time - session['last_active'] > self.streaming_call_timeout: - if session.get('streaming_call'): - session['tts_model'].end_streaming_call() - session['streaming_call'] = False - - # 会话超时检查 - if current_time - session['last_active'] > self.gc_interval: - with self.lock: - if session['text_buffer']: - gen_tts_audio_func(session_id, session['text_buffer']) - session['text_buffer'] = "" - self.close_session(session_id) - break - - # 休眠避免CPU空转 - time.sleep(0.05) # 50ms检查间隔 def _process_tasks(self, session_id): # 20250718 新更新 """任务处理线程(每个会话独立)- 保留原有处理逻辑""" @@ -229,11 +211,11 @@ class StreamSessionManager: # 根据引擎类型选择处理函数 if session.get('streaming_call'): - gen_tts_audio_func = self._stream_audio + gen_tts_audio_func = self._generate_audio #self._stream_audio else: gen_tts_audio_func = self._generate_audio - while session['active']: + while session['active'] and session.get('should_stop', False) is False: current_time = time.time() text_to_process = "" @@ -244,6 +226,7 @@ class StreamSessionManager: # 2. 处理文本 + # 在Python中,空字符串在布尔上下文中被视为 False if text_to_process and not session['current_processing'] : session['text_buffer'] = "" # 分割完整句子 @@ -294,109 +277,96 @@ class StreamSessionManager: # 处理剩余的缓冲文本 if buffer: combined_text = "".join(buffer) - - # 重置完成事件状态 - session['sentence_complete_event'].clear() session['current_processing'] = True # 生成音频 - gen_tts_audio_func(session_id, combined_text) + gen_tts_audio_func(session_id, combined_text) #如果调用_stream_audio,则是同步调用,会阻塞,直到音频生成完成 - # 等待完成 - if not session['sentence_complete_event'].wait(timeout=120.0): - logging.warning(f"Timeout waiting for TTS completion: {combined_text}") # 重置处理状态 time.sleep(1.0) session['current_processing'] = False - logging.info(f"StreamSessionManager _process_tasks 转换结束!!!") + # 3. 检查超时未处理的文本 if current_time - session['last_text_time'] > self.sentence_timeout: + text_to_process = "" with self.lock: if session['text_buffer']: - # 直接处理剩余文本 - session['sentence_complete_event'].clear() - session['current_processing'] = True - gen_tts_audio_func(session_id, session['text_buffer']) - session['text_buffer'] = "" - - # 等待完成 - if not session['sentence_complete_event'].wait(timeout=120.0): - logging.warning(f"Timeout waiting for TTS completion: {combined_text}") - # 重置处理状态 - session['current_processing'] = False + text_to_process = session['text_buffer'] + if text_to_process: + # 直接处理剩余文本 + session['current_processing'] = True + gen_tts_audio_func(session_id, session['text_buffer']) + session['text_buffer'] = "" + # 重置处理状态 + session['current_processing'] = False + """ # 4. 会话超时检查 - if current_time - session['last_active'] > self.gc_interval: + if current_time - session['last_active'] > session['gc_interval']: # 处理剩余文本 with self.lock: if session['text_buffer']: - # 重置完成事件状态 - session['sentence_complete_event'].clear() session['current_processing'] = True # 处理最后一段文本 gen_tts_audio_func(session_id, session['text_buffer']) session['text_buffer'] = "" - - # 等待完成 - if not session['sentence_complete_event'].wait(timeout=120.0): - logging.warning(f"Timeout waiting for TTS completion: {combined_text}") # 重置处理状态 session['current_processing'] = False # 关闭会话 + logging.info(f"--_process_tasks-- timeout {current_time} {session['last_active']} {session['gc_interval']}") self.close_session(session_id) break - + """ # 5. 休眠避免CPU空转 time.sleep(0.05) # 50ms检查间隔 - def _generate_audio1(self, session_id, text): - """实际生成音频(线程池执行)""" - session = self.sessions.get(session_id) - if not session: return - logging.info(f"_generate_audio:{text}") - first_chunk = True - # logging.info(f"转换开始!!! {text}") - try: - """ - for chunk in session['tts_model'].tts(text, session['sample_rate'], session['stream_format']): - if session['stream_format'] == 'wav': - if first_chunk: - chunk_len = len(chunk) - if chunk_len > 2048: - session['buffer'].put(audio_fade_in(chunk, 1024)) - else: - session['buffer'].put(audio_fade_in(chunk, chunk_len)) - first_chunk = False - else: - session['buffer'].put(chunk) - else: - session['buffer'].put(chunk) - """ - session['tts_model'].text_tts_call(text) - session['last_active'] = time.time() - session['audio_chunk_count'] = session['audio_chunk_count'] + 1 - if session['tts_chunk_data_valid'] is False: - session['tts_chunk_data_valid'] = True # 20250510 增加,表示连接TTS后台已经返回,可以通知前端了 - # logging.info(f"转换结束!!! {session['audio_chunk_count']}") - except Exception as e: - session['buffer'].put(f"ERROR:{str(e)}") def _generate_audio(self, session_id, text): # 20250718 新更新 """实际生成音频(顺序执行)- 用于非流式引擎""" session = self.sessions.get(session_id) - if not session: + if not session or session.get('should_stop', False): return try: - logging.info(f"StreamSessionManager _generate_audio--0 {text}") + #logging.info(f"StreamSessionManager _generate_audio--0 {text}") + + def on_data_whole(data: bytes): + if data: + try: + session['last_active'] = time.time() + session['buffer'].put({'type':'arraybuffer','data':data}) + #session['buffer'].put(data) + session['audio_chunk_size'] += len(data) + #logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}") + except queue.Full: + logging.warning(f"Audio buffer full for session {session_id}") + + # 重置完成事件 + session['sentence_complete_event'].clear() + session['tts_model'].setup_tts(on_data = on_data_whole,completion_event=session['sentence_complete_event']) # 调用 TTS session['tts_model'].text_tts_call(text) session['last_active'] = time.time() session['audio_chunk_count'] += 1 + # 等待句子完成,但会检查停止标志 + start_time = time.time() + timeout_or_stopped = True + while not session['sentence_complete_event'].wait(timeout=0.5): # 每0.5秒检查一次 + # 检查是否超时或收到停止信号 + if time.time() - start_time > 30 : + timeout_or_stopped = True + break + if session.get('should_stop',False): + timeout_or_stopped = False + break + logging.info(f"StreamSessionManager _generate_audio 转换结束!!!" + f"{session['audio_chunk_size']} {session_id} " + f"收到停止信号: {not timeout_or_stopped}") + session['buffer'].put({'type': 'sentence_end', 'data': ""}) if not session['tts_chunk_data_valid']: session['tts_chunk_data_valid'] = True @@ -412,6 +382,8 @@ class StreamSessionManager: # logging.info(f"Streaming text to TTS: {text}") try: + # 重置完成事件状态 + session['sentence_complete_event'].clear() # 使用流式调用发送文本 session['tts_model'].streaming_call(text) session['last_active'] = time.time() @@ -430,20 +402,60 @@ class StreamSessionManager: buffer = session['buffer'] # 这里是 queue.Queue last_data_time = time.time() # 记录最后一次获取数据的时间 - + get_tts_audio_size = 0 + get_tts_audio_return = 0 while session['active']: try: + # 检查会话是否被标记停止 + if session.get('should_stop', False): + logging.info(f"会话被标记停止: {session_id}") + break + # 使用 run_in_executor + wait_for 设置 10 秒超时 data = await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, buffer.get), + asyncio.get_event_loop().run_in_executor(self.executor, buffer.get), timeout=10.0 # 10 秒超时 ) + get_tts_audio_return += 1 + # 检查停止信号 + if isinstance(data, dict) and data.get('type') == 'stop_signal': + logging.info(f"StreamSessionManager get_tts_buffer_data收到停止信号: {session_id}") + break + if isinstance(data, dict) and data.get('data'): + get_tts_audio_size += len(data['data']) + #logging.info(f"get_tts_buffer_data {data}") last_data_time = time.time() # 更新最后数据时间 yield data except asyncio.TimeoutError: # 10 秒内没有新数据,检查是否超时 if time.time() - last_data_time >= 10.0: + logging.info(f"get_tts_buffer_data {session_id} Timeout after 10 seconds " + f"data_size={get_tts_audio_size} qsize={buffer.qsize()} {get_tts_audio_return}") + if buffer.qsize() >= 10: + # 获取线程池状态 - 使用自定义的 executor + active_threads = len(threading.enumerate()) # 当前系统线程数 + + # 获取线程池特定信息 + pool_threads = [ + t for t in threading.enumerate() + if t.name.startswith("ThreadPoolExecutor") + ] + + # 安全地获取等待任务数 + pending_tasks = 0 + if hasattr(self.executor, '_work_queue'): + pending_tasks = self.executor._work_queue.qsize() + + logging.warning( + f"[{threading.current_thread().name}] Timeout: " + f"System threads={active_threads}, " + f"ThreadPool threads={len(pool_threads)}, " + f"Pending tasks={pending_tasks}" + ) + # 202507 调试发现,偶尔存在队列中已经生成了TTS音频数据,但是上述从队列中获取数据时,不能成功 + # 所以做如下的复位关键资源的操作 dale_yxc + self.reset_manager() # 出现取不到队列数据的异常,重置关键资源 break else: continue # 未超时,继续等待 @@ -456,33 +468,68 @@ class StreamSessionManager: logging.error(f"Error in get_tts_buffer_data: {e}") break + def close_session(self, session_id): with self.lock: if session_id in self.sessions: - # 结束流式传输 - try: - # if self.sessions[session_id].get('streaming_call'): - # self.sessions[session_id]['tts_model'].end_streaming_call() - logging.info(f"Ended streaming for session {session_id}") - except Exception as e: - logging.error(f"Error ending streaming call: {str(e)}") + session = self.sessions[session_id] + session['active'] = False - # 标记会话为不活跃 - self.sessions[session_id]['active'] = False - # 设置完成事件(确保任何等待的线程被唤醒) - self.sessions[session_id]['sentence_complete_event'].set() - # 延迟2秒后清理资源 + # 设置完成事件 + session['sentence_complete_event'].set() + + # 清理TTS资源 + try: + if session.get('streaming_call'): + session['tts_model'].end_streaming_call() + except: + pass + + # 延迟清理会话 threading.Timer(1, self._clean_session, args=[session_id]).start() def _clean_session(self, session_id): with self.lock: if session_id in self.sessions: - # 确保流完全关闭 + del self.sessions[session_id] + + def _cleanup_expired(self): + """优化后的清理方法""" + while True: + time.sleep(30) + now = time.time() + to_remove = [] + + # 快速收集需要清理的会话ID(减少锁持有时间) + with self.lock: + for sid, session in list(self.sessions.items()): + if now - session['last_active'] > self.gc_interval: + to_remove.append(sid) + + # 在锁外执行实际清理 + for sid in to_remove: + self._clean_session(sid) + + logging.info(f"清理资源: {len(to_remove)}会话") + + def stop_session(self, session_id: str): + """停止指定会话的音频生成""" + if session_id in self.sessions: + session = self.sessions[session_id] + + # 设置停止标志 + session['should_stop'] = True + + # 如果使用队列,放入停止标记 + if 'buffer' in session: try: - self.sessions[session_id]['tts_model'].end_streaming_call() + # 放入特殊值通知生成循环退出 + session['buffer'].put({"type": "stop_signal"}) except: pass - del self.sessions[session_id] + # 设置完成事件,确保任务处理线程能退出 + session['sentence_complete_event'].set() + logging.info(f"StreamSessionManager stop_session: {session_id}") def get_session(self, session_id): return self.sessions.get(session_id) @@ -506,6 +553,7 @@ class StreamSessionManager: """ 增强型句子分割器 返回: (完整句子列表, 剩余文本) + 限定返回长度 """ # 特殊处理:如果文本以逗号开头,先处理前面的部分 if text.startswith((",", ",")): @@ -531,7 +579,7 @@ class StreamSessionManager: continue # 检查是否为有效句子(最小长度或包含结束符) - if len(sentence) > 6 or any(char in "。.?!?!" for char in sentence): + if (len(sentence) > 6 or any(char in "。.?!?!" for char in sentence)) and (last_end<24): complete_sentences.append(sentence) last_end = end_pos else: @@ -542,34 +590,98 @@ class StreamSessionManager: # 3. 提取剩余文本 remaining_text = text[last_end:].strip() - return complete_sentences, remaining_text + def reset_manager(self): + """完全重置管理器 - 极简版本""" + logging.critical("Resetting StreamSessionManager...") + + # 步骤1: 关闭所有会话 + + for session_id in list(self.sessions.keys()): + try: + # 直接清理会话而不调用close_session + if session_id in self.sessions: + # 尝试释放TTS资源 + try: + if self.sessions[session_id].get('tts_model'): + self.sessions[session_id]['tts_model'].cleanup() + except Exception: + pass + del self.sessions[session_id] + except Exception: + pass + + # 步骤2: 重置关键资源 + # 重建线程池 + try: + self.executor.shutdown(wait=False) + except Exception: + pass + self.executor = ThreadPoolExecutor(max_workers=30) + + # 清空会话字典 + self.sessions = {} + + # 重建锁对象 + self.lock = threading.Lock() + + logging.critical("Reset completed") + + def get_self_thread_pool_status(self): + # 获取线程池状态 - 使用自定义的 executor + active_threads = len(threading.enumerate()) # 当前系统线程数 + + # 获取线程池特定信息 + pool_threads = [ + t for t in threading.enumerate() + if t.name.startswith("ThreadPoolExecutor") + ] + + # 安全地获取等待任务数 + pending_tasks = 0 + if hasattr(self.executor, '_work_queue'): + pending_tasks = self.executor._work_queue.qsize() + return ( + f"[{threading.current_thread().name}] Timeout: " + f"System threads={active_threads}, " + f"ThreadPool threads={len(pool_threads)}, " + f"Pending tasks={pending_tasks}" + ) + stream_manager = StreamSessionManager() + def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} -audio_text_cache = {} -cache_lock = Lock() -CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 - # WebSocket 连接管理 class ConnectionManager: def __init__(self): self.active_connections = {} + self.aichat_audio_sessions = {} # 用于存储aichat音频会话,关联StreamSessionManager的session_id async def connect(self, websocket: WebSocket, connection_id: str): await websocket.accept() self.active_connections[connection_id] = websocket logging.info(f"新连接建立: {connection_id}") + # 注册音频会话,关联StreamSessionManager的session_id + def register_audio_session(self, connection_id: str, audio_session_id: str): + self.aichat_audio_sessions[connection_id] = audio_session_id + async def disconnect(self, connection_id: str, code=1000, reason: str = ""): + # 通知StreamSessionManager管理器停止生音频 + if connection_id in self.aichat_audio_sessions: + audio_session_id = self.aichat_audio_sessions[connection_id] + stream_manager.stop_session(audio_session_id) + del self.aichat_audio_sessions[connection_id] + if connection_id in self.active_connections: try: # 尝试正常关闭连接(非阻塞) @@ -587,7 +699,7 @@ class ConnectionManager: """安全发送的通用方法(核心修改)""" # 1. 检查连接是否存在 if connection_id not in self.active_connections: - logging.warning(f"尝试向不存在的连接发送数据: {connection_id}") + #logging.warning(f"尝试向不存在的连接发送数据: {connection_id}") return False websocket = self.active_connections[connection_id] @@ -853,11 +965,6 @@ def parse_markdown_json(json_string): return {'success': False, 'data': 'not a valid markdown json string'} -audio_text_cache = {} -cache_lock = Lock() -CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 - - # 全角字符到半角字符的映射 def fullwidth_to_halfwidth(s): full_to_half_map = { @@ -936,32 +1043,6 @@ def encode_gzip_base64(original_data: bytes) -> str: return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP) -def clean_audio_cache(): - """定时清理过期缓存""" - with cache_lock: - now = time.time() - expired_keys = [ - k for k, v in audio_text_cache.items() - if now - v['created_at'] > CACHE_EXPIRE_SECONDS - ] - for k in expired_keys: - entry = audio_text_cache.pop(k, None) - if entry and entry.get('audio_stream'): - entry['audio_stream'].close() - - -def start_background_cleaner(): - """启动后台清理线程""" - - def cleaner_loop(): - while True: - time.sleep(180) # 每3分钟清理一次 - clean_audio_cache() - - cleaner_thread = Thread(target=cleaner_loop, daemon=True) - cleaner_thread.start() - - def test_qwen_chat(): messages = [ {'role': 'system', 'content': 'You are a helpful assistant.'}, @@ -1004,13 +1085,13 @@ class QwenTTS: special_characters: Optional[Dict[str, str]] = None): import dashscope import ssl - logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx self.model_name = model_name.split('@')[0] dashscope.api_key = key ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证 self.synthesizer = None self.callback = None self.is_cosyvoice = False + self.cosyvoice = "" self.voice = "" self.format = format self.sample_rate = sample_rate @@ -1020,7 +1101,9 @@ class QwenTTS: # 返回分离后的两个字符串parts[0], parts[1] if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2': self.is_cosyvoice = True + self.cosyvoice = parts[0] self.voice = parts[1] + logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {self.cosyvoice} {self.voice}") # cyx self.completion_event = None # 新增:用于通知任务完成 # 特殊字符及其拼音映射 self.special_characters = special_characters or { @@ -1049,7 +1132,6 @@ class QwenTTS: pass def on_complete(self): - logging.info(f"---QwenTTS Callback on_complete--") self.dque.append(None) if self.data_callback: self.data_callback(None) # 发送结束信号 @@ -1094,7 +1176,7 @@ class QwenTTS: break def on_open(self): - logging.info("Qwen CosyVoice tts open ") + #logging.info("Qwen CosyVoice tts open ") pass def on_complete(self): @@ -1113,7 +1195,7 @@ class QwenTTS: def on_close(self): # print("---Qwen call back close") # cyx - logging.info("Qwen CosyVoice tts close") + #logging.info("Qwen CosyVoice tts close") pass """ canceled for test 语音大模型CosyVoice @@ -1192,12 +1274,14 @@ class QwenTTS: self.callback = self.Callback( data_callback=on_data, completion_event=completion_event) - format_val = self.get_audio_format(self.format, self.sample_rate) - logging.info(f"Qwen setup_tts {self.voice} {format_val}") + + if self.is_cosyvoice: + format_val = self.get_audio_format(self.format, self.sample_rate) + # logging.info(f"Qwen setup_tts {self.voice} {format_val}") self.synthesizer = CosySpeechSynthesizer( - model='cosyvoice-v1', - voice=self.voice, # voice="longyuan", #"longfei", + model=self.cosyvoice, #'cosyvoice-v1', + voice=self.voice, # voice="longyuan", #"longfei", "longyuan_2" callback=self.callback, format=format_val ) @@ -1227,11 +1311,21 @@ class QwenTTS: return f"{text}" def text_tts_call(self, text): - if self.special_characters and self.is_cosyvoice is False: + if self.special_characters : text = self.apply_phoneme_tags(text) #logging.info(f"Applied SSML phoneme tags to text: {text}") - + volume = 50 + if self.sample_rate < 10000: + volume = 70 if self.synthesizer and self.is_cosyvoice: + #logging.info(f"Qwen text_tts_call {text} {self.cosyvoice} {self.voice}") + format_val = self.get_audio_format(self.format, self.sample_rate) + self.synthesizer = CosySpeechSynthesizer( + model = self.cosyvoice, #'cosyvoice-v1', + voice = self.voice, # voice="longyuan", #"longfei", + callback=self.callback, + format=format_val + ) self.synthesizer.call(text) if self.is_cosyvoice is False: logging.info(f"Qwen text_tts_call {text}") @@ -1249,6 +1343,9 @@ class QwenTTS: # logging.info(f"---dale end_streaming_call") self.synthesizer.streaming_complete() + def cleanup(self): + pass + def get_audio_format(self, format: str, sample_rate: int): """动态获取音频格式""" from dashscope.audio.tts_v2 import AudioFormat @@ -1721,31 +1818,71 @@ def replace_domain(url: str) -> str: return f"http://localhost:9380/{url}" -async def proxy_aichat_audio_stream(client_id: str, audio_url: str): - """代理外部音频流请求""" + +async def proxy_aichat_audio_stream( + client_id: str, + audio_stream_id: str, + combined_state: dict = None # 新增组合状态参数 + ): try: - # 替换域名为本地地址 - local_url = audio_url - logging.info(f"代理音频流: {audio_url} -> {local_url}") + stream_session = stream_manager.get_session(audio_stream_id) + if not stream_session: + logging.warning(f"Audio session not found: {audio_stream_id}") + return + + # 注册当前连接的音频会话 + manager.register_audio_session(client_id, audio_stream_id) + + sample_rate = stream_session.get('sample_rate') + audio_data_size = 0 + + # 发送采样率 + await manager.send_json(client_id, {"command": "sample_rate", "params": sample_rate}) + + # 处理音频流 + async for data in stream_manager.get_tts_buffer_data(audio_stream_id): + # 检查websocket 连接是否仍然活跃 + if not manager.is_connected(client_id): + logging.info(f"audio websocket 连接已断开,停止音频生成: {client_id}") + # 通知会话管理器停止生成 + stream_manager.stop_session(audio_stream_id) + break + if isinstance(data, dict): + if data.get('type') == 'sentence_end': + await manager.send_json(client_id, {"command": "sentence_end"}) + + elif data.get('type') == 'arraybuffer': + audio_data = data.get('data') + audio_data_size += len(audio_data) + + if not await manager.send_bytes(client_id, audio_data): + break + + logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}") + + # 组合模式通知音频流结束 + if combined_state: + combined_state["audio_completed"].set() - async with httpx.AsyncClient(timeout=60.0) as client: - async with client.stream("GET", local_url) as response: - # 流式转发音频数据 - async for chunk in response.aiter_bytes(): - if not await manager.send_bytes(client_id, chunk): - logging.warning(f"Audio proxy interrupted for {client_id}") - return except Exception as e: - logging.error(f"Audio proxy failed: {str(e)}") - await manager.send_text(client_id, json.dumps({ - "type": "error", - "message": f"音频流获取失败: {str(e)}" - })) - + logging.error(f"音频流处理失败: {str(e)}") + # 仅在连接活跃时发送错误 + if manager.is_connected(client_id): + await manager.send_text(client_id, json.dumps({ + "type": "error", + "message": f"音频流错误: {str(e)}" + })) + finally: + # 确保取消注册 + if client_id in manager.aichat_audio_sessions: + del manager.aichat_audio_sessions[client_id] # 代理函数 - 文本流 # 在微信小程序中,原来APK使用的SSE机制不能正常工作,需要使用WebSocket -async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict): +async def proxy_aichat_text_stream(client_id: str, + completions_url: str, + payload: dict, + combined_state: dict = None): """代理大模型文本流请求 - 兼容现有Flask实现""" try: logging.info(f"代理文本流: completions_url={completions_url} {payload}") @@ -1755,32 +1892,15 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload "Content-Type": "application/json", 'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm' } + tts_model = None tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen') #if 'longyuan' in tts_model_name: # tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen" - # 创建TTS实例 - tts_model = QwenTTS( - key=ALI_KEY, - format=payload.get('tts_stream_format', 'mp3'), - sample_rate=payload.get('tts_sample_rate', 48000), - model_name=tts_model_name - ) - streaming_call = False - if tts_model.is_cosyvoice: - streaming_call = True - - # 创建流会话 - tts_stream_session_id = stream_manager.create_session( - tts_model=tts_model, - sample_rate=payload.get('tts_sample_rate', 48000), - stream_format=payload.get('tts_stream_format', 'mp3'), - session_id=None, - streaming_call= streaming_call - ) - # logging.info(f"---tts_stream_session_id = {tts_stream_session_id}") + #logging.info(f"---create tts_stream_session_id = {tts_stream_session_id}") tts_stream_session_id_sent = False + # 使用更长的超时时间 (5分钟) - timeout = httpx.Timeout(300.0, connect=60.0) + timeout = httpx.Timeout(30.0, connect=20.0,read=20, write=20) async with httpx.AsyncClient(timeout=timeout) as client: # 关键修改:使用流式请求模式 async with client.stream( # <-- 使用stream方法 @@ -1789,8 +1909,8 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload json=payload, headers=headers ) as response: - logging.info(f"响应状态: HTTP {response.status_code}") + logging.info(f"响应状态: HTTP {response.status_code}") if response.status_code != 200: # 读取错误信息(非流式) error_content = await response.aread() @@ -1810,7 +1930,32 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload })) return - logging.info("开始处理SSE流") + if tts_model is None: + # 创建TTS实例 + tts_model = QwenTTS( + key=ALI_KEY, + format=payload.get('tts_stream_format', 'mp3'), + sample_rate=payload.get('tts_sample_rate', 48000), + model_name=tts_model_name + ) + streaming_call = False + if tts_model.is_cosyvoice: + streaming_call = True + + # 创建流会话 + tts_stream_session_id = stream_manager.create_session( + tts_model=tts_model, + sample_rate=payload.get('tts_sample_rate', 48000), + stream_format=payload.get('tts_stream_format', 'mp3'), + session_id=None, + streaming_call=streaming_call + ) + # 关键修改:设置TTS会话ID并触发就绪事件 + if combined_state is not None and tts_stream_session_id: + combined_state["tts_session_id"] = tts_stream_session_id + combined_state["tts_ready_event"].set() # 触发事件通知主流程 + + logging.info(f"开始处理SSE流 {tts_stream_session_id}") event_count = 0 # 使用异步迭代器逐行处理 async for line in response.aiter_lines(): @@ -1829,8 +1974,10 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict): delta_text = data_obj.get('data', None).get('delta_ans', "") if tts_stream_session_id_sent is False: - data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}" - data_str = json.dumps(data_obj) + if tts_stream_session_id and combined_state is None: + logging.info(f"--proxy_aichat_text_stream 发送audio_stream_url") + data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}" + data_str = json.dumps(data_obj) tts_stream_session_id_sent = True # 直接转发原始数据 await manager.send_text(client_id, json.dumps({ @@ -1838,7 +1985,8 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload "data": data_str })) # 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析 - if delta_text: + if delta_text and tts_stream_session_id: + # 追加到会话管理器 stream_manager.append_text(tts_stream_session_id, delta_text) # logging.info(f"文本代理转发: {data_str}") @@ -1849,13 +1997,16 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload # 保持连接活性 await asyncio.sleep(0.001) # 避免CPU空转 - logging.info(f"SSE流处理完成,事件数: {event_count}") + #logging.info(f"SSE流处理完成,事件数: {event_count}") - # 发送结束信号 - await manager.send_text(client_id, json.dumps({"type": "end"})) + # 发送文本流结束信号 + await manager.send_text(client_id, json.dumps({"type": "AiChatTextEnd"})) + # 标记文本输入结束 + if tts_stream_session_id and stream_manager.finish_text_input: + stream_manager.finish_text_input(tts_stream_session_id) - except httpx.ReadTimeout: - logging.error("读取后端服务超时") + except httpx.ConnectTimeout: + logging.error("后端服务超时") await manager.send_text(client_id, json.dumps({ "type": "error", "message": "后端服务响应超时" @@ -1872,7 +2023,11 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload "type": "error", "message": f"文本流获取失败: {str(e)}" })) - + finally: + pass + # 记录连接池状态 + #pool_status = await HTTPXConnectionPool.get_pool_status() + #logging.debug(f"连接池状态: {json.dumps(pool_status)}") @tts_router.get("/audio/pcm_mp3") async def stream_mp3(): @@ -1917,6 +2072,7 @@ def generate_silence_header(duration_ms: int = 500) -> bytes: # ------------------------ API路由 ------------------------ + @tts_router.post("/chats/{chat_id}/tts") async def create_tts_request(chat_id: str, request: Request): try: @@ -1964,9 +2120,6 @@ async def create_tts_request(chat_id: str, request: Request): raise HTTPException(500, detail="Internal server error") -executor = ThreadPoolExecutor() - - @tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}") async def get_tts_audio( chat_id: str, @@ -2117,6 +2270,10 @@ async def websocket_tts_endpoint( # 接收 header 参数 headers = websocket.headers service_type = headers.get("x-tts-type") # 注意:header 名称转为小写 + #给H5代码特殊处理,H5代码中,x-tts-type的header不能工作 + # 浏览器内置WebSocket 的连接时不能附加额外的header传递参数 + if audio_stream_id == 'x-tts-type-is-TextToTts': + service_type = 'TextToTts' # audio_url = headers.get("x-audio-url") """ 前端示例 @@ -2146,31 +2303,131 @@ async def websocket_tts_endpoint( # 根据tts_type路由到不同的音频源 if service_type == "AiChatTts": # 音频代理服务 - audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}" - # await proxy_aichat_audio_stream(connection_id, audio_url) - sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate') - audio_data_size =0 - await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate}) - async for data in stream_manager.get_tts_buffer_data(audio_stream_id): - audio_data_size += len(data) - if not await manager.send_bytes(connection_id, data): - break + await proxy_aichat_audio_stream(connection_id, audio_stream_id, combined_state = None) completed_successfully = True - logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}") + elif service_type == "AiChatText": # 文本代理服务 # 等待客户端发送初始请求数据 进行大模型对话代理时,需要前端连接后发送payload payload = await websocket.receive_json() - completions_url = f"http://localhost:9380/api/v1/chats/{chat_id}/completions" - await proxy_aichat_text_stream(connection_id, completions_url, payload) + # 在代理前检查连接池状态 + completions_url = f"http://127.0.0.1:9380/api/v1/chats/{chat_id}/completions" + await proxy_aichat_text_stream(connection_id, completions_url, payload, combined_state = None) completed_successfully = True + elif service_type == "AiChatCombined": + # 接收初始请求数据 + payload = await websocket.receive_json() + + # 创建共享状态和同步事件 + combined_state = { + "tts_session_id": None, + "tts_ready_event": asyncio.Event(), # TTS准备就绪事件 + "audio_task": None, + "text_completed": asyncio.Event(), + "audio_completed": asyncio.Event() + } + + # 启动文本流任务 + text_task = asyncio.create_task( + proxy_aichat_text_stream( + client_id=connection_id, + completions_url=f"http://127.0.0.1:9380/api/v1/chats/{chat_id}/completions", + payload=payload, + combined_state=combined_state + ) + ) + + try: + # 等待TTS会话ID准备就绪(最多等待8秒) + await asyncio.wait_for(combined_state["tts_ready_event"].wait(), timeout=8.0) + + if combined_state["tts_session_id"]: + # 启动音频流任务 + combined_state["audio_task"] = asyncio.create_task( + proxy_aichat_audio_stream( + client_id=connection_id, + audio_stream_id=combined_state["tts_session_id"], + combined_state=combined_state + ) + ) + else: + logging.warning("TTS会话ID未生成,跳过音频流任务") + except asyncio.TimeoutError: + logging.warning("等待TTS会话ID超时,跳过音频流任务") + + # 等待两个任务完成(如果音频任务未启动,text_task会正常完成) + tasks = [text_task] + if combined_state["audio_task"]: + tasks.append(combined_state["audio_task"]) + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_EXCEPTION) + + # 检查任务状态并处理异常 + for task in done: + if task.exception(): + logging.error(f"任务异常: {task.exception()}") + # 发送错误消息给客户端 + await manager.send_text(connection_id, json.dumps({ + "type": "error", + "message": str(task.exception()) + })) + + # 取消任何未完成的任务 + for task in pending: + task.cancel() + + # 发送完成信号 + if manager.is_connected(connection_id): + await manager.send_text(connection_id, json.dumps({"type": "end"})) + logging.info(" websocket_tts_endpoint AiChatCombined completed successfully") + elif service_type == "TextToTts": + # 前端将文本发送到后端,后端调用TTS引擎生成音频流 ,并且将生成音频的文本、生成音频的参数 + # 返回音频,在1个websocket调用中完成 + params_valid = True + payload = await websocket.receive_json() + # 参数校验 + text = payload.get("text", "").strip() + if not text: + params_valid = False + data = payload.get("params", {}) + logging.info(f"websocket_tts_endpoint TextToTts:{text} {data}") + format = data.get("tts_stream_format", "mp3") + if format not in ["mp3", "wav", "pcm"]: + params_valid = False + + sample_rate = data.get("tts_sample_rate", 48000) + if sample_rate not in [8000, 16000, 22050, 44100, 48000]: + params_valid = False + + model_name = data.get("model_name", "cosyvoice-v1/longxiaochun") + delay_gen_audio = data.get('delay_gen_audio', False) + + if params_valid: + # 创建TTS任务 + audio_stream_id = tts_engine.create_tts_task( + text=text, + format=format, + sample_rate=sample_rate, + model_name=model_name, + key=ALI_KEY, + delay_gen_audio=delay_gen_audio + ) + # 使用引擎的生成器直接获取音频流 + audio_data_size = 0 + async for data in tts_engine.get_audio_stream(audio_stream_id): + audio_data_size += len(data) + if not await manager.send_bytes(connection_id, data): + logging.warning(f"Send failed, connection closed: {connection_id}") + break + await manager.send_json(connection_id, {"command": "sentence_end"}) + logging.info(f"websocket_tts_endpoint TextToTts completed successfully {audio_data_size} bytes") else: # 使用引擎的生成器直接获取音频流 async for data in tts_engine.get_audio_stream(audio_stream_id): if not await manager.send_bytes(connection_id, data): logging.warning(f"Send failed, connection closed: {connection_id}") break - + await manager.send_json(connection_id, {"command": "sentence_end"}) completed_successfully = True # 发送完成信号前检查连接状态 @@ -2184,9 +2441,9 @@ async def websocket_tts_endpoint( # 主动关闭WebSocket连接 await manager.disconnect(connection_id, code=1000, reason="Audio stream completed") except WebSocketDisconnect: - logging.info(f"WebSocket disconnected: {connection_id}") + logging.info(f"websocket_tts_endpoint WebSocket disconnected: {connection_id}") except Exception as e: - logging.error(f"WebSocket TTS error: {str(e)}") + logging.error(f"websocket_tts_endpoint WebSocket TTS error: {str(e)}") if manager.is_connected(connection_id): await manager.send_json(connection_id, {"error": str(e)}) finally: @@ -2194,15 +2451,17 @@ async def websocket_tts_endpoint( # await manager.disconnect(connection_id) -def cleanup_cache(): - """清理过期缓存""" - with cache_lock: - now = datetime.datetime.now() - expired = [k for k, v in audio_text_cache.items() - if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS] - for key in expired: - logging.info(f"del audio_text_cache= {audio_text_cache[key]}") - del audio_text_cache[key] +@tts_router.get("/debug/get_threadpool") +async def get_threadpool(request: Request): + params = dict(request.query_params) + if params.get('reset'): + stream_manager.reset_manager() + + return JSONResponse( + status_code=200, + content={ + "status":stream_manager.get_self_thread_pool_status() + } + ) + -# 应用启动时启动清理线程 -# start_background_cleaner() diff --git a/asr-monitor-test/bk/tts_service.py b/asr-monitor-test/bk/tts_service.py new file mode 100644 index 00000000..e2531863 --- /dev/null +++ b/asr-monitor-test/bk/tts_service.py @@ -0,0 +1,2328 @@ +import logging +import binascii +from copy import deepcopy +from timeit import default_timer as timer +import datetime +from datetime import timedelta +import threading, time, queue, uuid, time, array +from threading import Lock, Thread +from concurrent.futures import ThreadPoolExecutor +import base64, gzip +import os, io, re, json +from io import BytesIO +from typing import Optional, Dict, Any +import asyncio, httpx +from collections import deque + +import websockets +import uuid + +from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query +from fastapi import FastAPI, UploadFile, File, Form, Header +from fastapi.responses import StreamingResponse, JSONResponse, Response + +TTS_SAMPLERATE = 44100 # 22050 # 16000 +FORMAT = "mp3" +CHANNELS = 1 # 单声道 +SAMPLE_WIDTH = 2 # 16-bit = 2字节 + +tts_router = APIRouter() + + +# logger = logging.getLogger(__name__) +class MillisecondsFormatter(logging.Formatter): + """自定义日志格式器,添加毫秒时间戳""" + + def formatTime(self, record, datefmt=None): + # 将时间戳转换为本地时间元组 + ct = self.converter(record.created) + # 格式化为 "小时:分钟:秒" + t = time.strftime("%H:%M:%S", ct) + # 添加毫秒(3位) + return f"{t}.{int(record.msecs):03d}" + + +# 配置全局日志格式 +def configure_logging(): + # 创建 Formatter + log_format = "%(asctime)s - %(levelname)s - %(message)s" + formatter = MillisecondsFormatter(log_format) + + # 获取根 Logger 并清除已有配置 + root_logger = logging.getLogger() + root_logger.handlers = [] + + # 创建并配置 Handler(输出到控制台) + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + # 设置日志级别并添加 Handler + root_logger.setLevel(logging.INFO) + root_logger.addHandler(console_handler) + + +# 调用配置函数(程序启动时运行一次) +configure_logging() + + +class StreamSessionManager: + def __init__(self): + self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}} + self.lock = threading.Lock() + self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池 + self.gc_interval = 300 # 5分钟清理一次 + self.streaming_call_timeout = 10 # 10s + self.gc_tts = 3 # 3s + self.sentence_timeout = 2 # 2000ms句子超时 + self.sentence_endings = set('。?!;.?!;') # 中英文结束符 + # 增强版正则表达式:匹配中英文句子结束符(包含全角) + self.sentence_pattern = re.compile( + r'([,,。?!;.?!;?!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;?!;…])' + ) + self.sentence_audio_store = {} # {sentence_id: {'data': bytes, 'text': str, 'created_at': float}} + self.sentence_lock = threading.Lock() + + threading.Thread(target=self._cleanup_expired, daemon=True).start() + + def create_session(self, tts_model, sample_rate=8000, stream_format='mp3', session_id=None, streaming_call=False): + if not session_id: + session_id = str(uuid.uuid4()) + with self.lock: + # 创建TTS实例并设置流式回调 + tts_instance = tts_model + + # 定义音频数据回调函数 + def on_data(data: bytes): + if data: + try: + self.sessions[session_id]['last_active'] = time.time() + self.sessions[session_id]['buffer'].put(data) + self.sessions[session_id]['audio_chunk_size'] += len(data) + #logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}") + except queue.Full: + logging.warning(f"Audio buffer full for session {session_id}") + """ + elif data is None: # 结束信号 + # 仅对非流式引擎触发完成事件 + if not streaming_call: + logging.info(f"StreamSessionManager on_data sentence_complete_event set") + self.sessions[session_id]['sentence_complete_event'].set() + self.sessions[session_id]['current_processing'] = False + """ + # 创建完成事件 + completion_event = threading.Event() + # 设置TTS流式传输 + tts_instance.setup_tts(on_data,completion_event) + # 创建会话 + self.sessions[session_id] = { + 'tts_model': tts_model, + 'buffer': queue.Queue(maxsize=300), # 线程安全队列 + 'task_queue': queue.Queue(), + 'active': True, + 'last_active': time.time(), + 'audio_chunk_count': 0, + 'audio_chunk_size': 0, + 'finished': threading.Event(), # 添加事件对象 + 'sample_rate': sample_rate, + 'stream_format': stream_format, + "tts_chunk_data_valid": False, + "text_buffer": "", # 新增文本缓冲区 + "last_text_time": time.time(), # 最后文本到达时间 + "streaming_call": streaming_call, + "tts_stream_started": False, # 标记是否已启动流 + "current_processing": False, # 标记是否正在处理句子 + "sentence_complete_event": completion_event, #threading.Event(), + 'sentences': [], # 存储句子ID列表 + 'current_sentence_index': 0, + 'gc_interval':300, # 5分钟清理一次 + } + # 启动任务处理线程 + threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start() + return session_id + + def append_text(self, session_id, text): + with self.lock: + session = self.sessions.get(session_id) + if not session: return + # 更新文本缓冲区和时间戳 + session['text_buffer'] += text + session['last_text_time'] = time.time() + # 将文本放入任务队列(非阻塞) + try: + session['task_queue'].put(text, block=False) + except queue.Full: + logging.warning(f"Session {session_id} task queue full") + + def finish_text_input(self, session_id): + """标记文本输入结束,通知任务处理线程""" + with self.lock: + session = self.sessions.get(session_id) + if not session: + return + session['gc_interval'] = 100 # 所有的文本输入已经结束,可以将超时检查时间缩短 + + + def _process_tasks(self, session_id): # 20250718 新更新 + """任务处理线程(每个会话独立)- 保留原有处理逻辑""" + session = self.sessions.get(session_id) + if not session or not session['active']: + return + + # 根据引擎类型选择处理函数 + if session.get('streaming_call'): + gen_tts_audio_func = self._generate_audio #self._stream_audio + else: + gen_tts_audio_func = self._generate_audio + + while session['active']: + current_time = time.time() + text_to_process = "" + + # 1. 获取待处理文本 + with self.lock: + if session['text_buffer']: + text_to_process = session['text_buffer'] + + + # 2. 处理文本 + if text_to_process and not session['current_processing'] : + session['text_buffer'] = "" + # 分割完整句子 + complete_sentences, remaining_text = self._split_and_extract(text_to_process) + + # 保存剩余文本 + if remaining_text: + with self.lock: + session['text_buffer'] = remaining_text + session['text_buffer'] + + # 合并并处理完整句子 + if complete_sentences: + # 智能合并句子(最长300字符) + buffer = [] + current_length = 0 + + # 处理每个句子 + for sentence in complete_sentences: + sent_length = len(sentence) + + # 添加到当前缓冲区 + if current_length + sent_length <= 300: + buffer.append(sentence) + current_length += sent_length + else: + # 处理已缓冲的文本 + if buffer: + combined_text = "".join(buffer) + + # 重置完成事件状态 + session['sentence_complete_event'].clear() + session['current_processing'] = True + + # 生成音频 + gen_tts_audio_func(session_id, combined_text) + + # 等待完成 + if not session['sentence_complete_event'].wait(timeout=120.0): + logging.warning(f"Timeout waiting for TTS completion: {combined_text}") + # 重置处理状态 + time.sleep(5.0) + session['current_processing'] = False + logging.info(f"StreamSessionManager _process_tasks 转换结束!!!") + # 重置缓冲区 + buffer = [sentence] + current_length = sent_length + + # 处理剩余的缓冲文本 + if buffer: + combined_text = "".join(buffer) + session['current_processing'] = True + + # 生成音频 + gen_tts_audio_func(session_id, combined_text) #如果调用_stream_audio,则是同步调用,会阻塞,直到音频生成完成 + + # 重置处理状态 + time.sleep(1.0) + session['current_processing'] = False + + + # 3. 检查超时未处理的文本 + if current_time - session['last_text_time'] > self.sentence_timeout: + with self.lock: + if session['text_buffer']: + # 直接处理剩余文本 + session['current_processing'] = True + gen_tts_audio_func(session_id, session['text_buffer']) + session['text_buffer'] = "" + # 重置处理状态 + session['current_processing'] = False + + # 4. 会话超时检查 + if current_time - session['last_active'] > session['gc_interval']: + # 处理剩余文本 + with self.lock: + if session['text_buffer']: + session['current_processing'] = True + + # 处理最后一段文本 + gen_tts_audio_func(session_id, session['text_buffer']) + session['text_buffer'] = "" + # 重置处理状态 + session['current_processing'] = False + + # 关闭会话 + logging.info(f"--_process_tasks-- timeout {session['last_active']} {session['gc_interval']}") + self.close_session(session_id) + break + + # 5. 休眠避免CPU空转 + time.sleep(0.05) # 50ms检查间隔 + + + def _generate_audio(self, session_id, text): # 20250718 新更新 + """实际生成音频(顺序执行)- 用于非流式引擎""" + session = self.sessions.get(session_id) + if not session: + return + + try: + #logging.info(f"StreamSessionManager _generate_audio--0 {text}") + # 创建内存流 + audio_stream = io.BytesIO() + + # 定义回调函数:直接写入流 + def on_data_sentence(data: bytes): + if data: + audio_stream.write(data) + + def on_data_whole(data: bytes): + if data: + try: + session['last_active'] = time.time() + session['buffer'].put({'type':'arraybuffer','data':data}) + #session['buffer'].put(data) + session['audio_chunk_size'] += len(data) + #logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}") + except queue.Full: + logging.warning(f"Audio buffer full for session {session_id}") + + # 重置完成事件 + session['sentence_complete_event'].clear() + session['tts_model'].setup_tts(on_data = on_data_whole,completion_event=session['sentence_complete_event']) + # 调用 TTS + session['tts_model'].text_tts_call(text) + session['last_active'] = time.time() + session['audio_chunk_count'] += 1 + # 等待句子完成 + if not session['sentence_complete_event'].wait(timeout=30): # 30秒超时 + logging.warning(f"Timeout generating audio for: {text[:20]}...") + logging.info(f"StreamSessionManager _generate_audio 转换结束!!!") + session['buffer'].put({'type': 'sentence_end', 'data': ""}) + + # 获取音频数据 + audio_data = audio_stream.getvalue() + audio_stream.close() + + # 保存到句子存储 + self.add_sentence_audio(session_id, text, audio_data) + + if not session['tts_chunk_data_valid']: + session['tts_chunk_data_valid'] = True + + except Exception as e: + session['buffer'].put(f"ERROR:{str(e)}".encode()) + session['sentence_complete_event'].set() # 确保事件被设置 + + def _stream_audio(self, session_id, text): + """流式传输文本到TTS服务""" + session = self.sessions.get(session_id) + if not session: + return + # logging.info(f"Streaming text to TTS: {text}") + + try: + # 重置完成事件状态 + session['sentence_complete_event'].clear() + # 使用流式调用发送文本 + session['tts_model'].streaming_call(text) + session['last_active'] = time.time() + # 流式引擎不需要等待完成事件 + session['sentence_complete_event'].set() + except Exception as e: + logging.error(f"Error in streaming_call: {str(e)}") + session['buffer'].put(f"ERROR:{str(e)}".encode()) + session['sentence_complete_event'].set() + + async def get_tts_buffer_data(self, session_id): + """异步流式返回 TTS 音频数据(适配同步 queue.Queue,带 10 秒超时)""" + session = self.sessions.get(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + buffer = session['buffer'] # 这里是 queue.Queue + last_data_time = time.time() # 记录最后一次获取数据的时间 + + while session['active']: + try: + # 使用 run_in_executor + wait_for 设置 10 秒超时 + data = await asyncio.wait_for( + asyncio.get_event_loop().run_in_executor(None, buffer.get), + timeout=10.0 # 10 秒超时 + ) + #logging.info(f"get_tts_buffer_data {data}") + last_data_time = time.time() # 更新最后数据时间 + yield data + + except asyncio.TimeoutError: + # 10 秒内没有新数据,检查是否超时 + if time.time() - last_data_time >= 10.0: + break + else: + continue # 未超时,继续等待 + + except asyncio.CancelledError: + logging.info(f"Session {session_id} stream cancelled") + break + + except Exception as e: + logging.error(f"Error in get_tts_buffer_data: {e}") + break + + def close_session(self, session_id): + with self.lock: + if session_id in self.sessions: + session = self.sessions[session_id] + session['active'] = False + + # 清理关联的音频数据 + with self.sentence_lock: + # 标记会话相关的句子为过期 + expired_sentences = [ + sid for sid, data in self.sentence_audio_store.items() + if data.get('session_id') == session_id + ] + + # 设置完成事件 + session['sentence_complete_event'].set() + + # 清理TTS资源 + try: + if session.get('streaming_call'): + session['tts_model'].end_streaming_call() + except: + pass + + # 延迟清理会话 + threading.Timer(1, self._clean_session, args=[session_id]).start() + + def _clean_session(self, session_id): + with self.lock: + if session_id in self.sessions: + del self.sessions[session_id] + + def _cleanup_expired(self): + """定时清理过期资源""" + while True: + time.sleep(30) + now = time.time() + + # 清理过期会话 + with self.lock: + expired_sessions = [ + sid for sid, session in self.sessions.items() + if now - session['last_active'] > self.gc_interval + ] + for sid in expired_sessions: + self._clean_session(sid) + + # 清理过期句子音频 + with self.sentence_lock: + expired_sentences = [ + sid for sid, data in self.sentence_audio_store.items() + if now - data['created_at'] > self.gc_interval + ] + for sid in expired_sentences: + del self.sentence_audio_store[sid] + + logging.info(f"清理资源: {len(expired_sessions)}会话, {len(expired_sentences)}句子") + + def get_session(self, session_id): + return self.sessions.get(session_id) + + def _has_sentence_ending(self, text): + """检测文本是否包含句子结束符""" + if not text: + return False + + # 检查常见结束符(包含全角字符) + if any(char in self.sentence_endings for char in text[-3:]): + return True + + # 检查中文段落结束(换行符前有结束符) + if '\n' in text and any(char in self.sentence_endings for char in text.split('\n')[-2:-1]): + return True + + return False + + def _split_and_extract(self, text): + """ + 增强型句子分割器 + 返回: (完整句子列表, 剩余文本) + """ + # 特殊处理:如果文本以逗号开头,先处理前面的部分 + if text.startswith((",", ",")): + return [text[0]], text[1:] + + # 1. 查找所有可能的句子结束位置 + matches = list(self.sentence_pattern.finditer(text)) + + if not matches: + return [], text # 没有找到结束符 + + # 2. 确定最后一个完整句子的结束位置 + last_end = 0 + complete_sentences = [] + + for match in matches: + end_pos = match.end() + sentence = text[last_end:end_pos].strip() + + # 跳过空句子 + if not sentence: + last_end = end_pos + continue + + # 检查是否为有效句子(最小长度或包含结束符) + if (len(sentence) > 6 or any(char in "。.?!?!" for char in sentence)) and (last_end<24): + complete_sentences.append(sentence) + last_end = end_pos + else: + # 短文本但包含结束符,可能是特殊符号 + if any(char in "。.?!?!" for char in sentence): + complete_sentences.append(sentence) + last_end = end_pos + + # 3. 提取剩余文本 + remaining_text = text[last_end:].strip() + return complete_sentences, remaining_text + + def add_sentence_audio(self, session_id, sentence_text, audio_data: bytes): + """添加句子音频到存储""" + with self.lock: + if session_id not in self.sessions: + return None + + # 生成唯一句子ID + sentence_id = str(uuid.uuid4()) + + # 存储音频数据 + with self.sentence_lock: + self.sentence_audio_store[sentence_id] = { + 'data': audio_data, + 'text': sentence_text, + 'created_at': time.time(), + 'session_id': session_id, + 'format': self.sessions[session_id]['stream_format'] + } + logging.info(f" StreamSessionManager add_sentence_audio") + # 添加到会话的句子列表 + self.sessions[session_id]['sentences'].append(sentence_id) + + return sentence_id + + def get_sentence_audio(self, sentence_id): + """获取句子音频数据""" + with self.sentence_lock: + if sentence_id not in self.sentence_audio_store: + return None + return self.sentence_audio_store[sentence_id]['data'] + + def get_sentence_info(self, sentence_id): + """获取句子信息""" + with self.sentence_lock: + if sentence_id not in self.sentence_audio_store: + return None + return self.sentence_audio_store[sentence_id] + + def get_next_sentence(self, session_id): + """获取下一个句子的信息""" + with self.lock: + session = self.sessions.get(session_id) + if not session or not session['active']: + return None + + if session['current_sentence_index'] < len(session['sentences']): + sentence_id = session['sentences'][session['current_sentence_index']] + session['current_sentence_index'] += 1 + return { + 'id': sentence_id, + 'url': f"/tts_sentence/{sentence_id}" # 虚拟URL + } + return None + def get_sentence_audio_data(self, session_id): + with self.lock: + session = self.sessions.get(session_id) + if not session or not session['active']: + return None + return self.sentence_audio_store + +stream_manager = StreamSessionManager() + + +def allowed_file(filename): + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} + + +audio_text_cache = {} +cache_lock = Lock() +CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 + + +# WebSocket 连接管理 +class ConnectionManager: + def __init__(self): + self.active_connections = {} + + async def connect(self, websocket: WebSocket, connection_id: str): + await websocket.accept() + self.active_connections[connection_id] = websocket + logging.info(f"新连接建立: {connection_id}") + + async def disconnect(self, connection_id: str, code=1000, reason: str = ""): + if connection_id in self.active_connections: + try: + # 尝试正常关闭连接(非阻塞) + await self.active_connections[connection_id].close(code=code, reason=reason) + except: + pass # 忽略关闭错误 + finally: + del self.active_connections[connection_id] + + def is_connected(self, connection_id: str) -> bool: + """检查连接是否仍然活跃""" + return connection_id in self.active_connections + + async def _safe_send(self, connection_id: str, send_func, *args): + """安全发送的通用方法(核心修改)""" + # 1. 检查连接是否存在 + if connection_id not in self.active_connections: + logging.warning(f"尝试向不存在的连接发送数据: {connection_id}") + return False + + websocket = self.active_connections[connection_id] + + try: + # 2. 检查连接状态(关键修改) + if websocket.client_state.name != "CONNECTED": + logging.warning(f"连接 {connection_id} 状态为 {websocket.client_state.name}") + await self.disconnect(connection_id) + return False + + # 3. 执行发送操作 + await send_func(websocket, *args) + return True + + except (WebSocketDisconnect, RuntimeError) as e: + # 4. 处理连接断开异常 + logging.info(f"发送时检测到断开连接: {connection_id}, {str(e)}") + await self.disconnect(connection_id) + return False + except Exception as e: + # 5. 处理其他异常 + logging.error(f"发送数据出错: {connection_id}, {str(e)}") + await self.disconnect(connection_id) + return False + + async def send_bytes(self, connection_id: str, data: bytes): + """安全发送字节数据""" + return await self._safe_send( + connection_id, + lambda ws, d: ws.send_bytes(d), + data + ) + + async def send_text(self, connection_id: str, message: str): + """安全发送文本数据""" + return await self._safe_send( + connection_id, + lambda ws, m: ws.send_text(m), + message + ) + + async def send_json(self, connection_id: str, data: dict): + """安全发送JSON数据""" + return await self._safe_send( + connection_id, + lambda ws, d: ws.send_json(d), + data + ) + + +manager = ConnectionManager() + + +def generate_mp3_header( + sample_rate: int, + bitrate_kbps: int, + channels: int = 1, + layer: str = "III" # 新增参数,支持 "I"/"II"/"III" +) -> bytes: + """ + 动态生成 MP3 帧头(4字节),支持 Layer I/II/III + + :param sample_rate: 采样率 (8000, 16000, 22050, 44100) + :param bitrate_kbps: 比特率(单位 kbps) + :param channels: 声道数 (1: 单声道, 2: 立体声) + :param layer: 编码层 ("I", "II", "III") + :return: 4字节的帧头数据 + """ + # ---------------------------------- + # 参数校验 + # ---------------------------------- + valid_sample_rates = {8000, 16000, 22050, 44100, 48000} + if sample_rate not in valid_sample_rates: + raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}") + + valid_layers = {"I", "II", "III"} + if layer not in valid_layers: + raise ValueError(f"不支持的层,可选:{valid_layers}") + + # ---------------------------------- + # 确定 MPEG 版本和采样率索引 + # ---------------------------------- + if sample_rate == 44100: + mpeg_version = 0b11 # MPEG-1 + sample_rate_index = 0b00 + elif sample_rate == 22050: + mpeg_version = 0b10 # MPEG-2 + sample_rate_index = 0b00 + elif sample_rate == 16000: + mpeg_version = 0b10 # MPEG-2 + sample_rate_index = 0b10 + elif sample_rate == 8000: + mpeg_version = 0b00 # MPEG-2.5 + sample_rate_index = 0b10 + else: + raise ValueError("采样率与版本不匹配") + + # ---------------------------------- + # 动态选择比特率表(关键扩展) + # ---------------------------------- + # Layer 编码映射(I:0b11, II:0b10, III:0b01) + layer_code = { + "I": 0b11, + "II": 0b10, + "III": 0b01 + }[layer] + + # 比特率表(覆盖所有 Layer) + bitrate_tables = { + # ------------------------------- + # MPEG-1 (0b11) + # ------------------------------- + # Layer I + (0b11, 0b11): { + 32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011, + 160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111, + 288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011, + 416: 0b1100, 448: 0b1101 + }, + # Layer II + (0b11, 0b10): { + 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, + 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, + 160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011, + 320: 0b1100, 384: 0b1101 + }, + # Layer III + (0b11, 0b01): { + 32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011, + 64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111, + 128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011, + 256: 0b0100, 320: 0b0101 + }, + + # ------------------------------- + # MPEG-2 (0b10) / MPEG-2.5 (0b00) + # ------------------------------- + # Layer I + (0b10, 0b11): { + 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, + 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, + 144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011, + 224: 0b1100, 256: 0b1101 + }, + (0b00, 0b11): { + 32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011, + 80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111, + 144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011, + 224: 0b1100, 256: 0b1101 + }, + + # Layer II + (0b10, 0b10): { + 8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011, + 40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111, + 80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011, + 144: 0b1100, 160: 0b1101 + }, + (0b00, 0b10): { + 8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011, + 40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111, + 80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011, + 144: 0b1100, 160: 0b1101 + }, + + # Layer III + (0b10, 0b01): { + 8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011, + 40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111, + 80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011, + 144: 0b0100, 160: 0b0101 + }, + (0b00, 0b01): { + 8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011, + 40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111 + } + } + + # 获取当前版本的比特率表 + key = (mpeg_version, layer_code) + if key not in bitrate_tables: + raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}") + bitrate_table = bitrate_tables[key] + + if bitrate_kbps not in bitrate_table: + raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}") + bitrate_index = bitrate_table[bitrate_kbps] + + # ---------------------------------- + # 确定声道模式 + # ---------------------------------- + if channels == 1: + channel_mode = 0b11 # 单声道 + elif channels == 2: + channel_mode = 0b00 # 立体声 + else: + raise ValueError("声道数必须为1或2") + + # ---------------------------------- + # 组合帧头字段(修正层编码) + # ---------------------------------- + sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111) + version = mpeg_version << 19 # MPEG 版本 2位 + layer_bits = layer_code << 17 # Layer 编码(I:0b11, II:0b10, III:0b01) + protection = 0 << 16 # 无 CRC + bitrate_bits = bitrate_index << 12 + sample_rate_bits = sample_rate_index << 10 + padding = 0 << 9 # 无填充 + private = 0 << 8 + mode = channel_mode << 6 + mode_ext = 0 << 4 # 扩展模式(单声道无需设置) + copyright = 0 << 3 + original = 0 << 2 + emphasis = 0b00 # 无强调 + + frame_header = ( + sync | + version | + layer_bits | + protection | + bitrate_bits | + sample_rate_bits | + padding | + private | + mode | + mode_ext | + copyright | + original | + emphasis + ) + + return frame_header.to_bytes(4, byteorder='big') + + +# ------------------------------------------------ +def audio_fade_in(audio_data, fade_length): + # 假设音频数据是16位单声道PCM + # 将二进制数据转换为整数数组 + samples = array.array('h', audio_data) + + # 对前fade_length个样本进行淡入处理 + for i in range(fade_length): + fade_factor = i / fade_length + samples[i] = int(samples[i] * fade_factor) + + # 将整数数组转换回二进制数据 + return samples.tobytes() + + +def parse_markdown_json(json_string): + # 使用正则表达式匹配Markdown中的JSON代码块 + match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL) + if match: + try: + # 尝试解析JSON字符串 + data = json.loads(match[1]) + return {'success': True, 'data': data} + except json.JSONDecodeError as e: + # 如果解析失败,返回错误信息 + return {'success': False, 'data': str(e)} + else: + return {'success': False, 'data': 'not a valid markdown json string'} + + +audio_text_cache = {} +cache_lock = Lock() +CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 + + +# 全角字符到半角字符的映射 +def fullwidth_to_halfwidth(s): + full_to_half_map = { + '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", + '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', + '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', + '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', + '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', + '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', + '、': '、', '・': '・', ':': ':' + } + return ''.join(full_to_half_map.get(char, char) for char in s) + + +def split_text_at_punctuation(text, chunk_size=100): + # 使用正则表达式找到所有的标点符号和特殊字符 + punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' + tokens = re.split(punctuation_pattern, text) + + # 移除空字符串 + tokens = [token for token in tokens if token] + + # 存储最终的文本块 + chunks = [] + current_chunk = '' + + for token in tokens: + if len(current_chunk) + len(token) <= chunk_size: + # 如果添加当前token后长度不超过chunk_size,则添加到当前块 + current_chunk += (token + ' ') + else: + # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 + chunks.append(current_chunk.strip()) + current_chunk = token + ' ' + + # 添加最后一个块(如果有剩余) + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def extract_text_from_markdown(markdown_text): + # 移除Markdown标题 + text = re.sub(r'#\s*[^#]+', '', markdown_text) + # 移除内联代码块 + text = re.sub(r'`[^`]+`', '', text) + # 移除代码块 + text = re.sub(r'```[\s\S]*?```', '', text) + # 移除加粗和斜体 + text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) + # 移除链接 + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + # 移除图片 + text = re.sub(r'!\[.*?\]\(.*?\)', '', text) + # 移除HTML标签 + text = re.sub(r'<[^>]+>', '', text) + # 转换标点符号 + # text = re.sub(r'[^\w\s]', '', text) + text = fullwidth_to_halfwidth(text) + # 移除多余的空格 + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def encode_gzip_base64(original_data: bytes) -> str: + """核心编码过程:二进制数据 → Gzip压缩 → Base64编码""" + # Step 1: Gzip 压缩 + with BytesIO() as buf: + with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file: + gz_file.write(original_data) + compressed_bytes = buf.getvalue() + + # Step 2: Base64 编码(配置与Android端匹配) + return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP) + + +def clean_audio_cache(): + """定时清理过期缓存""" + with cache_lock: + now = time.time() + expired_keys = [ + k for k, v in audio_text_cache.items() + if now - v['created_at'] > CACHE_EXPIRE_SECONDS + ] + for k in expired_keys: + entry = audio_text_cache.pop(k, None) + if entry and entry.get('audio_stream'): + entry['audio_stream'].close() + + +def start_background_cleaner(): + """启动后台清理线程""" + + def cleaner_loop(): + while True: + time.sleep(180) # 每3分钟清理一次 + clean_audio_cache() + + cleaner_thread = Thread(target=cleaner_loop, daemon=True) + cleaner_thread.start() + + +def test_qwen_chat(): + messages = [ + {'role': 'system', 'content': 'You are a helpful assistant.'}, + {'role': 'user', 'content': '你是谁?'} + ] + response = Generation.call( + # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx", + api_key=ALI_KEY, + model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models + messages=messages, + result_format="message" + ) + + if response.status_code == 200: + print(response.output.choices[0].message.content) + else: + print(f"HTTP返回码:{response.status_code}") + print(f"错误码:{response.code}") + print(f"错误信息:{response.message}") + print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code") + + +ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75" +from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse +from dashscope.audio.tts import ( + ResultCallback as TTSResultCallback, + SpeechSynthesizer as TTSSpeechSynthesizer, + SpeechSynthesisResult as TTSSpeechSynthesisResult, +) +# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本 +from dashscope.audio.tts_v2 import ( + ResultCallback as CosyResultCallback, + SpeechSynthesizer as CosySpeechSynthesizer, + AudioFormat, +) + + +class QwenTTS: + def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun", + special_characters: Optional[Dict[str, str]] = None): + import dashscope + import ssl + logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx + self.model_name = model_name.split('@')[0] + dashscope.api_key = key + ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证 + self.synthesizer = None + self.callback = None + self.is_cosyvoice = False + self.voice = "" + self.format = format + self.sample_rate = sample_rate + self.first_chunk = True + if '/' in self.model_name: + parts = self.model_name.split('/', 1) + # 返回分离后的两个字符串parts[0], parts[1] + if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2': + self.is_cosyvoice = True + self.voice = parts[1] + self.completion_event = None # 新增:用于通知任务完成 + # 特殊字符及其拼音映射 + self.special_characters = special_characters or { + "㼽": "chuang3", + "䡇": "yue4" + # 可以添加更多特殊字符的映射 + } + + class Callback(TTSResultCallback): + def __init__(self,data_callback=None,completion_event=None) -> None: + self.dque = deque() + self.data_callback = data_callback + self.completion_event = completion_event # 新增完成事件引用 + def _run(self): + while True: + if not self.dque: + time.sleep(0) + continue + val = self.dque.popleft() + if val: + yield val + else: + break + + def on_open(self): + pass + + def on_complete(self): + self.dque.append(None) + if self.data_callback: + self.data_callback(None) # 发送结束信号 + # 通知任务完成 + if self.completion_event: + self.completion_event.set() + + def on_error(self, response: SpeechSynthesisResponse): + print("Qwen tts error", str(response)) + raise RuntimeError(str(response)) + + def on_close(self): + pass + + def on_event(self, result: TTSSpeechSynthesisResult): + data =result.get_audio_frame() + if data is not None: + if len(data) > 0: + if self.data_callback: + self.data_callback(data) + else: + self.dque.append(data) + #self.dque.append(result.get_audio_frame()) + + # -------------------------- + + class Callback_Cosy(CosyResultCallback): + def __init__(self, data_callback=None,completion_event=None) -> None: + self.dque = deque() + self.data_callback = data_callback + self.completion_event = completion_event # 新增完成事件引用 + + def _run(self): + while True: + if not self.dque: + time.sleep(0) + continue + val = self.dque.popleft() + if val: + yield val + else: + break + + def on_open(self): + logging.info("Qwen CosyVoice tts open ") + pass + + def on_complete(self): + self.dque.append(None) + if self.data_callback: + self.data_callback(None) # 发送结束信号 + # 通知任务完成 + if self.completion_event: + self.completion_event.set() + + def on_error(self, response: SpeechSynthesisResponse): + print("Qwen tts error", str(response)) + if self.data_callback: + self.data_callback(f"ERROR:{str(response)}".encode()) + raise RuntimeError(str(response)) + + def on_close(self): + # print("---Qwen call back close") # cyx + logging.info("Qwen CosyVoice tts close") + pass + + """ canceled for test 语音大模型CosyVoice + def on_event(self, result: SpeechSynthesisResult): + if result.get_audio_frame() is not None: + self.dque.append(result.get_audio_frame()) + """ + + def on_event(self, message): + # logging.info(f"recv speech synthsis message {message}") + pass + + # 以下适合语音大模型CosyVoice + def on_data(self, data: bytes) -> None: + if len(data) > 0: + if self.data_callback: + self.data_callback(data) + else: + self.dque.append(data) + + # -------------------------- + + def tts(self, text, on_data = None,completion_event=None): + # logging.info(f"---QwenTTS tts begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx + # text = self.normalize_text(text) + print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx + # text = self.normalize_text(text) + + try: + # if self.model_name != 'cosyvoice-v1': + if self.is_cosyvoice is False: + self.callback = self.Callback( + data_callback=on_data, + completion_event=completion_event + ) + TTSSpeechSynthesizer.call(model=self.model_name, + text=text, + callback=self.callback, + format=self.format) # format="mp3") + else: + self.callback = self.Callback_Cosy() + format = self.get_audio_format(self.format, self.sample_rate) + self.synthesizer = CosySpeechSynthesizer( + model='cosyvoice-v2', + # voice="longyuan", #"longfei", + voice=self.voice, + callback=self.callback, + format=format + ) + self.synthesizer.call(text) + except Exception as e: + print(f"---dale---20 error {e}") # cyx + # ----------------------------------- + try: + for data in self.callback._run(): + # logging.info(f"dashcope return data {len(data)}") + yield data + # print(f"---Qwen return data {num_tokens_from_string(text)}") + # yield num_tokens_from_string(text) + + except Exception as e: + raise RuntimeError(f"**ERROR**: {e}") + + def setup_tts(self, on_data,completion_event=None): + + """设置 TTS 回调,返回配置好的 synthesizer""" + #if not self.is_cosyvoice: + # raise NotImplementedError("Only CosyVoice supported") + + if self.is_cosyvoice: + # 创建 CosyVoice 回调 + self.callback = self.Callback_Cosy( + data_callback=on_data, + completion_event=completion_event) + else: + self.callback = self.Callback( + data_callback=on_data, + completion_event=completion_event) + + + if self.is_cosyvoice: + format_val = self.get_audio_format(self.format, self.sample_rate) + # logging.info(f"Qwen setup_tts {self.voice} {format_val}") + self.synthesizer = CosySpeechSynthesizer( + model='cosyvoice-v1', + voice=self.voice, # voice="longyuan", #"longfei", + callback=self.callback, + format=format_val + ) + + return self.synthesizer + + def apply_phoneme_tags(self, text: str) -> str: + """ + 在文本中查找特殊字符并用标签包裹它们 + """ + # 如果文本已经是SSML格式,直接返回 + if text.strip().startswith("") and text.strip().endswith(""): + return text + + # 为特殊字符添加SSML标签 + for char, pinyin in self.special_characters.items(): + # 使用正则表达式确保只替换整个字符(避免部分匹配) + pattern = r'([^<]|^)' + re.escape(char) + r'([^>]|$)' + replacement = r'\1' + char + r'\2' + text = re.sub(pattern, replacement, text) + + # 如果文本中已有SSML标签,直接返回 + if "" in text: + return text + + # 否则包裹在标签中 + return f"{text}" + + def text_tts_call(self, text): + if self.special_characters and self.is_cosyvoice is False: + text = self.apply_phoneme_tags(text) + #logging.info(f"Applied SSML phoneme tags to text: {text}") + + if self.synthesizer and self.is_cosyvoice: + logging.info(f"Qwen text_tts_call {text} {self.is_cosyvoice}") + format_val = self.get_audio_format(self.format, self.sample_rate) + self.synthesizer = CosySpeechSynthesizer( + model='cosyvoice-v1', + voice=self.voice, # voice="longyuan", #"longfei", + callback=self.callback, + format=format_val + ) + self.synthesizer.call(text) + if self.is_cosyvoice is False: + logging.info(f"Qwen text_tts_call {text}") + TTSSpeechSynthesizer.call(model=self.model_name, + text=text, + callback=self.callback, + format=self.format) + + def streaming_call(self, text): + if self.synthesizer: + self.synthesizer.streaming_call(text) + + def end_streaming_call(self): + if self.synthesizer: + # logging.info(f"---dale end_streaming_call") + self.synthesizer.streaming_complete() + + def get_audio_format(self, format: str, sample_rate: int): + """动态获取音频格式""" + from dashscope.audio.tts_v2 import AudioFormat + logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}") + format_map = { + (8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS, + (8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT, + (8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT, + (16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT, + (22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS, + (22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT, + (22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT, + (44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS, + (44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT, + (44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT, + (48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS, + (48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT, + (48000, 'wav'): AudioFormat.WAV_48000HZ_MONO_16BIT + + } + return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS) + + +class DoubaoTTS: + def __init__(self, key, format="mp3", sample_rate=8000, model_name="doubao-tts"): + logging.info(f"---begin--init DoubaoTTS-- {format} {sample_rate} {model_name}") + # 解析豆包认证信息 (appid, token, cluster, voice_type) + try: + self.appid = "7282190702" + self.token = "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9" + self.cluster = "volcano_tts" + self.voice_type ="zh_female_qingxinnvsheng_mars_bigtts" # "zh_male_jieshuonansheng_mars_bigtts" #"zh_male_ruyaqingnian_mars_bigtts" #"zh_male_jieshuonansheng_mars_bigtts" + except Exception as e: + raise ValueError(f"Invalid Doubao key format: {str(e)}") + + self.format = format + self.sample_rate = sample_rate + self.model_name = model_name + self.callback = None + self.ws = None + self.loop = None + self.task = None + self.event = threading.Event() + self.data_queue = deque() + self.host = "openspeech.bytedance.com" + self.api_url = f"wss://{self.host}/api/v1/tts/ws_binary" + self.default_header = bytearray(b'\x11\x10\x11\x00') + self.total_data_size = 0 + self.completion_event = None # 新增:用于通知任务完成 + + + class Callback: + def __init__(self, data_callback=None,completion_event=None): + self.data_callback = data_callback + self.data_queue = deque() + self.completion_event = completion_event # 完成事件引用 + + def on_data(self, data): + if self.data_callback: + self.data_callback(data) + else: + self.data_queue.append(data) + # 通知任务完成 + if self.completion_event: + self.completion_event.set() + + def on_complete(self): + if self.data_callback: + self.data_callback(None) + + def on_error(self, error): + if self.data_callback: + self.data_callback(f"ERROR:{error}".encode()) + + def setup_tts(self, on_data,completion_event): + + """设置回调,返回自身(因为豆包需要异步启动)""" + self.callback = self.Callback( + data_callback=on_data, + completion_event=completion_event + ) + return self + + def text_tts_call(self, text): + """同步调用,启动异步任务并等待完成""" + self.total_data_size = 0 + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.task = self.loop.create_task(self._async_tts(text)) + try: + self.loop.run_until_complete(self.task) + except Exception as e: + logging.error(f"DoubaoTTS--0 call error: {e}") + self.callback.on_error(str(e)) + + async def _async_tts(self, text): + """异步执行TTS请求""" + header = {"Authorization": f"Bearer; {self.token}"} + request_json = { + "app": { + "appid": self.appid, + "token": "access_token", # 固定值 + "cluster": self.cluster + }, + "user": { + "uid": str(uuid.uuid4()) # 随机用户ID + }, + "audio": { + "voice_type": self.voice_type, + "encoding": self.format, + "speed_ratio": 1.0, + "volume_ratio": 1.0, + "pitch_ratio": 1.0, + }, + "request": { + "reqid": str(uuid.uuid4()), + "text": text, + "text_type": "plain", + "operation": "submit" # 使用submit模式支持流式 + } + } + + # 构建请求数据 + payload_bytes = str.encode(json.dumps(request_json)) + payload_bytes = gzip.compress(payload_bytes) + full_client_request = bytearray(self.default_header) + full_client_request.extend(len(payload_bytes).to_bytes(4, 'big')) + full_client_request.extend(payload_bytes) + + try: + async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws: + self.ws = ws + await ws.send(full_client_request) + + # 接收音频数据 + while True: + res = await ws.recv() + done = self._parse_response(res) + if done: + self.callback.on_complete() + break + except Exception as e: + logging.error(f"DoubaoTTS--1 WebSocket error: {e}") + self.callback.on_error(str(e)) + finally: + # 通知任务完成 + if self.completion_event: + self.completion_event.set() + + + def _parse_response(self, res): + """解析豆包返回的二进制响应""" + # 协议头解析 (4字节) + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + payload = res[header_size * 4:] + + # 音频数据响应 + if message_type == 0xb: # audio-only server response + message_flags = res[1] & 0x0f + + # ACK消息,忽略 + if message_flags == 0: + return False + + # 音频数据消息 + sequence_number = int.from_bytes(payload[:4], "big", signed=True) + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + audio_data = payload[8:8 + payload_size] + + if audio_data: + self.total_data_size = self.total_data_size + len(audio_data) + self.callback.on_data(audio_data) + + #logging.info(f"doubao _parse_response: {sequence_number} {len(audio_data)} {self.total_data_size}") + # 序列号为负表示结束 + return sequence_number < 0 + + # 错误响应 + elif message_type == 0xf: + code = int.from_bytes(payload[:4], "big", signed=False) + msg_size = int.from_bytes(payload[4:8], "big", signed=False) + error_msg = payload[8:8 + msg_size] + + try: + # 尝试解压错误消息 + error_msg = gzip.decompress(error_msg).decode() + except: + error_msg = error_msg.decode(errors='ignore') + + logging.error(f"DoubaoTTS error: {error_msg}") + self.callback.on_error(error_msg) + return False + + return False + + +class UnifiedTTSEngine: + def __init__(self): + self.lock = threading.Lock() + self.tasks = {} + self.executor = ThreadPoolExecutor(max_workers=10) + self.cache_expire = 300 # 5分钟缓存 + # 启动清理过期任务的定时器 + self.cleanup_timer = None + self.start_cleanup_timer() + + def _cleanup_old_tasks(self): + """清理过期任务""" + now = time.time() + with self.lock: + expired_ids = [task_id for task_id, task in self.tasks.items() + if now - task['created_at'] > self.cache_expire] + for task_id in expired_ids: + self._remove_task(task_id) + + def _remove_task(self, task_id): + """移除任务""" + if task_id in self.tasks: + task = self.tasks.pop(task_id) + # 取消可能的后台任务 + if 'future' in task and not task['future'].done(): + task['future'].cancel() + # 其他资源在任务被移除后会被垃圾回收 + # 资源释放机制总结: + # 移除任务引用:self.tasks.pop() 解除任务对象引用,触发垃圾回收。 + # 取消后台线程:future.cancel() 终止未完成线程,释放线程资源。 + # 自动内存回收:Python GC 回收任务对象及其队列、缓冲区占用的内存。 + # 线程池管理:执行器自动回收线程至池中,避免资源泄漏。 + + def create_tts_task(self, text, format, sample_rate, model_name, key, delay_gen_audio=False): + """创建TTS任务(同步方法)""" + self._cleanup_old_tasks() + audio_stream_id = str(uuid.uuid4()) + + # 创建任务数据结构 + task_data = { + 'id': audio_stream_id, + 'text': text, + 'format': format, + 'sample_rate': sample_rate, + 'model_name': model_name, + 'key': key, + 'delay_gen_audio': delay_gen_audio, + 'created_at': time.time(), + 'status': 'pending', + 'data_queue': deque(), + 'event': threading.Event(), + 'completed': False, + 'error': None + } + with self.lock: + self.tasks[audio_stream_id] = task_data + + # 如果不是延迟模式,立即启动任务 + if not delay_gen_audio: + self._start_tts_task(audio_stream_id) + + return audio_stream_id + + def _start_tts_task(self, audio_stream_id): + # 启动TTS任务(后台线程) + + task = self.tasks.get(audio_stream_id) + if not task or task['status'] != 'pending': + return + logging.info("已经启动 start tts task {audio_stream_id}") + task['status'] = 'processing' + + # 在后台线程中执行TTS + future = self.executor.submit(self._run_tts_sync, audio_stream_id) + task['future'] = future + + # 如果需要等待任务完成 + if not task.get('delay_gen_audio', True): + try: + # 等待任务完成(最多5分钟) + future.result(timeout=300) + logging.info(f"TTS任务 {audio_stream_id} 已完成") + self._merge_audio_data(audio_stream_id) + except concurrent.futures.TimeoutError: + task['error'] = "TTS生成超时" + task['completed'] = True + logging.error(f"TTS任务 {audio_stream_id} 超时") + except Exception as e: + task['error'] = f"ERROR:{str(e)}" + task['completed'] = True + logging.exception(f"TTS任务执行异常: {str(e)}") + + def _run_tts_sync(self, audio_stream_id): + # 同步执行TTS生成 在后台线程中执行 + task = self.tasks.get(audio_stream_id) + if not task: + return + + try: + # 创建完成事件 + completion_event = threading.Event() + # 创建TTS实例 + # 根据model_name选择TTS引擎 + # 前端传入 cosyvoice-v1/longhua@Tongyi-Qianwen + model_name_wo_brand = task['model_name'].split('@')[0] + model_name_version = model_name_wo_brand.split('/')[0] + if "longhua" in task['model_name'] or "zh_female_qingxinnvsheng_mars_bigtts" in task['model_name']: + # 豆包TTS + tts = DoubaoTTS( + key=task['key'], + format=task['format'], + sample_rate=task['sample_rate'], + model_name=task['model_name'] + ) + else: + # 通义千问TTS + tts = QwenTTS( + key=task['key'], + format=task['format'], + sample_rate=task['sample_rate'], + model_name=task['model_name'] + ) + + # 定义同步数据处理函数 + def data_handler(data): + if data is None: # 结束信号 + task['completed'] = True + task['event'].set() + logging.info(f"--data_handler on_complete") + elif data.startswith(b"ERROR"): # 错误信号 + task['error'] = data.decode() + task['completed'] = True + task['event'].set() + else: # 音频数据 + task['data_queue'].append(data) + + + # 设置并执行TTS + synthesizer = tts.setup_tts(data_handler,completion_event) + #synthesizer.call(task['text']) + tts.text_tts_call(task['text']) + # 等待完成或超时 + # 等待完成或超时 + if not completion_event.wait(timeout=300): # 5分钟超时 + task['error'] = "TTS generation timeout" + task['completed'] = True + + logging.info(f"--tts task event set error = {task['error']}") + + except Exception as e: + logging.info(f"UnifiedTTSEngine _run_tts_sync ERROR: {str(e)}") + task['error'] = f"ERROR:{str(e)}" + task['completed'] = True + finally: + # 确保清理TTS资源 + logging.info("UnifiedTTSEngine _run_tts_sync finally") + if hasattr(tts, 'loop') and tts.loop: + tts.loop.close() + + def _merge_audio_data(self, audio_stream_id): + """将任务的所有音频数据合并到ByteIO缓冲区""" + task = self.tasks.get(audio_stream_id) + if not task or not task.get('completed'): + return + + try: + logging.info(f"开始合并音频数据: {audio_stream_id}") + + # 创建内存缓冲区 + buffer = io.BytesIO() + + # 合并所有数据块 + for data_chunk in task['data_queue']: + buffer.write(data_chunk) + + # 重置指针位置以便读取 + buffer.seek(0) + + # 保存到任务对象 + task['buffer'] = buffer + logging.info(f"音频数据合并完成,总大小: {buffer.getbuffer().nbytes} 字节") + + # 可选:清理原始数据队列以节省内存 + task['data_queue'].clear() + + except Exception as e: + logging.error(f"合并音频数据失败: {str(e)}") + task['error'] = f"合并错误: {str(e)}" + + async def get_audio_stream(self, audio_stream_id): + """获取音频流(异步生成器)""" + task = self.tasks.get(audio_stream_id) + if not task: + raise RuntimeError("Audio stream not found") + + # 如果是延迟任务且未启动,现在启动 status 为 pending + if task['delay_gen_audio'] and task['status'] == 'pending': + self._start_tts_task(audio_stream_id) + total_audio_data_size = 0 + # 等待任务启动 + while task['status'] == 'pending': + await asyncio.sleep(0.1) + + # 流式返回数据 + while not task['completed'] or task['data_queue']: + while task['data_queue']: + data = task['data_queue'].popleft() + total_audio_data_size += len(data) + #logging.info(f"yield audio data {len(data)} {total_audio_data_size}") + yield data + + # 短暂等待新数据 + await asyncio.sleep(0.05) + + # 检查错误 + if task['error']: + raise RuntimeError(task['error']) + + def start_cleanup_timer(self): + """启动定时清理任务""" + if self.cleanup_timer: + self.cleanup_timer.cancel() + + self.cleanup_timer = threading.Timer(30.0, self.cleanup_task) # 每30秒清理一次 + self.cleanup_timer.daemon = True # 设置为守护线程 + self.cleanup_timer.start() + + def cleanup_task(self): + """执行清理任务""" + try: + self._cleanup_old_tasks() + except Exception as e: + logging.error(f"清理任务时出错: {str(e)}") + finally: + self.start_cleanup_timer() # 重新启动定时器 + + +# 全局 TTS 引擎实例 +tts_engine = UnifiedTTSEngine() + + +def replace_domain(url: str) -> str: + """替换URL中的域名为本地地址,不使用urllib.parse""" + # 定义需要替换的域名列表 + domains_to_replace = [ + "http://1.13.185.116:9380", + "https://ragflow.szzysztech.com", + "1.13.185.116:9380", + "ragflow.szzysztech.com" + ] + + # 尝试替换每个可能的域名 + for domain in domains_to_replace: + if domain in url: + # 直接替换域名部分 + return url.replace(domain, "http://localhost:9380", 1) + + # 如果未匹配到特定域名,尝试智能替换 + if "://" in url: + # 分割协议和路径 + protocol, path = url.split("://", 1) + + # 查找第一个斜杠位置来确定域名结束位置 + slash_pos = path.find("/") + if slash_pos > 0: + # 替换域名部分 + return f"http://localhost:9380{path[slash_pos:]}" + else: + # 没有路径部分,直接返回本地地址 + return "http://localhost:9380" + else: + # 没有协议部分,直接添加本地地址 + return f"http://localhost:9380/{url}" + + +async def proxy_aichat_audio_stream(client_id: str, audio_url: str): + """代理外部音频流请求""" + try: + # 替换域名为本地地址 + local_url = audio_url + logging.info(f"代理音频流: {audio_url} -> {local_url}") + + async with httpx.AsyncClient(timeout=60.0) as client: + async with client.stream("GET", local_url) as response: + # 流式转发音频数据 + async for chunk in response.aiter_bytes(): + if not await manager.send_bytes(client_id, chunk): + logging.warning(f"Audio proxy interrupted for {client_id}") + return + except Exception as e: + logging.error(f"Audio proxy failed: {str(e)}") + await manager.send_text(client_id, json.dumps({ + "type": "error", + "message": f"音频流获取失败: {str(e)}" + })) + + +# 代理函数 - 文本流 +# 在微信小程序中,原来APK使用的SSE机制不能正常工作,需要使用WebSocket +async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict): + """代理大模型文本流请求 - 兼容现有Flask实现""" + try: + logging.info(f"代理文本流: completions_url={completions_url} {payload}") + logging.debug(f"请求负载: {json.dumps(payload, ensure_ascii=False)}") + + headers = { + "Content-Type": "application/json", + 'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm' + } + tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen') + #if 'longyuan' in tts_model_name: + # tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen" + # 创建TTS实例 + tts_model = QwenTTS( + key=ALI_KEY, + format=payload.get('tts_stream_format', 'mp3'), + sample_rate=payload.get('tts_sample_rate', 48000), + model_name=tts_model_name + ) + streaming_call = False + if tts_model.is_cosyvoice: + streaming_call = True + + # 创建流会话 + tts_stream_session_id = stream_manager.create_session( + tts_model=tts_model, + sample_rate=payload.get('tts_sample_rate', 48000), + stream_format=payload.get('tts_stream_format', 'mp3'), + session_id=None, + streaming_call= streaming_call + ) + # logging.info(f"---tts_stream_session_id = {tts_stream_session_id}") + tts_stream_session_id_sent = False + send_sentence_tts_url = False + # 添加一个事件来标记所有句子已发送 + all_sentences_sent = asyncio.Event() + + # 任务:监听并发送新生成的句子 + async def send_new_sentences(): + """监听并发送新生成的句子""" + try: + while True: + # 获取下一个句子 + sentence_info = stream_manager.get_next_sentence(tts_stream_session_id) + + if sentence_info: + logging.info(f"--proxy_aichat_text_stream 发送sentence_info\r\n") + # 发送句子信息 + await manager.send_json(client_id, { + "type": "tts_sentence", + "id": sentence_info['id'], + "text": stream_manager.get_sentence_info(sentence_info['id'])['text'], + "url": sentence_info['url'] + }) + else: + # 检查会话是否结束且没有更多句子 + session = stream_manager.get_session(tts_stream_session_id) + if not session or (not session['active']): + #and session['current_sentence_index'] >= len(session['sentences'])): + # 标记所有句子已发送 + all_sentences_sent.set() + break + + # 等待新句子生成 + await asyncio.sleep(0.1) + except asyncio.CancelledError: + logging.info("句子监听任务被取消") + except Exception as e: + logging.error(f"句子监听任务出错: {str(e)}") + all_sentences_sent.set() + + + if send_sentence_tts_url: + # 启动句子监听任务 + sentence_task = asyncio.create_task(send_new_sentences()) + # 使用更长的超时时间 (5分钟) + timeout = httpx.Timeout(300.0, connect=60.0) + async with httpx.AsyncClient(timeout=timeout) as client: + # 关键修改:使用流式请求模式 + async with client.stream( # <-- 使用stream方法 + "POST", + completions_url, + json=payload, + headers=headers + ) as response: + logging.info(f"响应状态: HTTP {response.status_code}") + + if response.status_code != 200: + # 读取错误信息(非流式) + error_content = await response.aread() + error_msg = f"后端错误: HTTP {response.status_code}" + error_msg += f" - {error_content[:200].decode()}" if error_content else "" + await manager.send_text(client_id, json.dumps({"type": "error", "message": error_msg})) + return + + # 验证SSE流 + content_type = response.headers.get("content-type", "").lower() + if "text/event-stream" not in content_type: + logging.warning("非流式响应,转发完整内容") + full_content = await response.aread() + await manager.send_text(client_id, json.dumps({ + "type": "text", + "data": full_content.decode('utf-8') + })) + return + + logging.info("开始处理SSE流") + event_count = 0 + # 使用异步迭代器逐行处理 + async for line in response.aiter_lines(): + # 跳过空行和注释行 + if not line or line.startswith(':'): + continue + + # 处理SSE事件 + if line.startswith("data:"): + data_str = line[5:].strip() + if data_str: # 过滤空数据 + try: + # 解析并提取增量文本 + data_obj = json.loads(data_str) + delta_text = None + if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict): + delta_text = data_obj.get('data', None).get('delta_ans', "") + if tts_stream_session_id_sent is False: + logging.info(f"--proxy_aichat_text_stream 发送audio_stream_url") + data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}" + data_str = json.dumps(data_obj) + tts_stream_session_id_sent = True + # 直接转发原始数据 + await manager.send_text(client_id, json.dumps({ + "type": "text", + "data": data_str + })) + # 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析 + if delta_text: + # 追加到会话管理器 + stream_manager.append_text(tts_stream_session_id, delta_text) + # logging.info(f"文本代理转发: {data_str}") + event_count += 1 + except Exception as e: + logging.error(f"事件发送失败: {str(e)}") + + # 保持连接活性 + await asyncio.sleep(0.001) # 避免CPU空转 + + logging.info(f"SSE流处理完成,事件数: {event_count}") + + # 发送文本流结束信号 + await manager.send_text(client_id, json.dumps({"type": "end"})) + # 标记文本输入结束 + if stream_manager.finish_text_input: + stream_manager.finish_text_input(tts_stream_session_id) + + if send_sentence_tts_url: + # 等待所有句子生成并发送(最多等待300秒) + try: + await asyncio.wait_for(all_sentences_sent.wait(), timeout=300.0) + logging.info(f"所有TTS句子已发送") + except asyncio.TimeoutError: + logging.warning("等待TTS句子发送超时") + + # 取消句子监听任务(如果仍在运行) + if not sentence_task.done(): + sentence_task.cancel() + try: + await sentence_task + except asyncio.CancelledError: + pass + + except httpx.ReadTimeout: + logging.error("读取后端服务超时") + await manager.send_text(client_id, json.dumps({ + "type": "error", + "message": "后端服务响应超时" + })) + except httpx.ConnectError as e: + logging.error(f"连接后端服务失败: {str(e)}") + await manager.send_text(client_id, json.dumps({ + "type": "error", + "message": f"无法连接到后端服务: {str(e)}" + })) + except Exception as e: + logging.exception(f"文本代理失败: {str(e)}") + await manager.send_text(client_id, json.dumps({ + "type": "error", + "message": f"文本流获取失败: {str(e)}" + })) + + +@tts_router.get("/audio/pcm_mp3") +async def stream_mp3(): + def audio_generator(): + path = './test.mp3' + try: + with open(path, 'rb') as f: + while True: + chunk = f.read(1024) + if not chunk: + break + yield chunk + except Exception as e: + logging.error(f"MP3 streaming error: {str(e)}") + + return StreamingResponse( + audio_generator(), + media_type="audio/mpeg", + headers={ + "Cache-Control": "no-store", + "Accept-Ranges": "bytes" + } + ) + + +def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes: + """动态生成WAV头(严格保持原有逻辑结构)""" + with BytesIO() as wav_buffer: + with wave.open(wav_buffer, 'wb') as wav_file: + wav_file.setnchannels(1) # 保持原单声道设置 + wav_file.setsampwidth(2) # 保持原16-bit设置 + wav_file.setframerate(sample_rate) + wav_file.writeframes(pcm_data) + wav_buffer.seek(0) + return wav_buffer.read() + + +def generate_silence_header(duration_ms: int = 500) -> bytes: + """生成静音数据(用于MP3流式传输预缓冲)""" + num_samples = int(TTS_SAMPLERATE * duration_ms / 1000) + return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS + + +# ------------------------ API路由 ------------------------ +@tts_router.get("/tts_sentence/{sentence_id}") +async def get_sentence_audio(sentence_id: str): + # 获取音频数据 + audio_data = stream_manager.get_sentence_audio(sentence_id) + if not audio_data: + raise HTTPException(status_code=404, detail="Audio not found") + + # 获取音频格式 + sentence_info = stream_manager.get_sentence_info(sentence_id) + if not sentence_info: + raise HTTPException(status_code=404, detail="Sentence info not found") + + # 确定MIME类型 + format = sentence_info['format'] + media_type = "audio/mpeg" if format == "mp3" else "audio/wav" + logging.info(f"--http get sentence tts audio stream {sentence_id}") + # 返回流式响应 + return StreamingResponse( + io.BytesIO(audio_data), + media_type=media_type, + headers={ + "Content-Disposition": f"attachment; filename=audio.{format}", + "Cache-Control": "max-age=3600" # 缓存1小时 + } + ) + +@tts_router.post("/chats/{chat_id}/tts") +async def create_tts_request(chat_id: str, request: Request): + try: + data = await request.json() + logging.info(f"Creating TTS request: {data}") + + # 参数校验 + text = data.get("text", "").strip() + if not text: + raise HTTPException(400, detail="Text cannot be empty") + + format = data.get("tts_stream_format", "mp3") + if format not in ["mp3", "wav", "pcm"]: + raise HTTPException(400, detail="Unsupported audio format") + + sample_rate = data.get("tts_sample_rate", 48000) + if sample_rate not in [8000, 16000, 22050, 44100, 48000]: + raise HTTPException(400, detail="Unsupported sample rate") + + model_name = data.get("model_name", "cosyvoice-v1/longxiaochun") + delay_gen_audio = data.get('delay_gen_audio', False) + + # 创建TTS任务 + audio_stream_id = tts_engine.create_tts_task( + text=text, + format=format, + sample_rate=sample_rate, + model_name=model_name, + key=ALI_KEY, + delay_gen_audio=delay_gen_audio + ) + + return JSONResponse( + status_code=200, + content={ + "tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}", + "url": f"/chats/{chat_id}/tts/{audio_stream_id}", + "ws_url": f"/chats/{chat_id}/tts/{audio_stream_id}", # WebSocket URL 2025 0622新增 + "expires_at": (datetime.datetime.now() + datetime.timedelta(seconds=300)).isoformat() + } + ) + + except Exception as e: + logging.error(f"Request failed: {str(e)}") + raise HTTPException(500, detail="Internal server error") + + +executor = ThreadPoolExecutor() + + +@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}") +async def get_tts_audio( + chat_id: str, + audio_stream_id: str, + range: str = Header(None) +): + try: + # 获取任务信息 + task = tts_engine.tasks.get(audio_stream_id) + if not task: + # 返回友好的错误信息而不是抛出异常 + return JSONResponse( + status_code=404, + content={ + "error": "Audio stream not found", + "message": f"The requested audio stream ID '{audio_stream_id}' does not exist or has expired", + "suggestion": "Please create a new TTS request and try again" + } + ) + + # 获取媒体类型 + format = task['format'] + media_type = { + "mp3": "audio/mpeg", + "wav": "audio/wav", + "pcm": f"audio/L16; rate={task['sample_rate']}; channels=1" + }[format] + + # 如果任务已完成且有完整缓冲区,处理Range请求 + logging.info(f"get_tts_audio task = {task.get('completed', 'None')} {task.get('buffer', 'None')}") + + # 创建响应内容生成器 + def buffer_read(buffer): + content_length = buffer.getbuffer().nbytes + remaining = content_length + chunk_size = 4096 + buffer.seek(0) + while remaining > 0: + read_size = min(remaining, chunk_size) + data = buffer.read(read_size) + if not data: + break + yield data + remaining -= len(data) + + if task.get('completed') and task.get('buffer') is not None: + buffer = task['buffer'] + total_size = buffer.getbuffer().nbytes + # 强制小文件使用流式传输(避免206响应问题) + + if total_size < 1024 * 120: # 小于300KB + range = None + + if range: + # 处理范围请求 + return handle_range_request(range, buffer, total_size, media_type) + else: + return StreamingResponse( + buffer_read(buffer), + media_type=media_type, + headers={ + "Accept-Ranges": "bytes", + "Cache-Control": "no-store", + "Transfer-Encoding": "chunked" + }) + + # 创建流式响应 + logging.info("tts_engine.get_audio_stream--0") + return StreamingResponse( + tts_engine.get_audio_stream(audio_stream_id), + media_type=media_type, + headers={ + "Accept-Ranges": "bytes", + "Cache-Control": "no-store", + "Transfer-Encoding": "chunked" + } + ) + + except Exception as e: + logging.error(f"Audio streaming failed: {str(e)}") + raise HTTPException(500, detail="Audio generation error") + + +def handle_range_request(range_header: str, buffer: BytesIO, total_size: int, media_type: str): + """处理 HTTP Range 请求""" + try: + # 解析 Range 头部 (示例: "bytes=0-1023") + range_type, range_spec = range_header.split('=') + if range_type != 'bytes': + raise ValueError("Unsupported range type") + + start_str, end_str = range_spec.split('-') + start = int(start_str) + end = int(end_str) if end_str else total_size - 1 + logging.info(f"handle_range_request--1 {start_str}-{end_str} {end}") + # 验证范围有效性 + if start >= total_size or end >= total_size: + raise HTTPException(status_code=416, headers={ + "Content-Range": f"bytes */{total_size}" + }) + + # 计算内容长度 + content_length = end - start + 1 + + # 设置状态码 + status_code = 206 # Partial Content + if start == 0 and end == total_size - 1: + status_code = 200 # Full Content + + # 设置流读取位置 + buffer.seek(start) + + # 创建响应内容生成器 + def content_generator(): + remaining = content_length + chunk_size = 4096 + while remaining > 0: + read_size = min(remaining, chunk_size) + data = buffer.read(read_size) + if not data: + break + yield data + remaining -= len(data) + + # 返回分块响应 + return StreamingResponse( + content_generator(), + status_code=status_code, + media_type=media_type, + headers={ + "Content-Range": f"bytes {start}-{end}/{total_size}", + "Content-Length": str(content_length), + "Accept-Ranges": "bytes", + "Cache-Control": "public, max-age=3600" + } + ) + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@tts_router.websocket("/chats/{chat_id}/tts/{audio_stream_id}") +async def websocket_tts_endpoint( + websocket: WebSocket, + chat_id: str, + audio_stream_id: str +): + # 接收 header 参数 + headers = websocket.headers + service_type = headers.get("x-tts-type") # 注意:header 名称转为小写 + # audio_url = headers.get("x-audio-url") + """ + 前端示例 + websocketConnection = uni.connectSocket({ + url: url, + header: { + 'Authorization': token, + 'X-Tts-Type': 'AiChat', //'Ask' // 自定义参数1 + 'X-Device-Type': 'mobile', // 自定义参数2 + 'X-User-ID': '12345' // 自定义参数3 + }, + success: () => { + console.log('WebSocket connected'); + }, + fail: (err) => { + console.error('WebSocket connection failed:', err); + } + }); + """ + # 创建唯一连接 ID + connection_id = str(uuid.uuid4()) + # logging.info(f"---dale-- websocket connection_id = {connection_id} chat_id={chat_id}") + await manager.connect(websocket, connection_id) + + completed_successfully = False + try: + # 根据tts_type路由到不同的音频源 + if service_type == "AiChatTts": + # 音频代理服务 + audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}" + # await proxy_aichat_audio_stream(connection_id, audio_url) + sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate') + audio_data_size =0 + await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate}) + async for data in stream_manager.get_tts_buffer_data(audio_stream_id): + if data.get('type') == 'sentence_end': + await manager.send_json(connection_id, {"command": "sentence_end"}) + + if data.get('type') == 'arraybuffer': + audio_data_size += len(data.get('data')) + if not await manager.send_bytes(connection_id, data.get('data')): + break + completed_successfully = True + logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}") + elif service_type == "AiChatText": + # 文本代理服务 + # 等待客户端发送初始请求数据 进行大模型对话代理时,需要前端连接后发送payload + payload = await websocket.receive_json() + completions_url = f"http://localhost:9380/api/v1/chats/{chat_id}/completions" + await proxy_aichat_text_stream(connection_id, completions_url, payload) + completed_successfully = True + else: + # 使用引擎的生成器直接获取音频流 + async for data in tts_engine.get_audio_stream(audio_stream_id): + if not await manager.send_bytes(connection_id, data): + logging.warning(f"Send failed, connection closed: {connection_id}") + break + await manager.send_json(connection_id, {"command": "sentence_end"}) + completed_successfully = True + + # 发送完成信号前检查连接状态 + if manager.is_connected(connection_id): + # 发送完成信号 + await manager.send_json(connection_id, {"status": "completed"}) + + # 添加短暂延迟确保消息送达 + await asyncio.sleep(0.1) + + # 主动关闭WebSocket连接 + await manager.disconnect(connection_id, code=1000, reason="Audio stream completed") + except WebSocketDisconnect: + logging.info(f"WebSocket disconnected: {connection_id}") + except Exception as e: + logging.error(f"WebSocket TTS error: {str(e)}") + if manager.is_connected(connection_id): + await manager.send_json(connection_id, {"error": str(e)}) + finally: + pass + # await manager.disconnect(connection_id) + + +def cleanup_cache(): + """清理过期缓存""" + with cache_lock: + now = datetime.datetime.now() + expired = [k for k, v in audio_text_cache.items() + if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS] + for key in expired: + logging.info(f"del audio_text_cache= {audio_text_cache[key]}") + del audio_text_cache[key] + +# 应用启动时启动清理线程 +# start_background_cleaner() diff --git a/asr-monitor-test/docker-compose.yml b/asr-monitor-test/docker-compose.yml index 599dcea5..713e26b2 100644 --- a/asr-monitor-test/docker-compose.yml +++ b/asr-monitor-test/docker-compose.yml @@ -3,8 +3,8 @@ # 需要准备 docker compose .yml 和 Dockerfile 文件 networks: - docker_ragflow: - external: true # 声明使用外部网络(需提前创建) + docker_ragflow: + external: true # 声明使用外部网络(需提前创建) services: asr-monitor-test: build: diff --git a/asr-monitor-test/requirements.txt b/asr-monitor-test/requirements.txt index cd6adae4..9ba1650f 100644 --- a/asr-monitor-test/requirements.txt +++ b/asr-monitor-test/requirements.txt @@ -3,6 +3,7 @@ uvicorn>=0.15.0 websockets==12.0 python-multipart==0.0.20 +dashscope>=1.23.8 # JWT 相关 python-jose[cryptography]==3.3.0 # 兼容 Python3 的 JOSE 实现 @@ -15,6 +16,7 @@ httpx # MySQL 驱动 pymysql==1.1.0 +tzdata==2025.2 # 连接池管理 dbutils==3.1.1 diff --git a/asr-monitor-test/run_app.sh b/asr-monitor-test/run_app.sh index eae35d57..d2766efb 100644 --- a/asr-monitor-test/run_app.sh +++ b/asr-monitor-test/run_app.sh @@ -20,8 +20,10 @@ export PYTHONPATH=.:$PYTHONPATH # 3. 使用nohup运行并防止终端退出影响 echo "➜ 启动应用进程..." -nohup python app/main.py > app.log 2>&1 & +#nohup python app/main.py > app.log 2>&1 & +#只能使用1个worker,共享全局变量使用会有问题 +nohup uvicorn app.main:app --host 0.0.0.0 --port 9580 --workers 4 > app.log 2>&1 & # 4. 验证新进程 sleep 3 # 等待进程启动 NEW_PID=$(lsof -ti :${PORT}) diff --git a/asr-monitor-test/start.sh b/asr-monitor-test/start.sh index 12eeb3b7..ce44b1a0 100755 --- a/asr-monitor-test/start.sh +++ b/asr-monitor-test/start.sh @@ -13,4 +13,13 @@ else fi source venv/bin/activate -export PYTHONPATH=.:$PYTHONPATH && python app/main.py +export PYTHONPATH=.:$PYTHONPATH +#python app/main.py + +# 高性能启动 FastAPI +uvicorn app.main:app \ + --host 0.0.0.0 \ + --port $PORT \ + --workers 4 \ + --timeout-keep-alive 65 \ + --no-access-log \ No newline at end of file