准备对AI流式音频发回给前端的机制做较大的修改,先提交1个版本
This commit is contained in:
@@ -5,6 +5,7 @@ from contextlib import contextmanager
|
||||
from config import DATABASE_CONFIG
|
||||
from datetime import datetime,timedelta
|
||||
import logging
|
||||
from zoneinfo import ZoneInfo # Python 3.9+ 内置
|
||||
from typing import Union, List, Dict, Optional
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
|
||||
from dateutil.relativedelta import relativedelta
|
||||
@@ -415,6 +416,21 @@ def get_users(status: int = None, email: str = None, phone: str = None,openid: s
|
||||
|
||||
# 按用户ID获取用户
|
||||
def get_user_by_id(user_id: str):
|
||||
"""
|
||||
根据用户ID获取用户信息
|
||||
|
||||
功能说明:
|
||||
- 通过用户ID查询用户基本信息
|
||||
|
||||
参数说明:
|
||||
- user_id: 用户ID
|
||||
|
||||
返回:
|
||||
- 用户信息的字典,如果不存在则返回None
|
||||
|
||||
重要逻辑:
|
||||
- 直接查询用户表的所有字段
|
||||
"""
|
||||
sql = "SELECT * FROM rag_flow.users_info WHERE user_id = %s"
|
||||
result = execute_query(sql, (user_id,))
|
||||
return result[0] if result else None
|
||||
@@ -720,7 +736,7 @@ def update_order(order_id: str, update_data: dict) -> int:
|
||||
|
||||
from typing import Union, List, Dict, Optional
|
||||
|
||||
def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -> Union[Dict, List[Dict], None]:
|
||||
def get_order_by_id(order_id: str = None, user_id: str = None,combined:bool = False,museum_id:int = None) -> Union[Dict, List[Dict], None]:
|
||||
"""
|
||||
根据订单ID或用户ID查询订单信息
|
||||
|
||||
@@ -742,9 +758,25 @@ def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -
|
||||
- 使用参数化查询防止 SQL 注入
|
||||
- 当同时传入 order_id 和 user_id 时,优先使用 order_id
|
||||
"""
|
||||
if not order_id and not user_id:
|
||||
return None # 两个参数都未传入,直接返回 None
|
||||
sql_wo_condition= """
|
||||
# 无任何参数时返回 None
|
||||
if not any([order_id, user_id, museum_id]):
|
||||
return None
|
||||
# ========== 简单查询模式 ==========
|
||||
if not combined:
|
||||
# 优先使用 order_id 查询
|
||||
if order_id:
|
||||
sql = "SELECT * FROM subscription_orders WHERE order_id = %s"
|
||||
result = execute_query(sql, (order_id,))
|
||||
return result[0] if result and len(result) > 0 else None
|
||||
|
||||
# 使用 user_id 查询
|
||||
if user_id:
|
||||
sql = "SELECT * FROM subscription_orders WHERE user_id = %s"
|
||||
result = execute_query(sql, (user_id,))
|
||||
return result if result else []
|
||||
|
||||
# ========== 复杂查询模式 ==========
|
||||
base_sql= """
|
||||
SELECT
|
||||
o.order_id,
|
||||
o.user_id,
|
||||
@@ -774,25 +806,39 @@ def get_order_by_id(order_id: str = None, user_id: str = None,combined = None) -
|
||||
LEFT JOIN rag_flow.user_subscriptions us ON o.order_id = us.order_id
|
||||
LEFT JOIN rag_flow.mesum_overview mo ON ms.museum_id = mo.id
|
||||
"""
|
||||
# 优先使用 order_id 查询
|
||||
if order_id and not combined:
|
||||
sql = "SELECT * FROM subscription_orders WHERE order_id = %s"
|
||||
result = execute_query(sql, (order_id,))
|
||||
return result[0] if result and len(result) > 0 else None # 返回单个订单
|
||||
# 构建查询条件和参数
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
# 如果 order_id 不存在,使用 user_id 查询
|
||||
if user_id and not combined:
|
||||
sql = "SELECT * FROM subscription_orders WHERE user_id = %s"
|
||||
result = execute_query(sql, (user_id,))
|
||||
return result if result else [] # 返回所有订单(列表)
|
||||
if user_id and combined:
|
||||
sql = sql_wo_condition + f"\n WHERE o.user_id = %s"
|
||||
result = execute_query(sql, (user_id,))
|
||||
return result if result else [] # 返回所有订单(列表)
|
||||
if order_id and combined:
|
||||
sql = sql_wo_condition + f"\n WHERE o.order_id = %s"
|
||||
result = execute_query(sql, (order_id,))
|
||||
return result[0] if result and len(result) > 0 else None # 返回单个订单
|
||||
# 添加条件(按优先级)
|
||||
if order_id:
|
||||
conditions.append("o.order_id = %s")
|
||||
params.append(order_id)
|
||||
elif user_id:
|
||||
conditions.append("o.user_id = %s")
|
||||
params.append(user_id)
|
||||
|
||||
# 新增博物馆ID条件(可与其他条件组合)
|
||||
if museum_id:
|
||||
conditions.append("ms.museum_id = %s")
|
||||
params.append(museum_id)
|
||||
|
||||
# 构建完整SQL
|
||||
if conditions:
|
||||
where_clause = " WHERE " + " AND ".join(conditions)
|
||||
sql = base_sql + where_clause
|
||||
else:
|
||||
sql = base_sql
|
||||
|
||||
# 执行查询
|
||||
result = execute_query(sql, tuple(params))
|
||||
|
||||
# 处理返回结果
|
||||
if not result:
|
||||
return [] if user_id or museum_id else None
|
||||
|
||||
# 当有order_id时返回单个对象,否则返回列表
|
||||
return result[0] if order_id and not museum_id else result
|
||||
|
||||
def create_user_subscription(data: dict) -> int:
|
||||
"""
|
||||
@@ -875,27 +921,6 @@ def deactivate_previous_subscriptions(user_id: str, museum_subscription_id: str)
|
||||
return execute_query(sql, (user_id, museum_subscription_id), autocommit=True)
|
||||
|
||||
|
||||
def get_user_by_id(user_id: str) -> dict:
|
||||
"""
|
||||
根据用户ID获取用户信息
|
||||
|
||||
功能说明:
|
||||
- 通过用户ID查询用户基本信息
|
||||
|
||||
参数说明:
|
||||
- user_id: 用户ID
|
||||
|
||||
返回:
|
||||
- 用户信息的字典,如果不存在则返回None
|
||||
|
||||
重要逻辑:
|
||||
- 直接查询用户表的所有字段
|
||||
"""
|
||||
sql = "SELECT * FROM users_info WHERE user_id = %s"
|
||||
result = execute_query(sql, (user_id,))
|
||||
return result[0] if result else None
|
||||
|
||||
|
||||
def get_subscription_template_by_id(template_id: int) -> dict:
|
||||
"""
|
||||
根据模板ID获取订阅模板信息
|
||||
@@ -1175,6 +1200,75 @@ def check_user_subscription(user_id: str, museum_id: int) -> dict:
|
||||
return None
|
||||
|
||||
|
||||
def is_museum_free_period(museum_id: int) -> bool:
|
||||
"""
|
||||
检查博物馆当前是否处于免费时段
|
||||
|
||||
参数:
|
||||
- museum_id: 博物馆ID
|
||||
|
||||
返回:
|
||||
- True: 当前是免费时段
|
||||
- False: 当前不是免费时段
|
||||
"""
|
||||
# 查询博物馆的免费时段配置
|
||||
sql = """
|
||||
SELECT t.validity_type, t.valid_time_range, t.valid_week_days
|
||||
FROM museum_subscriptions ms
|
||||
JOIN subscription_templates t ON ms.template_id = t.id
|
||||
WHERE ms.museum_id = %s
|
||||
AND t.validity_type = 'free_interval'
|
||||
AND t.is_active = 1
|
||||
AND ms.is_active = 1
|
||||
LIMIT 1
|
||||
"""
|
||||
result = execute_query(sql, (museum_id,))
|
||||
if not result:
|
||||
return False
|
||||
|
||||
subscription = result[0]
|
||||
return is_subscription_valid(subscription)
|
||||
|
||||
|
||||
def get_user_valid_subscription(user_id: str, museum_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否有有效的博物馆订阅
|
||||
|
||||
参数:
|
||||
- user_id: 用户ID
|
||||
- museum_id: 博物馆ID
|
||||
|
||||
返回:
|
||||
- True: 用户有有效订阅
|
||||
- False: 用户无有效订阅
|
||||
"""
|
||||
# 查询用户的有效订阅
|
||||
sql = """
|
||||
SELECT
|
||||
t.validity_type,
|
||||
t.valid_time_range,
|
||||
t.valid_week_days,
|
||||
us.start_date,
|
||||
us.end_date
|
||||
FROM user_subscriptions us
|
||||
JOIN museum_subscriptions ms ON us.museum_subscription_id = ms.sub_id
|
||||
JOIN subscription_templates t ON ms.template_id = t.id
|
||||
WHERE us.user_id = %s
|
||||
AND ms.museum_id = %s
|
||||
AND us.is_active = 1
|
||||
AND ms.is_active = 1
|
||||
AND t.is_active = 1
|
||||
AND us.start_date <= NOW()
|
||||
AND us.end_date >= NOW()
|
||||
"""
|
||||
subscriptions = execute_query(sql, (user_id, museum_id))
|
||||
|
||||
# 检查每个订阅是否在当前时间有效
|
||||
for sub in subscriptions:
|
||||
if is_subscription_valid(sub):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def calculate_subscription_expiry(start_date: datetime, validity_type: str) -> datetime:
|
||||
"""
|
||||
@@ -1203,3 +1297,64 @@ def calculate_subscription_expiry(start_date: datetime, validity_type: str) -> d
|
||||
# 未知类型默认30天
|
||||
logger.warning(f"未知有效期类型: {validity_type}, 使用默认30天")
|
||||
return start_date + timedelta(days=30)
|
||||
|
||||
|
||||
def is_subscription_valid(subscription: dict) -> bool:
|
||||
"""
|
||||
检查订阅在当前时间是否有效
|
||||
|
||||
参数:
|
||||
subscription: 包含订阅信息的字典,包含以下字段:
|
||||
- validity_type: 有效期类型
|
||||
- valid_time_range: 有效时间段 (格式: "08:00-20:00")
|
||||
- valid_week_days: 有效星期 (格式: "1,3,5")
|
||||
- start_date: 订阅开始日期 (datetime 对象)
|
||||
- end_date: 订阅结束日期 (datetime 对象)
|
||||
"""
|
||||
# 设置时区(根据服务器实际时区调整)
|
||||
tz = ZoneInfo('Asia/Shanghai')
|
||||
now = datetime.now(tz)
|
||||
|
||||
# 1. 检查永久免费订阅
|
||||
if subscription['validity_type'] == 'free':
|
||||
return True
|
||||
|
||||
# 2. 检查时间间隔类型订阅
|
||||
if subscription['validity_type'] == 'free_interval':
|
||||
# 时间间隔类型不需要检查有效期范围
|
||||
pass
|
||||
else:
|
||||
# 3. 检查有效期是否在范围内
|
||||
if subscription['validity_type'] in ['1month', '1year', 'permanent']:
|
||||
start_date = subscription['start_date'].astimezone(tz)
|
||||
end_date = subscription['end_date'].astimezone(tz)
|
||||
|
||||
if not (start_date <= now <= end_date):
|
||||
return False
|
||||
|
||||
# 4. 检查星期限制
|
||||
if subscription.get('valid_week_days'):
|
||||
week_day = now.isoweekday() # 1=周一, 7=周日
|
||||
valid_days = [int(d) for d in str(subscription['valid_week_days']).split(',')]
|
||||
if week_day not in valid_days:
|
||||
return False
|
||||
|
||||
# 5. 检查时间范围限制
|
||||
if subscription.get('valid_time_range'):
|
||||
try:
|
||||
start_str, end_str = subscription['valid_time_range'].split('-')
|
||||
start_time = datetime.strptime(start_str, '%H:%M').time()
|
||||
end_time = datetime.strptime(end_str, '%H:%M').time()
|
||||
current_time = now.time()
|
||||
# 处理跨夜时段
|
||||
if end_time < start_time:
|
||||
if not (current_time >= start_time or current_time <= end_time):
|
||||
return False
|
||||
else:
|
||||
if not (start_time <= current_time <= end_time):
|
||||
return False
|
||||
except (ValueError, AttributeError):
|
||||
# 时间格式无效,跳过时间检查
|
||||
pass
|
||||
|
||||
return True
|
||||
@@ -468,12 +468,14 @@ async def get_museum_subscriptions_by_museum_id(
|
||||
result = get_museum_subscriptions_by_museum(museum_id)
|
||||
return CustomJSONResponse(result)
|
||||
|
||||
@payment_router.get("/get_order_list")
|
||||
@payment_router.post("/get_order_list")
|
||||
async def get_order_list(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
result = get_order_by_id(user_id = current_user["user_id"],combined=True)
|
||||
data = await request.json()
|
||||
museum_id = data.get("museum_id")
|
||||
result = get_order_by_id(user_id = current_user["user_id"],combined=True,museum_id=museum_id)
|
||||
return CustomJSONResponse({
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
@@ -490,6 +492,35 @@ async def get_order_detial(
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": result})
|
||||
|
||||
@payment_router.post("/get_user_museum_subscriptions")
|
||||
async def get_user_museum_subscriptions(
|
||||
request: Request,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
data = await request.json()
|
||||
museum_id = data.get("museum_id")
|
||||
user_id = current_user["user_id"] # 用户id
|
||||
is_free = False
|
||||
museum_info = get_museum_by_id(museum_id=museum_id)
|
||||
if museum_info and museum_info['free']:
|
||||
is_free = True
|
||||
is_free_period = is_museum_free_period(museum_id)
|
||||
is_subscription_valid = get_user_valid_subscription(user_id, museum_id)
|
||||
can_access = False
|
||||
can_access = is_free or is_free_period or is_subscription_valid
|
||||
result = {
|
||||
'can_access': can_access,
|
||||
'is_free': is_free,
|
||||
'is_free_period': is_free_period,
|
||||
'is_subscription_valid': is_subscription_valid
|
||||
}
|
||||
return CustomJSONResponse({
|
||||
"code": 0,
|
||||
"msg": "success",
|
||||
"data": result})
|
||||
|
||||
|
||||
# --- 支付工具函数 ---
|
||||
|
||||
async def generate_wx_prepay_params_v2(order_id: str, total_fee: int, openid: str, body: str):
|
||||
|
||||
@@ -14,6 +14,9 @@ from typing import Optional, Dict, Any
|
||||
import asyncio, httpx
|
||||
from collections import deque
|
||||
|
||||
import websockets
|
||||
import uuid
|
||||
|
||||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query
|
||||
from fastapi import FastAPI, UploadFile, File, Form, Header
|
||||
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
||||
@@ -90,12 +93,23 @@ class StreamSessionManager:
|
||||
try:
|
||||
self.sessions[session_id]['last_active'] = time.time()
|
||||
self.sessions[session_id]['buffer'].put(data)
|
||||
self.sessions[session_id]['audio_chunk_size'] += len(data)
|
||||
#logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}")
|
||||
except queue.Full:
|
||||
logging.warning(f"Audio buffer full for session {session_id}")
|
||||
|
||||
"""
|
||||
elif data is None: # 结束信号
|
||||
# 仅对非流式引擎触发完成事件
|
||||
if not streaming_call:
|
||||
logging.info(f"StreamSessionManager on_data sentence_complete_event set")
|
||||
self.sessions[session_id]['sentence_complete_event'].set()
|
||||
self.sessions[session_id]['current_processing'] = False
|
||||
"""
|
||||
# 创建完成事件
|
||||
completion_event = threading.Event()
|
||||
# 设置TTS流式传输
|
||||
tts_instance.setup_tts(on_data)
|
||||
|
||||
tts_instance.setup_tts(on_data,completion_event)
|
||||
# 创建会话
|
||||
self.sessions[session_id] = {
|
||||
'tts_model': tts_model,
|
||||
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
||||
@@ -103,6 +117,7 @@ class StreamSessionManager:
|
||||
'active': True,
|
||||
'last_active': time.time(),
|
||||
'audio_chunk_count': 0,
|
||||
'audio_chunk_size': 0,
|
||||
'finished': threading.Event(), # 添加事件对象
|
||||
'sample_rate': sample_rate,
|
||||
'stream_format': stream_format,
|
||||
@@ -110,7 +125,9 @@ class StreamSessionManager:
|
||||
"text_buffer": "", # 新增文本缓冲区
|
||||
"last_text_time": time.time(), # 最后文本到达时间
|
||||
"streaming_call": streaming_call,
|
||||
"tts_stream_started": False # 标记是否已启动流
|
||||
"tts_stream_started": False, # 标记是否已启动流
|
||||
"sentence_complete_event": completion_event, #threading.Event(),
|
||||
"current_processing": False # 标记是否正在处理句子
|
||||
}
|
||||
# 启动任务处理线程
|
||||
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
||||
@@ -129,7 +146,7 @@ class StreamSessionManager:
|
||||
except queue.Full:
|
||||
logging.warning(f"Session {session_id} task queue full")
|
||||
|
||||
def _process_tasks(self, session_id):
|
||||
def _process_tasks1(self, session_id):
|
||||
"""任务处理线程(每个会话独立)"""
|
||||
session = self.sessions.get(session_id)
|
||||
if not session or not session['active']:
|
||||
@@ -204,7 +221,138 @@ class StreamSessionManager:
|
||||
# 休眠避免CPU空转
|
||||
time.sleep(0.05) # 50ms检查间隔
|
||||
|
||||
def _generate_audio(self, session_id, text):
|
||||
def _process_tasks(self, session_id): # 20250718 新更新
|
||||
"""任务处理线程(每个会话独立)- 保留原有处理逻辑"""
|
||||
session = self.sessions.get(session_id)
|
||||
if not session or not session['active']:
|
||||
return
|
||||
|
||||
# 根据引擎类型选择处理函数
|
||||
if session.get('streaming_call'):
|
||||
gen_tts_audio_func = self._stream_audio
|
||||
else:
|
||||
gen_tts_audio_func = self._generate_audio
|
||||
|
||||
while session['active']:
|
||||
current_time = time.time()
|
||||
text_to_process = ""
|
||||
|
||||
# 1. 获取待处理文本
|
||||
with self.lock:
|
||||
if session['text_buffer']:
|
||||
text_to_process = session['text_buffer']
|
||||
|
||||
|
||||
# 2. 处理文本
|
||||
if text_to_process and not session['current_processing'] :
|
||||
session['text_buffer'] = ""
|
||||
# 分割完整句子
|
||||
complete_sentences, remaining_text = self._split_and_extract(text_to_process)
|
||||
|
||||
# 保存剩余文本
|
||||
if remaining_text:
|
||||
with self.lock:
|
||||
session['text_buffer'] = remaining_text + session['text_buffer']
|
||||
|
||||
# 合并并处理完整句子
|
||||
if complete_sentences:
|
||||
# 智能合并句子(最长300字符)
|
||||
buffer = []
|
||||
current_length = 0
|
||||
|
||||
# 处理每个句子
|
||||
for sentence in complete_sentences:
|
||||
sent_length = len(sentence)
|
||||
|
||||
# 添加到当前缓冲区
|
||||
if current_length + sent_length <= 300:
|
||||
buffer.append(sentence)
|
||||
current_length += sent_length
|
||||
else:
|
||||
# 处理已缓冲的文本
|
||||
if buffer:
|
||||
combined_text = "".join(buffer)
|
||||
|
||||
# 重置完成事件状态
|
||||
session['sentence_complete_event'].clear()
|
||||
session['current_processing'] = True
|
||||
|
||||
# 生成音频
|
||||
gen_tts_audio_func(session_id, combined_text)
|
||||
|
||||
# 等待完成
|
||||
if not session['sentence_complete_event'].wait(timeout=120.0):
|
||||
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
|
||||
# 重置处理状态
|
||||
time.sleep(5.0)
|
||||
session['current_processing'] = False
|
||||
logging.info(f"StreamSessionManager _process_tasks 转换结束!!!")
|
||||
# 重置缓冲区
|
||||
buffer = [sentence]
|
||||
current_length = sent_length
|
||||
|
||||
# 处理剩余的缓冲文本
|
||||
if buffer:
|
||||
combined_text = "".join(buffer)
|
||||
|
||||
# 重置完成事件状态
|
||||
session['sentence_complete_event'].clear()
|
||||
session['current_processing'] = True
|
||||
|
||||
# 生成音频
|
||||
gen_tts_audio_func(session_id, combined_text)
|
||||
|
||||
# 等待完成
|
||||
if not session['sentence_complete_event'].wait(timeout=120.0):
|
||||
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
|
||||
# 重置处理状态
|
||||
time.sleep(1.0)
|
||||
session['current_processing'] = False
|
||||
logging.info(f"StreamSessionManager _process_tasks 转换结束!!!")
|
||||
|
||||
# 3. 检查超时未处理的文本
|
||||
if current_time - session['last_text_time'] > self.sentence_timeout:
|
||||
with self.lock:
|
||||
if session['text_buffer']:
|
||||
# 直接处理剩余文本
|
||||
session['sentence_complete_event'].clear()
|
||||
session['current_processing'] = True
|
||||
gen_tts_audio_func(session_id, session['text_buffer'])
|
||||
session['text_buffer'] = ""
|
||||
|
||||
# 等待完成
|
||||
if not session['sentence_complete_event'].wait(timeout=120.0):
|
||||
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
|
||||
# 重置处理状态
|
||||
session['current_processing'] = False
|
||||
|
||||
# 4. 会话超时检查
|
||||
if current_time - session['last_active'] > self.gc_interval:
|
||||
# 处理剩余文本
|
||||
with self.lock:
|
||||
if session['text_buffer']:
|
||||
# 重置完成事件状态
|
||||
session['sentence_complete_event'].clear()
|
||||
session['current_processing'] = True
|
||||
|
||||
# 处理最后一段文本
|
||||
gen_tts_audio_func(session_id, session['text_buffer'])
|
||||
session['text_buffer'] = ""
|
||||
|
||||
# 等待完成
|
||||
if not session['sentence_complete_event'].wait(timeout=120.0):
|
||||
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
|
||||
# 重置处理状态
|
||||
session['current_processing'] = False
|
||||
|
||||
# 关闭会话
|
||||
self.close_session(session_id)
|
||||
break
|
||||
|
||||
# 5. 休眠避免CPU空转
|
||||
time.sleep(0.05) # 50ms检查间隔
|
||||
|
||||
def _generate_audio1(self, session_id, text):
|
||||
"""实际生成音频(线程池执行)"""
|
||||
session = self.sessions.get(session_id)
|
||||
if not session: return
|
||||
@@ -236,6 +384,26 @@ class StreamSessionManager:
|
||||
except Exception as e:
|
||||
session['buffer'].put(f"ERROR:{str(e)}")
|
||||
|
||||
def _generate_audio(self, session_id, text): # 20250718 新更新
|
||||
"""实际生成音频(顺序执行)- 用于非流式引擎"""
|
||||
session = self.sessions.get(session_id)
|
||||
if not session:
|
||||
return
|
||||
|
||||
try:
|
||||
logging.info(f"StreamSessionManager _generate_audio--0 {text}")
|
||||
# 调用 TTS
|
||||
session['tts_model'].text_tts_call(text)
|
||||
session['last_active'] = time.time()
|
||||
session['audio_chunk_count'] += 1
|
||||
|
||||
if not session['tts_chunk_data_valid']:
|
||||
session['tts_chunk_data_valid'] = True
|
||||
|
||||
except Exception as e:
|
||||
session['buffer'].put(f"ERROR:{str(e)}".encode())
|
||||
session['sentence_complete_event'].set() # 确保事件被设置
|
||||
|
||||
def _stream_audio(self, session_id, text):
|
||||
"""流式传输文本到TTS服务"""
|
||||
session = self.sessions.get(session_id)
|
||||
@@ -247,9 +415,12 @@ class StreamSessionManager:
|
||||
# 使用流式调用发送文本
|
||||
session['tts_model'].streaming_call(text)
|
||||
session['last_active'] = time.time()
|
||||
# 流式引擎不需要等待完成事件
|
||||
session['sentence_complete_event'].set()
|
||||
except Exception as e:
|
||||
logging.error(f"Error in streaming_call: {str(e)}")
|
||||
session['buffer'].put(f"ERROR:{str(e)}".encode())
|
||||
session['sentence_complete_event'].set()
|
||||
|
||||
async def get_tts_buffer_data(self, session_id):
|
||||
"""异步流式返回 TTS 音频数据(适配同步 queue.Queue,带 10 秒超时)"""
|
||||
@@ -298,6 +469,8 @@ class StreamSessionManager:
|
||||
|
||||
# 标记会话为不活跃
|
||||
self.sessions[session_id]['active'] = False
|
||||
# 设置完成事件(确保任何等待的线程被唤醒)
|
||||
self.sessions[session_id]['sentence_complete_event'].set()
|
||||
# 延迟2秒后清理资源
|
||||
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
||||
|
||||
@@ -827,7 +1000,8 @@ from dashscope.audio.tts_v2 import (
|
||||
|
||||
|
||||
class QwenTTS:
|
||||
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun"):
|
||||
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun",
|
||||
special_characters: Optional[Dict[str, str]] = None):
|
||||
import dashscope
|
||||
import ssl
|
||||
logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx
|
||||
@@ -844,14 +1018,22 @@ class QwenTTS:
|
||||
if '/' in self.model_name:
|
||||
parts = self.model_name.split('/', 1)
|
||||
# 返回分离后的两个字符串parts[0], parts[1]
|
||||
if parts[0] == 'cosyvoice-v1':
|
||||
if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2':
|
||||
self.is_cosyvoice = True
|
||||
self.voice = parts[1]
|
||||
self.completion_event = None # 新增:用于通知任务完成
|
||||
# 特殊字符及其拼音映射
|
||||
self.special_characters = special_characters or {
|
||||
"㼽": "chuang3",
|
||||
"䡇": "yue4"
|
||||
# 可以添加更多特殊字符的映射
|
||||
}
|
||||
|
||||
class Callback(TTSResultCallback):
|
||||
def __init__(self) -> None:
|
||||
def __init__(self,data_callback=None,completion_event=None) -> None:
|
||||
self.dque = deque()
|
||||
|
||||
self.data_callback = data_callback
|
||||
self.completion_event = completion_event # 新增完成事件引用
|
||||
def _run(self):
|
||||
while True:
|
||||
if not self.dque:
|
||||
@@ -867,7 +1049,13 @@ class QwenTTS:
|
||||
pass
|
||||
|
||||
def on_complete(self):
|
||||
logging.info(f"---QwenTTS Callback on_complete--")
|
||||
self.dque.append(None)
|
||||
if self.data_callback:
|
||||
self.data_callback(None) # 发送结束信号
|
||||
# 通知任务完成
|
||||
if self.completion_event:
|
||||
self.completion_event.set()
|
||||
|
||||
def on_error(self, response: SpeechSynthesisResponse):
|
||||
print("Qwen tts error", str(response))
|
||||
@@ -877,15 +1065,22 @@ class QwenTTS:
|
||||
pass
|
||||
|
||||
def on_event(self, result: TTSSpeechSynthesisResult):
|
||||
if result.get_audio_frame() is not None:
|
||||
self.dque.append(result.get_audio_frame())
|
||||
data =result.get_audio_frame()
|
||||
if data is not None:
|
||||
if len(data) > 0:
|
||||
if self.data_callback:
|
||||
self.data_callback(data)
|
||||
else:
|
||||
self.dque.append(data)
|
||||
#self.dque.append(result.get_audio_frame())
|
||||
|
||||
# --------------------------
|
||||
|
||||
class Callback_Cosy(CosyResultCallback):
|
||||
def __init__(self, data_callback=None) -> None:
|
||||
def __init__(self, data_callback=None,completion_event=None) -> None:
|
||||
self.dque = deque()
|
||||
self.data_callback = data_callback
|
||||
self.completion_event = completion_event # 新增完成事件引用
|
||||
|
||||
def _run(self):
|
||||
while True:
|
||||
@@ -906,6 +1101,9 @@ class QwenTTS:
|
||||
self.dque.append(None)
|
||||
if self.data_callback:
|
||||
self.data_callback(None) # 发送结束信号
|
||||
# 通知任务完成
|
||||
if self.completion_event:
|
||||
self.completion_event.set()
|
||||
|
||||
def on_error(self, response: SpeechSynthesisResponse):
|
||||
print("Qwen tts error", str(response))
|
||||
@@ -938,23 +1136,28 @@ class QwenTTS:
|
||||
|
||||
# --------------------------
|
||||
|
||||
def tts(self, text):
|
||||
def tts(self, text, on_data = None,completion_event=None):
|
||||
# logging.info(f"---QwenTTS tts begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
||||
# text = self.normalize_text(text)
|
||||
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
||||
# text = self.normalize_text(text)
|
||||
|
||||
try:
|
||||
# if self.model_name != 'cosyvoice-v1':
|
||||
if self.is_cosyvoice is False:
|
||||
self.callback = self.Callback()
|
||||
self.callback = self.Callback(
|
||||
data_callback=on_data,
|
||||
completion_event=completion_event
|
||||
)
|
||||
TTSSpeechSynthesizer.call(model=self.model_name,
|
||||
text=text,
|
||||
callback=self.callback,
|
||||
format="wav") # format="mp3")
|
||||
format=self.format) # format="mp3")
|
||||
else:
|
||||
self.callback = self.Callback_Cosy()
|
||||
format = self.get_audio_format(self.format, self.sample_rate)
|
||||
self.synthesizer = CosySpeechSynthesizer(
|
||||
model='cosyvoice-v1',
|
||||
model='cosyvoice-v2',
|
||||
# voice="longyuan", #"longfei",
|
||||
voice=self.voice,
|
||||
callback=self.callback,
|
||||
@@ -974,26 +1177,68 @@ class QwenTTS:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"**ERROR**: {e}")
|
||||
|
||||
def setup_tts(self, on_data):
|
||||
"""设置 TTS 回调,返回配置好的 synthesizer"""
|
||||
if not self.is_cosyvoice:
|
||||
raise NotImplementedError("Only CosyVoice supported")
|
||||
def setup_tts(self, on_data,completion_event=None):
|
||||
|
||||
# 创建 CosyVoice 回调
|
||||
self.callback = self.Callback_Cosy(on_data)
|
||||
"""设置 TTS 回调,返回配置好的 synthesizer"""
|
||||
#if not self.is_cosyvoice:
|
||||
# raise NotImplementedError("Only CosyVoice supported")
|
||||
|
||||
if self.is_cosyvoice:
|
||||
# 创建 CosyVoice 回调
|
||||
self.callback = self.Callback_Cosy(
|
||||
data_callback=on_data,
|
||||
completion_event=completion_event)
|
||||
else:
|
||||
self.callback = self.Callback(
|
||||
data_callback=on_data,
|
||||
completion_event=completion_event)
|
||||
format_val = self.get_audio_format(self.format, self.sample_rate)
|
||||
logging.info(f"setup_tts {self.voice} {format_val}")
|
||||
self.synthesizer = CosySpeechSynthesizer(
|
||||
model='cosyvoice-v1',
|
||||
voice=self.voice, # voice="longyuan", #"longfei",
|
||||
callback=self.callback,
|
||||
format=format_val
|
||||
)
|
||||
logging.info(f"Qwen setup_tts {self.voice} {format_val}")
|
||||
if self.is_cosyvoice:
|
||||
self.synthesizer = CosySpeechSynthesizer(
|
||||
model='cosyvoice-v1',
|
||||
voice=self.voice, # voice="longyuan", #"longfei",
|
||||
callback=self.callback,
|
||||
format=format_val
|
||||
)
|
||||
|
||||
return self.synthesizer
|
||||
|
||||
def apply_phoneme_tags(self, text: str) -> str:
|
||||
"""
|
||||
在文本中查找特殊字符并用<phoneme>标签包裹它们
|
||||
"""
|
||||
# 如果文本已经是SSML格式,直接返回
|
||||
if text.strip().startswith("<speak>") and text.strip().endswith("</speak>"):
|
||||
return text
|
||||
|
||||
# 为特殊字符添加SSML标签
|
||||
for char, pinyin in self.special_characters.items():
|
||||
# 使用正则表达式确保只替换整个字符(避免部分匹配)
|
||||
pattern = r'([^<]|^)' + re.escape(char) + r'([^>]|$)'
|
||||
replacement = r'\1<phoneme alphabet="py" ph="' + pinyin + r'">' + char + r'</phoneme>\2'
|
||||
text = re.sub(pattern, replacement, text)
|
||||
|
||||
# 如果文本中已有SSML标签,直接返回
|
||||
if "<speak>" in text:
|
||||
return text
|
||||
|
||||
# 否则包裹在<speak>标签中
|
||||
return f"<speak>{text}</speak>"
|
||||
|
||||
def text_tts_call(self, text):
|
||||
if self.synthesizer:
|
||||
if self.special_characters and self.is_cosyvoice is False:
|
||||
text = self.apply_phoneme_tags(text)
|
||||
#logging.info(f"Applied SSML phoneme tags to text: {text}")
|
||||
|
||||
if self.synthesizer and self.is_cosyvoice:
|
||||
self.synthesizer.call(text)
|
||||
if self.is_cosyvoice is False:
|
||||
logging.info(f"Qwen text_tts_call {text}")
|
||||
TTSSpeechSynthesizer.call(model=self.model_name,
|
||||
text=text,
|
||||
callback=self.callback,
|
||||
format=self.format)
|
||||
|
||||
def streaming_call(self, text):
|
||||
if self.synthesizer:
|
||||
@@ -1027,22 +1272,178 @@ class QwenTTS:
|
||||
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
|
||||
|
||||
|
||||
import threading
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
class DoubaoTTS:
|
||||
def __init__(self, key, format="mp3", sample_rate=8000, model_name="doubao-tts"):
|
||||
logging.info(f"---begin--init DoubaoTTS-- {format} {sample_rate} {model_name}")
|
||||
# 解析豆包认证信息 (appid, token, cluster, voice_type)
|
||||
try:
|
||||
self.appid = "7282190702"
|
||||
self.token = "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9"
|
||||
self.cluster = "volcano_tts"
|
||||
self.voice_type ="zh_female_qingxinnvsheng_mars_bigtts" # "zh_male_jieshuonansheng_mars_bigtts" #"zh_male_ruyaqingnian_mars_bigtts" #"zh_male_jieshuonansheng_mars_bigtts"
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid Doubao key format: {str(e)}")
|
||||
|
||||
import threading
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import deque
|
||||
from io import BytesIO
|
||||
self.format = format
|
||||
self.sample_rate = sample_rate
|
||||
self.model_name = model_name
|
||||
self.callback = None
|
||||
self.ws = None
|
||||
self.loop = None
|
||||
self.task = None
|
||||
self.event = threading.Event()
|
||||
self.data_queue = deque()
|
||||
self.host = "openspeech.bytedance.com"
|
||||
self.api_url = f"wss://{self.host}/api/v1/tts/ws_binary"
|
||||
self.default_header = bytearray(b'\x11\x10\x11\x00')
|
||||
self.total_data_size = 0
|
||||
self.completion_event = None # 新增:用于通知任务完成
|
||||
|
||||
|
||||
class Callback:
|
||||
def __init__(self, data_callback=None,completion_event=None):
|
||||
self.data_callback = data_callback
|
||||
self.data_queue = deque()
|
||||
self.completion_event = completion_event # 完成事件引用
|
||||
|
||||
def on_data(self, data):
|
||||
if self.data_callback:
|
||||
self.data_callback(data)
|
||||
else:
|
||||
self.data_queue.append(data)
|
||||
# 通知任务完成
|
||||
if self.completion_event:
|
||||
self.completion_event.set()
|
||||
|
||||
def on_complete(self):
|
||||
if self.data_callback:
|
||||
self.data_callback(None)
|
||||
|
||||
def on_error(self, error):
|
||||
if self.data_callback:
|
||||
self.data_callback(f"ERROR:{error}".encode())
|
||||
|
||||
def setup_tts(self, on_data,completion_event):
|
||||
|
||||
"""设置回调,返回自身(因为豆包需要异步启动)"""
|
||||
self.callback = self.Callback(
|
||||
data_callback=on_data,
|
||||
completion_event=completion_event
|
||||
)
|
||||
return self
|
||||
|
||||
def text_tts_call(self, text):
|
||||
"""同步调用,启动异步任务并等待完成"""
|
||||
self.total_data_size = 0
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
self.task = self.loop.create_task(self._async_tts(text))
|
||||
try:
|
||||
self.loop.run_until_complete(self.task)
|
||||
except Exception as e:
|
||||
logging.error(f"DoubaoTTS--0 call error: {e}")
|
||||
self.callback.on_error(str(e))
|
||||
|
||||
async def _async_tts(self, text):
|
||||
"""异步执行TTS请求"""
|
||||
header = {"Authorization": f"Bearer; {self.token}"}
|
||||
request_json = {
|
||||
"app": {
|
||||
"appid": self.appid,
|
||||
"token": "access_token", # 固定值
|
||||
"cluster": self.cluster
|
||||
},
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4()) # 随机用户ID
|
||||
},
|
||||
"audio": {
|
||||
"voice_type": self.voice_type,
|
||||
"encoding": self.format,
|
||||
"speed_ratio": 1.0,
|
||||
"volume_ratio": 1.0,
|
||||
"pitch_ratio": 1.0,
|
||||
},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()),
|
||||
"text": text,
|
||||
"text_type": "plain",
|
||||
"operation": "submit" # 使用submit模式支持流式
|
||||
}
|
||||
}
|
||||
|
||||
# 构建请求数据
|
||||
payload_bytes = str.encode(json.dumps(request_json))
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
full_client_request = bytearray(self.default_header)
|
||||
full_client_request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||||
full_client_request.extend(payload_bytes)
|
||||
|
||||
try:
|
||||
async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws:
|
||||
self.ws = ws
|
||||
await ws.send(full_client_request)
|
||||
|
||||
# 接收音频数据
|
||||
while True:
|
||||
res = await ws.recv()
|
||||
done = self._parse_response(res)
|
||||
if done:
|
||||
self.callback.on_complete()
|
||||
break
|
||||
except Exception as e:
|
||||
logging.error(f"DoubaoTTS--1 WebSocket error: {e}")
|
||||
self.callback.on_error(str(e))
|
||||
finally:
|
||||
# 通知任务完成
|
||||
if self.completion_event:
|
||||
self.completion_event.set()
|
||||
|
||||
|
||||
def _parse_response(self, res):
|
||||
"""解析豆包返回的二进制响应"""
|
||||
# 协议头解析 (4字节)
|
||||
header_size = res[0] & 0x0f
|
||||
message_type = res[1] >> 4
|
||||
payload = res[header_size * 4:]
|
||||
|
||||
# 音频数据响应
|
||||
if message_type == 0xb: # audio-only server response
|
||||
message_flags = res[1] & 0x0f
|
||||
|
||||
# ACK消息,忽略
|
||||
if message_flags == 0:
|
||||
return False
|
||||
|
||||
# 音频数据消息
|
||||
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
|
||||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
audio_data = payload[8:8 + payload_size]
|
||||
|
||||
if audio_data:
|
||||
self.total_data_size = self.total_data_size + len(audio_data)
|
||||
self.callback.on_data(audio_data)
|
||||
|
||||
#logging.info(f"doubao _parse_response: {sequence_number} {len(audio_data)} {self.total_data_size}")
|
||||
# 序列号为负表示结束
|
||||
return sequence_number < 0
|
||||
|
||||
# 错误响应
|
||||
elif message_type == 0xf:
|
||||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||||
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||||
error_msg = payload[8:8 + msg_size]
|
||||
|
||||
try:
|
||||
# 尝试解压错误消息
|
||||
error_msg = gzip.decompress(error_msg).decode()
|
||||
except:
|
||||
error_msg = error_msg.decode(errors='ignore')
|
||||
|
||||
logging.error(f"DoubaoTTS error: {error_msg}")
|
||||
self.callback.on_error(error_msg)
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class UnifiedTTSEngine:
|
||||
@@ -1144,13 +1545,29 @@ class UnifiedTTSEngine:
|
||||
return
|
||||
|
||||
try:
|
||||
# 创建完成事件
|
||||
completion_event = threading.Event()
|
||||
# 创建TTS实例
|
||||
tts = QwenTTS(
|
||||
key=task['key'],
|
||||
format=task['format'],
|
||||
sample_rate=task['sample_rate'],
|
||||
model_name=task['model_name']
|
||||
)
|
||||
# 根据model_name选择TTS引擎
|
||||
# 前端传入 cosyvoice-v1/longhua@Tongyi-Qianwen
|
||||
model_name_wo_brand = task['model_name'].split('@')[0]
|
||||
model_name_version = model_name_wo_brand.split('/')[0]
|
||||
if "longhua" in task['model_name'] or "zh_female_qingxinnvsheng_mars_bigtts" in task['model_name']:
|
||||
# 豆包TTS
|
||||
tts = DoubaoTTS(
|
||||
key=task['key'],
|
||||
format=task['format'],
|
||||
sample_rate=task['sample_rate'],
|
||||
model_name=task['model_name']
|
||||
)
|
||||
else:
|
||||
# 通义千问TTS
|
||||
tts = QwenTTS(
|
||||
key=task['key'],
|
||||
format=task['format'],
|
||||
sample_rate=task['sample_rate'],
|
||||
model_name=task['model_name']
|
||||
)
|
||||
|
||||
# 定义同步数据处理函数
|
||||
def data_handler(data):
|
||||
@@ -1165,20 +1582,28 @@ class UnifiedTTSEngine:
|
||||
else: # 音频数据
|
||||
task['data_queue'].append(data)
|
||||
|
||||
# 设置并执行TTS
|
||||
synthesizer = tts.setup_tts(data_handler)
|
||||
synthesizer.call(task['text'])
|
||||
|
||||
# 设置并执行TTS
|
||||
synthesizer = tts.setup_tts(data_handler,completion_event)
|
||||
#synthesizer.call(task['text'])
|
||||
tts.text_tts_call(task['text'])
|
||||
# 等待完成或超时
|
||||
if not task['event'].wait(timeout=300): # 5分钟超时
|
||||
# 等待完成或超时
|
||||
if not completion_event.wait(timeout=300): # 5分钟超时
|
||||
task['error'] = "TTS generation timeout"
|
||||
task['completed'] = True
|
||||
|
||||
logging.info(f"--tts task event set error = {task['error']}")
|
||||
|
||||
except Exception as e:
|
||||
logging.info(f"UnifiedTTSEngine _run_tts_sync ERROR: {str(e)}")
|
||||
task['error'] = f"ERROR:{str(e)}"
|
||||
task['completed'] = True
|
||||
finally:
|
||||
# 确保清理TTS资源
|
||||
logging.info("UnifiedTTSEngine _run_tts_sync finally")
|
||||
if hasattr(tts, 'loop') and tts.loop:
|
||||
tts.loop.close()
|
||||
|
||||
def _merge_audio_data(self, audio_stream_id):
|
||||
"""将任务的所有音频数据合并到ByteIO缓冲区"""
|
||||
@@ -1219,7 +1644,7 @@ class UnifiedTTSEngine:
|
||||
# 如果是延迟任务且未启动,现在启动 status 为 pending
|
||||
if task['delay_gen_audio'] and task['status'] == 'pending':
|
||||
self._start_tts_task(audio_stream_id)
|
||||
|
||||
total_audio_data_size = 0
|
||||
# 等待任务启动
|
||||
while task['status'] == 'pending':
|
||||
await asyncio.sleep(0.1)
|
||||
@@ -1228,7 +1653,8 @@ class UnifiedTTSEngine:
|
||||
while not task['completed'] or task['data_queue']:
|
||||
while task['data_queue']:
|
||||
data = task['data_queue'].popleft()
|
||||
# logging.info(f"yield data {len(data)}")
|
||||
total_audio_data_size += len(data)
|
||||
#logging.info(f"yield audio data {len(data)} {total_audio_data_size}")
|
||||
yield data
|
||||
|
||||
# 短暂等待新数据
|
||||
@@ -1318,6 +1744,7 @@ async def proxy_aichat_audio_stream(client_id: str, audio_url: str):
|
||||
|
||||
|
||||
# 代理函数 - 文本流
|
||||
# 在微信小程序中,原来APK使用的SSE机制不能正常工作,需要使用WebSocket
|
||||
async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict):
|
||||
"""代理大模型文本流请求 - 兼容现有Flask实现"""
|
||||
try:
|
||||
@@ -1328,13 +1755,19 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm'
|
||||
}
|
||||
tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
|
||||
#if 'longyuan' in tts_model_name:
|
||||
# tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen"
|
||||
# 创建TTS实例
|
||||
tts_model = QwenTTS(
|
||||
key=ALI_KEY,
|
||||
format=payload.get('tts_stream_format', 'mp3'),
|
||||
sample_rate=payload.get('tts_sample_rate', 48000),
|
||||
model_name=payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
|
||||
model_name=tts_model_name
|
||||
)
|
||||
streaming_call = False
|
||||
if tts_model.is_cosyvoice:
|
||||
streaming_call = True
|
||||
|
||||
# 创建流会话
|
||||
tts_stream_session_id = stream_manager.create_session(
|
||||
@@ -1342,7 +1775,7 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload
|
||||
sample_rate=payload.get('tts_sample_rate', 48000),
|
||||
stream_format=payload.get('tts_stream_format', 'mp3'),
|
||||
session_id=None,
|
||||
streaming_call=True
|
||||
streaming_call= streaming_call
|
||||
)
|
||||
# logging.info(f"---tts_stream_session_id = {tts_stream_session_id}")
|
||||
tts_stream_session_id_sent = False
|
||||
@@ -1399,7 +1832,6 @@ async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload
|
||||
data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}"
|
||||
data_str = json.dumps(data_obj)
|
||||
tts_stream_session_id_sent = True
|
||||
|
||||
# 直接转发原始数据
|
||||
await manager.send_text(client_id, json.dumps({
|
||||
"type": "text",
|
||||
@@ -1717,12 +2149,14 @@ async def websocket_tts_endpoint(
|
||||
audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}"
|
||||
# await proxy_aichat_audio_stream(connection_id, audio_url)
|
||||
sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate')
|
||||
audio_data_size =0
|
||||
await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate})
|
||||
async for data in stream_manager.get_tts_buffer_data(audio_stream_id):
|
||||
audio_data_size += len(data)
|
||||
if not await manager.send_bytes(connection_id, data):
|
||||
break
|
||||
completed_successfully = True
|
||||
|
||||
logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}")
|
||||
elif service_type == "AiChatText":
|
||||
# 文本代理服务
|
||||
# 等待客户端发送初始请求数据 进行大模型对话代理时,需要前端连接后发送payload
|
||||
|
||||
Reference in New Issue
Block a user