准备对AI流式音频发回给前端的机制做较大的修改,先提交1个版本

This commit is contained in:
qcloud
2025-07-19 22:44:28 +08:00
parent 74899acab9
commit 44cb7c0dca
9 changed files with 766 additions and 868 deletions

View File

@@ -5,6 +5,7 @@ from contextlib import contextmanager
from config import DATABASE_CONFIG
from datetime import datetime,timedelta
import logging
from zoneinfo import ZoneInfo # Python 3.9+ 内置
from typing import Union, List, Dict, Optional
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from dateutil.relativedelta import relativedelta
@@ -415,6 +416,21 @@ def get_users(status: int = None, email: str = None, phone: str = None,openid: s
# 按用户ID获取用户
def get_user_by_id(user_id: str):
"""
根据用户ID获取用户信息
功能说明:
- 通过用户ID查询用户基本信息
参数说明:
- user_id: 用户ID
返回:
- 用户信息的字典如果不存在则返回None
重要逻辑:
- 直接查询用户表的所有字段
"""
sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result[0] if result else None
@@ -720,7 +736,7 @@ def update_order(order_id: str, update_data: dict) -> int:
from typing import Union, List, Dict, Optional
def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -> Union[Dict, List[Dict], None]:
def get_order_by_id(order_id: str = None, user_id: str = None,combined:bool = False,museum_id:int = None) -> Union[Dict, List[Dict], None]:
"""
根据订单ID或用户ID查询订单信息
@@ -742,9 +758,25 @@ def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -
- 使用参数化查询防止 SQL 注入
- 当同时传入 order_id 和 user_id 时,优先使用 order_id
"""
if not order_id and not user_id:
return None # 两个参数都未传入,直接返回 None
sql_wo_condition= """
# 无任何参数时返回 None
if not any([order_id, user_id, museum_id]):
return None
# ========== 简单查询模式 ==========
if not combined:
# 优先使用 order_id 查询
if order_id:
sql = "SELECT * FROM subscription_orders WHERE order_id = %s"
result = execute_query(sql, (order_id,))
return result[0] if result and len(result) > 0 else None
# 使用 user_id 查询
if user_id:
sql = "SELECT * FROM subscription_orders WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result if result else []
# ========== 复杂查询模式 ==========
base_sql= """
SELECT
o.order_id,
o.user_id,
@@ -774,25 +806,39 @@ def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -
LEFT JOIN rag_flow.user_subscriptions us ON o.order_id = us.order_id
LEFT JOIN rag_flow.mesum_overview mo ON ms.museum_id = mo.id
"""
# 优先使用 order_id 查询
if order_id and not combined:
sql = "SELECT * FROM subscription_orders WHERE order_id = %s"
result = execute_query(sql, (order_id,))
return result[0] if result and len(result) > 0 else None # 返回单个订单
# 构建查询条件和参数
conditions = []
params = []
# 如果 order_id 不存在,使用 user_id 查询
if user_id and not combined:
sql = "SELECT * FROM subscription_orders WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result if result else [] # 返回所有订单(列表)
if user_id and combined:
sql = sql_wo_condition + f"\n WHERE o.user_id = %s"
result = execute_query(sql, (user_id,))
return result if result else [] # 返回所有订单(列表
if order_id and combined:
sql = sql_wo_condition + f"\n WHERE o.order_id = %s"
result = execute_query(sql, (order_id,))
return result[0] if result and len(result) > 0 else None # 返回单个订单
# 添加条件(按优先级)
if order_id:
conditions.append("o.order_id = %s")
params.append(order_id)
elif user_id:
conditions.append("o.user_id = %s")
params.append(user_id)
# 新增博物馆ID条件可与其他条件组合
if museum_id:
conditions.append("ms.museum_id = %s")
params.append(museum_id)
# 构建完整SQL
if conditions:
where_clause = " WHERE " + " AND ".join(conditions)
sql = base_sql + where_clause
else:
sql = base_sql
# 执行查询
result = execute_query(sql, tuple(params))
# 处理返回结果
if not result:
return [] if user_id or museum_id else None
# 当有order_id时返回单个对象否则返回列表
return result[0] if order_id and not museum_id else result
def create_user_subscription(data: dict) -> int:
"""
@@ -875,27 +921,6 @@ def deactivate_previous_subscriptions(user_id: str, museum_subscription_id: str)
return execute_query(sql, (user_id, museum_subscription_id), autocommit=True)
def get_user_by_id(user_id: str) -> dict:
"""
根据用户ID获取用户信息
功能说明:
- 通过用户ID查询用户基本信息
参数说明:
- user_id: 用户ID
返回:
- 用户信息的字典如果不存在则返回None
重要逻辑:
- 直接查询用户表的所有字段
"""
sql = "SELECT * FROM users_info WHERE user_id = %s"
result = execute_query(sql, (user_id,))
return result[0] if result else None
def get_subscription_template_by_id(template_id: int) -> dict:
"""
根据模板ID获取订阅模板信息
@@ -1175,6 +1200,75 @@ def check_user_subscription(user_id: str, museum_id: int) -> dict:
return None
def is_museum_free_period(museum_id: int) -> bool:
"""
检查博物馆当前是否处于免费时段
参数:
- museum_id: 博物馆ID
返回:
- True: 当前是免费时段
- False: 当前不是免费时段
"""
# 查询博物馆的免费时段配置
sql = """
SELECT t.validity_type, t.valid_time_range, t.valid_week_days
FROM museum_subscriptions ms
JOIN subscription_templates t ON ms.template_id = t.id
WHERE ms.museum_id = %s
AND t.validity_type = 'free_interval'
AND t.is_active = 1
AND ms.is_active = 1
LIMIT 1
"""
result = execute_query(sql, (museum_id,))
if not result:
return False
subscription = result[0]
return is_subscription_valid(subscription)
def get_user_valid_subscription(user_id: str, museum_id: int) -> bool:
"""
检查用户是否有有效的博物馆订阅
参数:
- user_id: 用户ID
- museum_id: 博物馆ID
返回:
- True: 用户有有效订阅
- False: 用户无有效订阅
"""
# 查询用户的有效订阅
sql = """
SELECT
t.validity_type,
t.valid_time_range,
t.valid_week_days,
us.start_date,
us.end_date
FROM user_subscriptions us
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.sub_id
JOIN subscription_templates t ON ms.template_id = t.id
WHERE us.user_id = %s
AND ms.museum_id = %s
AND us.is_active = 1
AND ms.is_active = 1
AND t.is_active = 1
AND us.start_date <= NOW()
AND us.end_date >= NOW()
"""
subscriptions = execute_query(sql, (user_id, museum_id))
# 检查每个订阅是否在当前时间有效
for sub in subscriptions:
if is_subscription_valid(sub):
return True
return False
def calculate_subscription_expiry(start_date: datetime, validity_type: str) -> datetime:
"""
@@ -1203,3 +1297,64 @@ def calculate_subscription_expiry(start_date: datetime, validity_type: str) -> d
# 未知类型默认30天
logger.warning(f"未知有效期类型: {validity_type}, 使用默认30天")
return start_date + timedelta(days=30)
def is_subscription_valid(subscription: dict) -> bool:
"""
检查订阅在当前时间是否有效
参数:
subscription: 包含订阅信息的字典,包含以下字段:
- validity_type: 有效期类型
- valid_time_range: 有效时间段 (格式: "08:00-20:00")
- valid_week_days: 有效星期 (格式: "1,3,5")
- start_date: 订阅开始日期 (datetime 对象)
- end_date: 订阅结束日期 (datetime 对象)
"""
# 设置时区(根据服务器实际时区调整)
tz = ZoneInfo('Asia/Shanghai')
now = datetime.now(tz)
# 1. 检查永久免费订阅
if subscription['validity_type'] == 'free':
return True
# 2. 检查时间间隔类型订阅
if subscription['validity_type'] == 'free_interval':
# 时间间隔类型不需要检查有效期范围
pass
else:
# 3. 检查有效期是否在范围内
if subscription['validity_type'] in ['1month', '1year', 'permanent']:
start_date = subscription['start_date'].astimezone(tz)
end_date = subscription['end_date'].astimezone(tz)
if not (start_date <= now <= end_date):
return False
# 4. 检查星期限制
if subscription.get('valid_week_days'):
week_day = now.isoweekday() # 1=周一, 7=周日
valid_days = [int(d) for d in str(subscription['valid_week_days']).split(',')]
if week_day not in valid_days:
return False
# 5. 检查时间范围限制
if subscription.get('valid_time_range'):
try:
start_str, end_str = subscription['valid_time_range'].split('-')
start_time = datetime.strptime(start_str, '%H:%M').time()
end_time = datetime.strptime(end_str, '%H:%M').time()
current_time = now.time()
# 处理跨夜时段
if end_time < start_time:
if not (current_time >= start_time or current_time <= end_time):
return False
else:
if not (start_time <= current_time <= end_time):
return False
except (ValueError, AttributeError):
# 时间格式无效,跳过时间检查
pass
return True

View File

@@ -468,12 +468,14 @@ async def get_museum_subscriptions_by_museum_id(
result = get_museum_subscriptions_by_museum(museum_id)
return CustomJSONResponse(result)
@payment_router.get("/get_order_list")
@payment_router.post("/get_order_list")
async def get_order_list(
request: Request,
current_user: dict = Depends(get_current_user)
):
result = get_order_by_id(user_id = current_user["user_id"],combined=True)
data = await request.json()
museum_id = data.get("museum_id")
result = get_order_by_id(user_id = current_user["user_id"],combined=True,museum_id=museum_id)
return CustomJSONResponse({
"code": 0,
"msg": "success",
@@ -490,6 +492,35 @@ async def get_order_detial(
"code": 0,
"msg": "success",
"data": result})
@payment_router.post("/get_user_museum_subscriptions")
async def get_user_museum_subscriptions(
request: Request,
current_user: dict = Depends(get_current_user)
):
data = await request.json()
museum_id = data.get("museum_id")
user_id = current_user["user_id"] # 用户id
is_free = False
museum_info = get_museum_by_id(museum_id=museum_id)
if museum_info and museum_info['free']:
is_free = True
is_free_period = is_museum_free_period(museum_id)
is_subscription_valid = get_user_valid_subscription(user_id, museum_id)
can_access = False
can_access = is_free or is_free_period or is_subscription_valid
result = {
'can_access': can_access,
'is_free': is_free,
'is_free_period': is_free_period,
'is_subscription_valid': is_subscription_valid
}
return CustomJSONResponse({
"code": 0,
"msg": "success",
"data": result})
# --- 支付工具函数 ---
async def generate_wx_prepay_params_v2(order_id: str, total_fee: int, openid: str, body: str):

View File

@@ -14,6 +14,9 @@ from typing import Optional, Dict, Any
import asyncio, httpx
from collections import deque
import websockets
import uuid
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query
from fastapi import FastAPI, UploadFile, File, Form, Header
from fastapi.responses import StreamingResponse, JSONResponse, Response
@@ -90,12 +93,23 @@ class StreamSessionManager:
try:
self.sessions[session_id]['last_active'] = time.time()
self.sessions[session_id]['buffer'].put(data)
self.sessions[session_id]['audio_chunk_size'] += len(data)
#logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}")
except queue.Full:
logging.warning(f"Audio buffer full for session {session_id}")
"""
elif data is None: # 结束信号
# 仅对非流式引擎触发完成事件
if not streaming_call:
logging.info(f"StreamSessionManager on_data sentence_complete_event set")
self.sessions[session_id]['sentence_complete_event'].set()
self.sessions[session_id]['current_processing'] = False
"""
# 创建完成事件
completion_event = threading.Event()
# 设置TTS流式传输
tts_instance.setup_tts(on_data)
tts_instance.setup_tts(on_data,completion_event)
# 创建会话
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
@@ -103,6 +117,7 @@ class StreamSessionManager:
'active': True,
'last_active': time.time(),
'audio_chunk_count': 0,
'audio_chunk_size': 0,
'finished': threading.Event(), # 添加事件对象
'sample_rate': sample_rate,
'stream_format': stream_format,
@@ -110,7 +125,9 @@ class StreamSessionManager:
"text_buffer": "", # 新增文本缓冲区
"last_text_time": time.time(), # 最后文本到达时间
"streaming_call": streaming_call,
"tts_stream_started": False # 标记是否已启动流
"tts_stream_started": False, # 标记是否已启动流
"sentence_complete_event": completion_event, #threading.Event(),
"current_processing": False # 标记是否正在处理句子
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
@@ -129,7 +146,7 @@ class StreamSessionManager:
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
def _process_tasks1(self, session_id):
"""任务处理线程(每个会话独立)"""
session = self.sessions.get(session_id)
if not session or not session['active']:
@@ -204,7 +221,138 @@ class StreamSessionManager:
# 休眠避免CPU空转
time.sleep(0.05) # 50ms检查间隔
def _generate_audio(self, session_id, text):
def _process_tasks(self, session_id): # 20250718 新更新
"""任务处理线程(每个会话独立)- 保留原有处理逻辑"""
session = self.sessions.get(session_id)
if not session or not session['active']:
return
# 根据引擎类型选择处理函数
if session.get('streaming_call'):
gen_tts_audio_func = self._stream_audio
else:
gen_tts_audio_func = self._generate_audio
while session['active']:
current_time = time.time()
text_to_process = ""
# 1. 获取待处理文本
with self.lock:
if session['text_buffer']:
text_to_process = session['text_buffer']
# 2. 处理文本
if text_to_process and not session['current_processing'] :
session['text_buffer'] = ""
# 分割完整句子
complete_sentences, remaining_text = self._split_and_extract(text_to_process)
# 保存剩余文本
if remaining_text:
with self.lock:
session['text_buffer'] = remaining_text + session['text_buffer']
# 合并并处理完整句子
if complete_sentences:
# 智能合并句子最长300字符
buffer = []
current_length = 0
# 处理每个句子
for sentence in complete_sentences:
sent_length = len(sentence)
# 添加到当前缓冲区
if current_length + sent_length <= 300:
buffer.append(sentence)
current_length += sent_length
else:
# 处理已缓冲的文本
if buffer:
combined_text = "".join(buffer)
# 重置完成事件状态
session['sentence_complete_event'].clear()
session['current_processing'] = True
# 生成音频
gen_tts_audio_func(session_id, combined_text)
# 等待完成
if not session['sentence_complete_event'].wait(timeout=120.0):
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
# 重置处理状态
time.sleep(5.0)
session['current_processing'] = False
logging.info(f"StreamSessionManager _process_tasks 转换结束!!!")
# 重置缓冲区
buffer = [sentence]
current_length = sent_length
# 处理剩余的缓冲文本
if buffer:
combined_text = "".join(buffer)
# 重置完成事件状态
session['sentence_complete_event'].clear()
session['current_processing'] = True
# 生成音频
gen_tts_audio_func(session_id, combined_text)
# 等待完成
if not session['sentence_complete_event'].wait(timeout=120.0):
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
# 重置处理状态
time.sleep(1.0)
session['current_processing'] = False
logging.info(f"StreamSessionManager _process_tasks 转换结束!!!")
# 3. 检查超时未处理的文本
if current_time - session['last_text_time'] > self.sentence_timeout:
with self.lock:
if session['text_buffer']:
# 直接处理剩余文本
session['sentence_complete_event'].clear()
session['current_processing'] = True
gen_tts_audio_func(session_id, session['text_buffer'])
session['text_buffer'] = ""
# 等待完成
if not session['sentence_complete_event'].wait(timeout=120.0):
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
# 重置处理状态
session['current_processing'] = False
# 4. 会话超时检查
if current_time - session['last_active'] > self.gc_interval:
# 处理剩余文本
with self.lock:
if session['text_buffer']:
# 重置完成事件状态
session['sentence_complete_event'].clear()
session['current_processing'] = True
# 处理最后一段文本
gen_tts_audio_func(session_id, session['text_buffer'])
session['text_buffer'] = ""
# 等待完成
if not session['sentence_complete_event'].wait(timeout=120.0):
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
# 重置处理状态
session['current_processing'] = False
# 关闭会话
self.close_session(session_id)
break
# 5. 休眠避免CPU空转
time.sleep(0.05) # 50ms检查间隔
def _generate_audio1(self, session_id, text):
"""实际生成音频(线程池执行)"""
session = self.sessions.get(session_id)
if not session: return
@@ -236,6 +384,26 @@ class StreamSessionManager:
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
def _generate_audio(self, session_id, text): # 20250718 新更新
"""实际生成音频(顺序执行)- 用于非流式引擎"""
session = self.sessions.get(session_id)
if not session:
return
try:
logging.info(f"StreamSessionManager _generate_audio--0 {text}")
# 调用 TTS
session['tts_model'].text_tts_call(text)
session['last_active'] = time.time()
session['audio_chunk_count'] += 1
if not session['tts_chunk_data_valid']:
session['tts_chunk_data_valid'] = True
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}".encode())
session['sentence_complete_event'].set() # 确保事件被设置
def _stream_audio(self, session_id, text):
"""流式传输文本到TTS服务"""
session = self.sessions.get(session_id)
@@ -247,9 +415,12 @@ class StreamSessionManager:
# 使用流式调用发送文本
session['tts_model'].streaming_call(text)
session['last_active'] = time.time()
# 流式引擎不需要等待完成事件
session['sentence_complete_event'].set()
except Exception as e:
logging.error(f"Error in streaming_call: {str(e)}")
session['buffer'].put(f"ERROR:{str(e)}".encode())
session['sentence_complete_event'].set()
async def get_tts_buffer_data(self, session_id):
"""异步流式返回 TTS 音频数据(适配同步 queue.Queue带 10 秒超时)"""
@@ -298,6 +469,8 @@ class StreamSessionManager:
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 设置完成事件(确保任何等待的线程被唤醒)
self.sessions[session_id]['sentence_complete_event'].set()
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
@@ -827,7 +1000,8 @@ from dashscope.audio.tts_v2 import (
class QwenTTS:
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun"):
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun",
special_characters: Optional[Dict[str, str]] = None):
import dashscope
import ssl
logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx
@@ -844,14 +1018,22 @@ class QwenTTS:
if '/' in self.model_name:
parts = self.model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1':
if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2':
self.is_cosyvoice = True
self.voice = parts[1]
self.completion_event = None # 新增:用于通知任务完成
# 特殊字符及其拼音映射
self.special_characters = special_characters or {
"": "chuang3",
"": "yue4"
# 可以添加更多特殊字符的映射
}
class Callback(TTSResultCallback):
def __init__(self) -> None:
def __init__(self,data_callback=None,completion_event=None) -> None:
self.dque = deque()
self.data_callback = data_callback
self.completion_event = completion_event # 新增完成事件引用
def _run(self):
while True:
if not self.dque:
@@ -867,7 +1049,13 @@ class QwenTTS:
pass
def on_complete(self):
logging.info(f"---QwenTTS Callback on_complete--")
self.dque.append(None)
if self.data_callback:
self.data_callback(None) # 发送结束信号
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
@@ -877,15 +1065,22 @@ class QwenTTS:
pass
def on_event(self, result: TTSSpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
data =result.get_audio_frame()
if data is not None:
if len(data) > 0:
if self.data_callback:
self.data_callback(data)
else:
self.dque.append(data)
#self.dque.append(result.get_audio_frame())
# --------------------------
class Callback_Cosy(CosyResultCallback):
def __init__(self, data_callback=None) -> None:
def __init__(self, data_callback=None,completion_event=None) -> None:
self.dque = deque()
self.data_callback = data_callback
self.completion_event = completion_event # 新增完成事件引用
def _run(self):
while True:
@@ -906,6 +1101,9 @@ class QwenTTS:
self.dque.append(None)
if self.data_callback:
self.data_callback(None) # 发送结束信号
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
@@ -938,23 +1136,28 @@ class QwenTTS:
# --------------------------
def tts(self, text):
def tts(self, text, on_data = None,completion_event=None):
# logging.info(f"---QwenTTS tts begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
# text = self.normalize_text(text)
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
# text = self.normalize_text(text)
try:
# if self.model_name != 'cosyvoice-v1':
if self.is_cosyvoice is False:
self.callback = self.Callback()
self.callback = self.Callback(
data_callback=on_data,
completion_event=completion_event
)
TTSSpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format="wav") # format="mp3")
format=self.format) # format="mp3")
else:
self.callback = self.Callback_Cosy()
format = self.get_audio_format(self.format, self.sample_rate)
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
model='cosyvoice-v2',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
@@ -974,26 +1177,68 @@ class QwenTTS:
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def setup_tts(self, on_data):
"""设置 TTS 回调,返回配置好的 synthesizer"""
if not self.is_cosyvoice:
raise NotImplementedError("Only CosyVoice supported")
def setup_tts(self, on_data,completion_event=None):
# 创建 CosyVoice 回调
self.callback = self.Callback_Cosy(on_data)
"""设置 TTS 回调,返回配置好的 synthesizer"""
#if not self.is_cosyvoice:
# raise NotImplementedError("Only CosyVoice supported")
if self.is_cosyvoice:
# 创建 CosyVoice 回调
self.callback = self.Callback_Cosy(
data_callback=on_data,
completion_event=completion_event)
else:
self.callback = self.Callback(
data_callback=on_data,
completion_event=completion_event)
format_val = self.get_audio_format(self.format, self.sample_rate)
logging.info(f"setup_tts {self.voice} {format_val}")
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
voice=self.voice, # voice="longyuan", #"longfei",
callback=self.callback,
format=format_val
)
logging.info(f"Qwen setup_tts {self.voice} {format_val}")
if self.is_cosyvoice:
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
voice=self.voice, # voice="longyuan", #"longfei",
callback=self.callback,
format=format_val
)
return self.synthesizer
def apply_phoneme_tags(self, text: str) -> str:
"""
在文本中查找特殊字符并用<phoneme>标签包裹它们
"""
# 如果文本已经是SSML格式直接返回
if text.strip().startswith("<speak>") and text.strip().endswith("</speak>"):
return text
# 为特殊字符添加SSML标签
for char, pinyin in self.special_characters.items():
# 使用正则表达式确保只替换整个字符(避免部分匹配)
pattern = r'([^<]|^)' + re.escape(char) + r'([^>]|$)'
replacement = r'\1<phoneme alphabet="py" ph="' + pinyin + r'">' + char + r'</phoneme>\2'
text = re.sub(pattern, replacement, text)
# 如果文本中已有SSML标签直接返回
if "<speak>" in text:
return text
# 否则包裹在<speak>标签中
return f"<speak>{text}</speak>"
def text_tts_call(self, text):
if self.synthesizer:
if self.special_characters and self.is_cosyvoice is False:
text = self.apply_phoneme_tags(text)
#logging.info(f"Applied SSML phoneme tags to text: {text}")
if self.synthesizer and self.is_cosyvoice:
self.synthesizer.call(text)
if self.is_cosyvoice is False:
logging.info(f"Qwen text_tts_call {text}")
TTSSpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format=self.format)
def streaming_call(self, text):
if self.synthesizer:
@@ -1027,22 +1272,178 @@ class QwenTTS:
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
import threading
import uuid
import time
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
class DoubaoTTS:
def __init__(self, key, format="mp3", sample_rate=8000, model_name="doubao-tts"):
logging.info(f"---begin--init DoubaoTTS-- {format} {sample_rate} {model_name}")
# 解析豆包认证信息 (appid, token, cluster, voice_type)
try:
self.appid = "7282190702"
self.token = "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9"
self.cluster = "volcano_tts"
self.voice_type ="zh_female_qingxinnvsheng_mars_bigtts" # "zh_male_jieshuonansheng_mars_bigtts" #"zh_male_ruyaqingnian_mars_bigtts" #"zh_male_jieshuonansheng_mars_bigtts"
except Exception as e:
raise ValueError(f"Invalid Doubao key format: {str(e)}")
import threading
import uuid
import time
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from collections import deque
from io import BytesIO
self.format = format
self.sample_rate = sample_rate
self.model_name = model_name
self.callback = None
self.ws = None
self.loop = None
self.task = None
self.event = threading.Event()
self.data_queue = deque()
self.host = "openspeech.bytedance.com"
self.api_url = f"wss://{self.host}/api/v1/tts/ws_binary"
self.default_header = bytearray(b'\x11\x10\x11\x00')
self.total_data_size = 0
self.completion_event = None # 新增:用于通知任务完成
class Callback:
def __init__(self, data_callback=None,completion_event=None):
self.data_callback = data_callback
self.data_queue = deque()
self.completion_event = completion_event # 完成事件引用
def on_data(self, data):
if self.data_callback:
self.data_callback(data)
else:
self.data_queue.append(data)
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def on_complete(self):
if self.data_callback:
self.data_callback(None)
def on_error(self, error):
if self.data_callback:
self.data_callback(f"ERROR:{error}".encode())
def setup_tts(self, on_data,completion_event):
"""设置回调,返回自身(因为豆包需要异步启动)"""
self.callback = self.Callback(
data_callback=on_data,
completion_event=completion_event
)
return self
def text_tts_call(self, text):
"""同步调用,启动异步任务并等待完成"""
self.total_data_size = 0
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.task = self.loop.create_task(self._async_tts(text))
try:
self.loop.run_until_complete(self.task)
except Exception as e:
logging.error(f"DoubaoTTS--0 call error: {e}")
self.callback.on_error(str(e))
async def _async_tts(self, text):
"""异步执行TTS请求"""
header = {"Authorization": f"Bearer; {self.token}"}
request_json = {
"app": {
"appid": self.appid,
"token": "access_token", # 固定值
"cluster": self.cluster
},
"user": {
"uid": str(uuid.uuid4()) # 随机用户ID
},
"audio": {
"voice_type": self.voice_type,
"encoding": self.format,
"speed_ratio": 1.0,
"volume_ratio": 1.0,
"pitch_ratio": 1.0,
},
"request": {
"reqid": str(uuid.uuid4()),
"text": text,
"text_type": "plain",
"operation": "submit" # 使用submit模式支持流式
}
}
# 构建请求数据
payload_bytes = str.encode(json.dumps(request_json))
payload_bytes = gzip.compress(payload_bytes)
full_client_request = bytearray(self.default_header)
full_client_request.extend(len(payload_bytes).to_bytes(4, 'big'))
full_client_request.extend(payload_bytes)
try:
async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws:
self.ws = ws
await ws.send(full_client_request)
# 接收音频数据
while True:
res = await ws.recv()
done = self._parse_response(res)
if done:
self.callback.on_complete()
break
except Exception as e:
logging.error(f"DoubaoTTS--1 WebSocket error: {e}")
self.callback.on_error(str(e))
finally:
# 通知任务完成
if self.completion_event:
self.completion_event.set()
def _parse_response(self, res):
"""解析豆包返回的二进制响应"""
# 协议头解析 (4字节)
header_size = res[0] & 0x0f
message_type = res[1] >> 4
payload = res[header_size * 4:]
# 音频数据响应
if message_type == 0xb: # audio-only server response
message_flags = res[1] & 0x0f
# ACK消息忽略
if message_flags == 0:
return False
# 音频数据消息
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
audio_data = payload[8:8 + payload_size]
if audio_data:
self.total_data_size = self.total_data_size + len(audio_data)
self.callback.on_data(audio_data)
#logging.info(f"doubao _parse_response: {sequence_number} {len(audio_data)} {self.total_data_size}")
# 序列号为负表示结束
return sequence_number < 0
# 错误响应
elif message_type == 0xf:
code = int.from_bytes(payload[:4], "big", signed=False)
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
error_msg = payload[8:8 + msg_size]
try:
# 尝试解压错误消息
error_msg = gzip.decompress(error_msg).decode()
except:
error_msg = error_msg.decode(errors='ignore')
logging.error(f"DoubaoTTS error: {error_msg}")
self.callback.on_error(error_msg)
return False
return False
class UnifiedTTSEngine:
@@ -1144,13 +1545,29 @@ class UnifiedTTSEngine:
return
try:
# 创建完成事件
completion_event = threading.Event()
# 创建TTS实例
tts = QwenTTS(
key=task['key'],
format=task['format'],
sample_rate=task['sample_rate'],
model_name=task['model_name']
)
# 根据model_name选择TTS引擎
# 前端传入 cosyvoice-v1/longhua@Tongyi-Qianwen
model_name_wo_brand = task['model_name'].split('@')[0]
model_name_version = model_name_wo_brand.split('/')[0]
if "longhua" in task['model_name'] or "zh_female_qingxinnvsheng_mars_bigtts" in task['model_name']:
# 豆包TTS
tts = DoubaoTTS(
key=task['key'],
format=task['format'],
sample_rate=task['sample_rate'],
model_name=task['model_name']
)
else:
# 通义千问TTS
tts = QwenTTS(
key=task['key'],
format=task['format'],
sample_rate=task['sample_rate'],
model_name=task['model_name']
)
# 定义同步数据处理函数
def data_handler(data):
@@ -1165,20 +1582,28 @@ class UnifiedTTSEngine:
else: # 音频数据
task['data_queue'].append(data)
# 设置并执行TTS
synthesizer = tts.setup_tts(data_handler)
synthesizer.call(task['text'])
# 设置并执行TTS
synthesizer = tts.setup_tts(data_handler,completion_event)
#synthesizer.call(task['text'])
tts.text_tts_call(task['text'])
# 等待完成或超时
if not task['event'].wait(timeout=300): # 5分钟超时
# 等待完成或超时
if not completion_event.wait(timeout=300): # 5分钟超时
task['error'] = "TTS generation timeout"
task['completed'] = True
logging.info(f"--tts task event set error = {task['error']}")
except Exception as e:
logging.info(f"UnifiedTTSEngine _run_tts_sync ERROR: {str(e)}")
task['error'] = f"ERROR:{str(e)}"
task['completed'] = True
finally:
# 确保清理TTS资源
logging.info("UnifiedTTSEngine _run_tts_sync finally")
if hasattr(tts, 'loop') and tts.loop:
tts.loop.close()
def _merge_audio_data(self, audio_stream_id):
"""将任务的所有音频数据合并到ByteIO缓冲区"""
@@ -1219,7 +1644,7 @@ class UnifiedTTSEngine:
# 如果是延迟任务且未启动,现在启动 status 为 pending
if task['delay_gen_audio'] and task['status'] == 'pending':
self._start_tts_task(audio_stream_id)
total_audio_data_size = 0
# 等待任务启动
while task['status'] == 'pending':
await asyncio.sleep(0.1)
@@ -1228,7 +1653,8 @@ class UnifiedTTSEngine:
while not task['completed'] or task['data_queue']:
while task['data_queue']:
data = task['data_queue'].popleft()
# logging.info(f"yield data {len(data)}")
total_audio_data_size += len(data)
#logging.info(f"yield audio data {len(data)} {total_audio_data_size}")
yield data
# 短暂等待新数据
@@ -1318,6 +1744,7 @@ async def proxy_aichat_audio_stream(client_id: str, audio_url: str):
# 代理函数 - 文本流
# 在微信小程序中原来APK使用的SSE机制不能正常工作需要使用WebSocket
async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict):
"""代理大模型文本流请求 - 兼容现有Flask实现"""
try:
@@ -1328,13 +1755,19 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload
"Content-Type": "application/json",
'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm'
}
tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
#if 'longyuan' in tts_model_name:
# tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen"
# 创建TTS实例
tts_model = QwenTTS(
key=ALI_KEY,
format=payload.get('tts_stream_format', 'mp3'),
sample_rate=payload.get('tts_sample_rate', 48000),
model_name=payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
model_name=tts_model_name
)
streaming_call = False
if tts_model.is_cosyvoice:
streaming_call = True
# 创建流会话
tts_stream_session_id = stream_manager.create_session(
@@ -1342,7 +1775,7 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload
sample_rate=payload.get('tts_sample_rate', 48000),
stream_format=payload.get('tts_stream_format', 'mp3'),
session_id=None,
streaming_call=True
streaming_call= streaming_call
)
# logging.info(f"---tts_stream_session_id = {tts_stream_session_id}")
tts_stream_session_id_sent = False
@@ -1399,7 +1832,6 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload
data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}"
data_str = json.dumps(data_obj)
tts_stream_session_id_sent = True
# 直接转发原始数据
await manager.send_text(client_id, json.dumps({
"type": "text",
@@ -1717,12 +2149,14 @@ async def websocket_tts_endpoint(
audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}"
# await proxy_aichat_audio_stream(connection_id, audio_url)
sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate')
audio_data_size =0
await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate})
async for data in stream_manager.get_tts_buffer_data(audio_stream_id):
audio_data_size += len(data)
if not await manager.send_bytes(connection_id, data):
break
completed_successfully = True
logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}")
elif service_type == "AiChatText":
# 文本代理服务
# 等待客户端发送初始请求数据 进行大模型对话代理时需要前端连接后发送payload