2329 lines
91 KiB
Python
2329 lines
91 KiB
Python
import logging
|
||
import binascii
|
||
from copy import deepcopy
|
||
from timeit import default_timer as timer
|
||
import datetime
|
||
from datetime import timedelta
|
||
import threading, time, queue, uuid, time, array
|
||
from threading import Lock, Thread
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import base64, gzip
|
||
import os, io, re, json
|
||
from io import BytesIO
|
||
from typing import Optional, Dict, Any
|
||
import asyncio, httpx
|
||
from collections import deque
|
||
|
||
import websockets
|
||
import uuid
|
||
|
||
from fastapi import WebSocket, APIRouter, WebSocketDisconnect, Request, Body, Query
|
||
from fastapi import FastAPI, UploadFile, File, Form, Header
|
||
from fastapi.responses import StreamingResponse, JSONResponse, Response
|
||
|
||
TTS_SAMPLERATE = 44100 # 22050 # 16000
|
||
FORMAT = "mp3"
|
||
CHANNELS = 1 # 单声道
|
||
SAMPLE_WIDTH = 2 # 16-bit = 2字节
|
||
|
||
tts_router = APIRouter()
|
||
|
||
|
||
# logger = logging.getLogger(__name__)
|
||
class MillisecondsFormatter(logging.Formatter):
|
||
"""自定义日志格式器,添加毫秒时间戳"""
|
||
|
||
def formatTime(self, record, datefmt=None):
|
||
# 将时间戳转换为本地时间元组
|
||
ct = self.converter(record.created)
|
||
# 格式化为 "小时:分钟:秒"
|
||
t = time.strftime("%H:%M:%S", ct)
|
||
# 添加毫秒(3位)
|
||
return f"{t}.{int(record.msecs):03d}"
|
||
|
||
|
||
# 配置全局日志格式
|
||
def configure_logging():
|
||
# 创建 Formatter
|
||
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
||
formatter = MillisecondsFormatter(log_format)
|
||
|
||
# 获取根 Logger 并清除已有配置
|
||
root_logger = logging.getLogger()
|
||
root_logger.handlers = []
|
||
|
||
# 创建并配置 Handler(输出到控制台)
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 设置日志级别并添加 Handler
|
||
root_logger.setLevel(logging.INFO)
|
||
root_logger.addHandler(console_handler)
|
||
|
||
|
||
# 调用配置函数(程序启动时运行一次)
|
||
configure_logging()
|
||
|
||
|
||
class StreamSessionManager:
|
||
def __init__(self):
|
||
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
|
||
self.lock = threading.Lock()
|
||
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
||
self.gc_interval = 300 # 5分钟清理一次
|
||
self.streaming_call_timeout = 10 # 10s
|
||
self.gc_tts = 3 # 3s
|
||
self.sentence_timeout = 2 # 2000ms句子超时
|
||
self.sentence_endings = set('。?!;.?!;') # 中英文结束符
|
||
# 增强版正则表达式:匹配中英文句子结束符(包含全角)
|
||
self.sentence_pattern = re.compile(
|
||
r'([,,。?!;.?!;?!;…]+["\'”’]?)(?=\s|$|[^,,。?!;.?!;?!;…])'
|
||
)
|
||
self.sentence_audio_store = {} # {sentence_id: {'data': bytes, 'text': str, 'created_at': float}}
|
||
self.sentence_lock = threading.Lock()
|
||
|
||
threading.Thread(target=self._cleanup_expired, daemon=True).start()
|
||
|
||
def create_session(self, tts_model, sample_rate=8000, stream_format='mp3', session_id=None, streaming_call=False):
|
||
if not session_id:
|
||
session_id = str(uuid.uuid4())
|
||
with self.lock:
|
||
# 创建TTS实例并设置流式回调
|
||
tts_instance = tts_model
|
||
|
||
# 定义音频数据回调函数
|
||
def on_data(data: bytes):
|
||
if data:
|
||
try:
|
||
self.sessions[session_id]['last_active'] = time.time()
|
||
self.sessions[session_id]['buffer'].put(data)
|
||
self.sessions[session_id]['audio_chunk_size'] += len(data)
|
||
#logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}")
|
||
except queue.Full:
|
||
logging.warning(f"Audio buffer full for session {session_id}")
|
||
"""
|
||
elif data is None: # 结束信号
|
||
# 仅对非流式引擎触发完成事件
|
||
if not streaming_call:
|
||
logging.info(f"StreamSessionManager on_data sentence_complete_event set")
|
||
self.sessions[session_id]['sentence_complete_event'].set()
|
||
self.sessions[session_id]['current_processing'] = False
|
||
"""
|
||
# 创建完成事件
|
||
completion_event = threading.Event()
|
||
# 设置TTS流式传输
|
||
tts_instance.setup_tts(on_data,completion_event)
|
||
# 创建会话
|
||
self.sessions[session_id] = {
|
||
'tts_model': tts_model,
|
||
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
||
'task_queue': queue.Queue(),
|
||
'active': True,
|
||
'last_active': time.time(),
|
||
'audio_chunk_count': 0,
|
||
'audio_chunk_size': 0,
|
||
'finished': threading.Event(), # 添加事件对象
|
||
'sample_rate': sample_rate,
|
||
'stream_format': stream_format,
|
||
"tts_chunk_data_valid": False,
|
||
"text_buffer": "", # 新增文本缓冲区
|
||
"last_text_time": time.time(), # 最后文本到达时间
|
||
"streaming_call": streaming_call,
|
||
"tts_stream_started": False, # 标记是否已启动流
|
||
"current_processing": False, # 标记是否正在处理句子
|
||
"sentence_complete_event": completion_event, #threading.Event(),
|
||
'sentences': [], # 存储句子ID列表
|
||
'current_sentence_index': 0,
|
||
'gc_interval':300, # 5分钟清理一次
|
||
}
|
||
# 启动任务处理线程
|
||
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
||
return session_id
|
||
|
||
def append_text(self, session_id, text):
|
||
with self.lock:
|
||
session = self.sessions.get(session_id)
|
||
if not session: return
|
||
# 更新文本缓冲区和时间戳
|
||
session['text_buffer'] += text
|
||
session['last_text_time'] = time.time()
|
||
# 将文本放入任务队列(非阻塞)
|
||
try:
|
||
session['task_queue'].put(text, block=False)
|
||
except queue.Full:
|
||
logging.warning(f"Session {session_id} task queue full")
|
||
|
||
def finish_text_input(self, session_id):
|
||
"""标记文本输入结束,通知任务处理线程"""
|
||
with self.lock:
|
||
session = self.sessions.get(session_id)
|
||
if not session:
|
||
return
|
||
session['gc_interval'] = 100 # 所有的文本输入已经结束,可以将超时检查时间缩短
|
||
|
||
|
||
def _process_tasks(self, session_id): # 20250718 新更新
|
||
"""任务处理线程(每个会话独立)- 保留原有处理逻辑"""
|
||
session = self.sessions.get(session_id)
|
||
if not session or not session['active']:
|
||
return
|
||
|
||
# 根据引擎类型选择处理函数
|
||
if session.get('streaming_call'):
|
||
gen_tts_audio_func = self._generate_audio #self._stream_audio
|
||
else:
|
||
gen_tts_audio_func = self._generate_audio
|
||
|
||
while session['active']:
|
||
current_time = time.time()
|
||
text_to_process = ""
|
||
|
||
# 1. 获取待处理文本
|
||
with self.lock:
|
||
if session['text_buffer']:
|
||
text_to_process = session['text_buffer']
|
||
|
||
|
||
# 2. 处理文本
|
||
if text_to_process and not session['current_processing'] :
|
||
session['text_buffer'] = ""
|
||
# 分割完整句子
|
||
complete_sentences, remaining_text = self._split_and_extract(text_to_process)
|
||
|
||
# 保存剩余文本
|
||
if remaining_text:
|
||
with self.lock:
|
||
session['text_buffer'] = remaining_text + session['text_buffer']
|
||
|
||
# 合并并处理完整句子
|
||
if complete_sentences:
|
||
# 智能合并句子(最长300字符)
|
||
buffer = []
|
||
current_length = 0
|
||
|
||
# 处理每个句子
|
||
for sentence in complete_sentences:
|
||
sent_length = len(sentence)
|
||
|
||
# 添加到当前缓冲区
|
||
if current_length + sent_length <= 300:
|
||
buffer.append(sentence)
|
||
current_length += sent_length
|
||
else:
|
||
# 处理已缓冲的文本
|
||
if buffer:
|
||
combined_text = "".join(buffer)
|
||
|
||
# 重置完成事件状态
|
||
session['sentence_complete_event'].clear()
|
||
session['current_processing'] = True
|
||
|
||
# 生成音频
|
||
gen_tts_audio_func(session_id, combined_text)
|
||
|
||
# 等待完成
|
||
if not session['sentence_complete_event'].wait(timeout=120.0):
|
||
logging.warning(f"Timeout waiting for TTS completion: {combined_text}")
|
||
# 重置处理状态
|
||
time.sleep(5.0)
|
||
session['current_processing'] = False
|
||
logging.info(f"StreamSessionManager _process_tasks 转换结束!!!")
|
||
# 重置缓冲区
|
||
buffer = [sentence]
|
||
current_length = sent_length
|
||
|
||
# 处理剩余的缓冲文本
|
||
if buffer:
|
||
combined_text = "".join(buffer)
|
||
session['current_processing'] = True
|
||
|
||
# 生成音频
|
||
gen_tts_audio_func(session_id, combined_text) #如果调用_stream_audio,则是同步调用,会阻塞,直到音频生成完成
|
||
|
||
# 重置处理状态
|
||
time.sleep(1.0)
|
||
session['current_processing'] = False
|
||
|
||
|
||
# 3. 检查超时未处理的文本
|
||
if current_time - session['last_text_time'] > self.sentence_timeout:
|
||
with self.lock:
|
||
if session['text_buffer']:
|
||
# 直接处理剩余文本
|
||
session['current_processing'] = True
|
||
gen_tts_audio_func(session_id, session['text_buffer'])
|
||
session['text_buffer'] = ""
|
||
# 重置处理状态
|
||
session['current_processing'] = False
|
||
|
||
# 4. 会话超时检查
|
||
if current_time - session['last_active'] > session['gc_interval']:
|
||
# 处理剩余文本
|
||
with self.lock:
|
||
if session['text_buffer']:
|
||
session['current_processing'] = True
|
||
|
||
# 处理最后一段文本
|
||
gen_tts_audio_func(session_id, session['text_buffer'])
|
||
session['text_buffer'] = ""
|
||
# 重置处理状态
|
||
session['current_processing'] = False
|
||
|
||
# 关闭会话
|
||
logging.info(f"--_process_tasks-- timeout {session['last_active']} {session['gc_interval']}")
|
||
self.close_session(session_id)
|
||
break
|
||
|
||
# 5. 休眠避免CPU空转
|
||
time.sleep(0.05) # 50ms检查间隔
|
||
|
||
|
||
def _generate_audio(self, session_id, text): # 20250718 新更新
|
||
"""实际生成音频(顺序执行)- 用于非流式引擎"""
|
||
session = self.sessions.get(session_id)
|
||
if not session:
|
||
return
|
||
|
||
try:
|
||
#logging.info(f"StreamSessionManager _generate_audio--0 {text}")
|
||
# 创建内存流
|
||
audio_stream = io.BytesIO()
|
||
|
||
# 定义回调函数:直接写入流
|
||
def on_data_sentence(data: bytes):
|
||
if data:
|
||
audio_stream.write(data)
|
||
|
||
def on_data_whole(data: bytes):
|
||
if data:
|
||
try:
|
||
session['last_active'] = time.time()
|
||
session['buffer'].put({'type':'arraybuffer','data':data})
|
||
#session['buffer'].put(data)
|
||
session['audio_chunk_size'] += len(data)
|
||
#logging.info(f"StreamSessionManager on_data {len(data)} {self.sessions[session_id]['audio_chunk_size']}")
|
||
except queue.Full:
|
||
logging.warning(f"Audio buffer full for session {session_id}")
|
||
|
||
# 重置完成事件
|
||
session['sentence_complete_event'].clear()
|
||
session['tts_model'].setup_tts(on_data = on_data_whole,completion_event=session['sentence_complete_event'])
|
||
# 调用 TTS
|
||
session['tts_model'].text_tts_call(text)
|
||
session['last_active'] = time.time()
|
||
session['audio_chunk_count'] += 1
|
||
# 等待句子完成
|
||
if not session['sentence_complete_event'].wait(timeout=30): # 30秒超时
|
||
logging.warning(f"Timeout generating audio for: {text[:20]}...")
|
||
logging.info(f"StreamSessionManager _generate_audio 转换结束!!!")
|
||
session['buffer'].put({'type': 'sentence_end', 'data': ""})
|
||
|
||
# 获取音频数据
|
||
audio_data = audio_stream.getvalue()
|
||
audio_stream.close()
|
||
|
||
# 保存到句子存储
|
||
self.add_sentence_audio(session_id, text, audio_data)
|
||
|
||
if not session['tts_chunk_data_valid']:
|
||
session['tts_chunk_data_valid'] = True
|
||
|
||
except Exception as e:
|
||
session['buffer'].put(f"ERROR:{str(e)}".encode())
|
||
session['sentence_complete_event'].set() # 确保事件被设置
|
||
|
||
def _stream_audio(self, session_id, text):
|
||
"""流式传输文本到TTS服务"""
|
||
session = self.sessions.get(session_id)
|
||
if not session:
|
||
return
|
||
# logging.info(f"Streaming text to TTS: {text}")
|
||
|
||
try:
|
||
# 重置完成事件状态
|
||
session['sentence_complete_event'].clear()
|
||
# 使用流式调用发送文本
|
||
session['tts_model'].streaming_call(text)
|
||
session['last_active'] = time.time()
|
||
# 流式引擎不需要等待完成事件
|
||
session['sentence_complete_event'].set()
|
||
except Exception as e:
|
||
logging.error(f"Error in streaming_call: {str(e)}")
|
||
session['buffer'].put(f"ERROR:{str(e)}".encode())
|
||
session['sentence_complete_event'].set()
|
||
|
||
async def get_tts_buffer_data(self, session_id):
|
||
"""异步流式返回 TTS 音频数据(适配同步 queue.Queue,带 10 秒超时)"""
|
||
session = self.sessions.get(session_id)
|
||
if not session:
|
||
raise ValueError(f"Session {session_id} not found")
|
||
|
||
buffer = session['buffer'] # 这里是 queue.Queue
|
||
last_data_time = time.time() # 记录最后一次获取数据的时间
|
||
|
||
while session['active']:
|
||
try:
|
||
# 使用 run_in_executor + wait_for 设置 10 秒超时
|
||
data = await asyncio.wait_for(
|
||
asyncio.get_event_loop().run_in_executor(None, buffer.get),
|
||
timeout=10.0 # 10 秒超时
|
||
)
|
||
#logging.info(f"get_tts_buffer_data {data}")
|
||
last_data_time = time.time() # 更新最后数据时间
|
||
yield data
|
||
|
||
except asyncio.TimeoutError:
|
||
# 10 秒内没有新数据,检查是否超时
|
||
if time.time() - last_data_time >= 10.0:
|
||
break
|
||
else:
|
||
continue # 未超时,继续等待
|
||
|
||
except asyncio.CancelledError:
|
||
logging.info(f"Session {session_id} stream cancelled")
|
||
break
|
||
|
||
except Exception as e:
|
||
logging.error(f"Error in get_tts_buffer_data: {e}")
|
||
break
|
||
|
||
def close_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
session = self.sessions[session_id]
|
||
session['active'] = False
|
||
|
||
# 清理关联的音频数据
|
||
with self.sentence_lock:
|
||
# 标记会话相关的句子为过期
|
||
expired_sentences = [
|
||
sid for sid, data in self.sentence_audio_store.items()
|
||
if data.get('session_id') == session_id
|
||
]
|
||
|
||
# 设置完成事件
|
||
session['sentence_complete_event'].set()
|
||
|
||
# 清理TTS资源
|
||
try:
|
||
if session.get('streaming_call'):
|
||
session['tts_model'].end_streaming_call()
|
||
except:
|
||
pass
|
||
|
||
# 延迟清理会话
|
||
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
||
|
||
def _clean_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
del self.sessions[session_id]
|
||
|
||
def _cleanup_expired(self):
|
||
"""定时清理过期资源"""
|
||
while True:
|
||
time.sleep(30)
|
||
now = time.time()
|
||
|
||
# 清理过期会话
|
||
with self.lock:
|
||
expired_sessions = [
|
||
sid for sid, session in self.sessions.items()
|
||
if now - session['last_active'] > self.gc_interval
|
||
]
|
||
for sid in expired_sessions:
|
||
self._clean_session(sid)
|
||
|
||
# 清理过期句子音频
|
||
with self.sentence_lock:
|
||
expired_sentences = [
|
||
sid for sid, data in self.sentence_audio_store.items()
|
||
if now - data['created_at'] > self.gc_interval
|
||
]
|
||
for sid in expired_sentences:
|
||
del self.sentence_audio_store[sid]
|
||
|
||
logging.info(f"清理资源: {len(expired_sessions)}会话, {len(expired_sentences)}句子")
|
||
|
||
def get_session(self, session_id):
|
||
return self.sessions.get(session_id)
|
||
|
||
def _has_sentence_ending(self, text):
|
||
"""检测文本是否包含句子结束符"""
|
||
if not text:
|
||
return False
|
||
|
||
# 检查常见结束符(包含全角字符)
|
||
if any(char in self.sentence_endings for char in text[-3:]):
|
||
return True
|
||
|
||
# 检查中文段落结束(换行符前有结束符)
|
||
if '\n' in text and any(char in self.sentence_endings for char in text.split('\n')[-2:-1]):
|
||
return True
|
||
|
||
return False
|
||
|
||
def _split_and_extract(self, text):
|
||
"""
|
||
增强型句子分割器
|
||
返回: (完整句子列表, 剩余文本)
|
||
"""
|
||
# 特殊处理:如果文本以逗号开头,先处理前面的部分
|
||
if text.startswith((",", ",")):
|
||
return [text[0]], text[1:]
|
||
|
||
# 1. 查找所有可能的句子结束位置
|
||
matches = list(self.sentence_pattern.finditer(text))
|
||
|
||
if not matches:
|
||
return [], text # 没有找到结束符
|
||
|
||
# 2. 确定最后一个完整句子的结束位置
|
||
last_end = 0
|
||
complete_sentences = []
|
||
|
||
for match in matches:
|
||
end_pos = match.end()
|
||
sentence = text[last_end:end_pos].strip()
|
||
|
||
# 跳过空句子
|
||
if not sentence:
|
||
last_end = end_pos
|
||
continue
|
||
|
||
# 检查是否为有效句子(最小长度或包含结束符)
|
||
if (len(sentence) > 6 or any(char in "。.?!?!" for char in sentence)) and (last_end<24):
|
||
complete_sentences.append(sentence)
|
||
last_end = end_pos
|
||
else:
|
||
# 短文本但包含结束符,可能是特殊符号
|
||
if any(char in "。.?!?!" for char in sentence):
|
||
complete_sentences.append(sentence)
|
||
last_end = end_pos
|
||
|
||
# 3. 提取剩余文本
|
||
remaining_text = text[last_end:].strip()
|
||
return complete_sentences, remaining_text
|
||
|
||
def add_sentence_audio(self, session_id, sentence_text, audio_data: bytes):
|
||
"""添加句子音频到存储"""
|
||
with self.lock:
|
||
if session_id not in self.sessions:
|
||
return None
|
||
|
||
# 生成唯一句子ID
|
||
sentence_id = str(uuid.uuid4())
|
||
|
||
# 存储音频数据
|
||
with self.sentence_lock:
|
||
self.sentence_audio_store[sentence_id] = {
|
||
'data': audio_data,
|
||
'text': sentence_text,
|
||
'created_at': time.time(),
|
||
'session_id': session_id,
|
||
'format': self.sessions[session_id]['stream_format']
|
||
}
|
||
logging.info(f" StreamSessionManager add_sentence_audio")
|
||
# 添加到会话的句子列表
|
||
self.sessions[session_id]['sentences'].append(sentence_id)
|
||
|
||
return sentence_id
|
||
|
||
def get_sentence_audio(self, sentence_id):
|
||
"""获取句子音频数据"""
|
||
with self.sentence_lock:
|
||
if sentence_id not in self.sentence_audio_store:
|
||
return None
|
||
return self.sentence_audio_store[sentence_id]['data']
|
||
|
||
def get_sentence_info(self, sentence_id):
|
||
"""获取句子信息"""
|
||
with self.sentence_lock:
|
||
if sentence_id not in self.sentence_audio_store:
|
||
return None
|
||
return self.sentence_audio_store[sentence_id]
|
||
|
||
def get_next_sentence(self, session_id):
|
||
"""获取下一个句子的信息"""
|
||
with self.lock:
|
||
session = self.sessions.get(session_id)
|
||
if not session or not session['active']:
|
||
return None
|
||
|
||
if session['current_sentence_index'] < len(session['sentences']):
|
||
sentence_id = session['sentences'][session['current_sentence_index']]
|
||
session['current_sentence_index'] += 1
|
||
return {
|
||
'id': sentence_id,
|
||
'url': f"/tts_sentence/{sentence_id}" # 虚拟URL
|
||
}
|
||
return None
|
||
def get_sentence_audio_data(self, session_id):
|
||
with self.lock:
|
||
session = self.sessions.get(session_id)
|
||
if not session or not session['active']:
|
||
return None
|
||
return self.sentence_audio_store
|
||
|
||
stream_manager = StreamSessionManager()
|
||
|
||
|
||
def allowed_file(filename):
|
||
return '.' in filename and \
|
||
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
|
||
|
||
|
||
audio_text_cache = {}
|
||
cache_lock = Lock()
|
||
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
||
|
||
|
||
# WebSocket 连接管理
|
||
class ConnectionManager:
|
||
def __init__(self):
|
||
self.active_connections = {}
|
||
|
||
async def connect(self, websocket: WebSocket, connection_id: str):
|
||
await websocket.accept()
|
||
self.active_connections[connection_id] = websocket
|
||
logging.info(f"新连接建立: {connection_id}")
|
||
|
||
async def disconnect(self, connection_id: str, code=1000, reason: str = ""):
|
||
if connection_id in self.active_connections:
|
||
try:
|
||
# 尝试正常关闭连接(非阻塞)
|
||
await self.active_connections[connection_id].close(code=code, reason=reason)
|
||
except:
|
||
pass # 忽略关闭错误
|
||
finally:
|
||
del self.active_connections[connection_id]
|
||
|
||
def is_connected(self, connection_id: str) -> bool:
|
||
"""检查连接是否仍然活跃"""
|
||
return connection_id in self.active_connections
|
||
|
||
async def _safe_send(self, connection_id: str, send_func, *args):
|
||
"""安全发送的通用方法(核心修改)"""
|
||
# 1. 检查连接是否存在
|
||
if connection_id not in self.active_connections:
|
||
logging.warning(f"尝试向不存在的连接发送数据: {connection_id}")
|
||
return False
|
||
|
||
websocket = self.active_connections[connection_id]
|
||
|
||
try:
|
||
# 2. 检查连接状态(关键修改)
|
||
if websocket.client_state.name != "CONNECTED":
|
||
logging.warning(f"连接 {connection_id} 状态为 {websocket.client_state.name}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
|
||
# 3. 执行发送操作
|
||
await send_func(websocket, *args)
|
||
return True
|
||
|
||
except (WebSocketDisconnect, RuntimeError) as e:
|
||
# 4. 处理连接断开异常
|
||
logging.info(f"发送时检测到断开连接: {connection_id}, {str(e)}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
except Exception as e:
|
||
# 5. 处理其他异常
|
||
logging.error(f"发送数据出错: {connection_id}, {str(e)}")
|
||
await self.disconnect(connection_id)
|
||
return False
|
||
|
||
async def send_bytes(self, connection_id: str, data: bytes):
|
||
"""安全发送字节数据"""
|
||
return await self._safe_send(
|
||
connection_id,
|
||
lambda ws, d: ws.send_bytes(d),
|
||
data
|
||
)
|
||
|
||
async def send_text(self, connection_id: str, message: str):
|
||
"""安全发送文本数据"""
|
||
return await self._safe_send(
|
||
connection_id,
|
||
lambda ws, m: ws.send_text(m),
|
||
message
|
||
)
|
||
|
||
async def send_json(self, connection_id: str, data: dict):
|
||
"""安全发送JSON数据"""
|
||
return await self._safe_send(
|
||
connection_id,
|
||
lambda ws, d: ws.send_json(d),
|
||
data
|
||
)
|
||
|
||
|
||
manager = ConnectionManager()
|
||
|
||
|
||
def generate_mp3_header(
|
||
sample_rate: int,
|
||
bitrate_kbps: int,
|
||
channels: int = 1,
|
||
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
|
||
) -> bytes:
|
||
"""
|
||
动态生成 MP3 帧头(4字节),支持 Layer I/II/III
|
||
|
||
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
|
||
:param bitrate_kbps: 比特率(单位 kbps)
|
||
:param channels: 声道数 (1: 单声道, 2: 立体声)
|
||
:param layer: 编码层 ("I", "II", "III")
|
||
:return: 4字节的帧头数据
|
||
"""
|
||
# ----------------------------------
|
||
# 参数校验
|
||
# ----------------------------------
|
||
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
|
||
if sample_rate not in valid_sample_rates:
|
||
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
|
||
|
||
valid_layers = {"I", "II", "III"}
|
||
if layer not in valid_layers:
|
||
raise ValueError(f"不支持的层,可选:{valid_layers}")
|
||
|
||
# ----------------------------------
|
||
# 确定 MPEG 版本和采样率索引
|
||
# ----------------------------------
|
||
if sample_rate == 44100:
|
||
mpeg_version = 0b11 # MPEG-1
|
||
sample_rate_index = 0b00
|
||
elif sample_rate == 22050:
|
||
mpeg_version = 0b10 # MPEG-2
|
||
sample_rate_index = 0b00
|
||
elif sample_rate == 16000:
|
||
mpeg_version = 0b10 # MPEG-2
|
||
sample_rate_index = 0b10
|
||
elif sample_rate == 8000:
|
||
mpeg_version = 0b00 # MPEG-2.5
|
||
sample_rate_index = 0b10
|
||
else:
|
||
raise ValueError("采样率与版本不匹配")
|
||
|
||
# ----------------------------------
|
||
# 动态选择比特率表(关键扩展)
|
||
# ----------------------------------
|
||
# Layer 编码映射(I:0b11, II:0b10, III:0b01)
|
||
layer_code = {
|
||
"I": 0b11,
|
||
"II": 0b10,
|
||
"III": 0b01
|
||
}[layer]
|
||
|
||
# 比特率表(覆盖所有 Layer)
|
||
bitrate_tables = {
|
||
# -------------------------------
|
||
# MPEG-1 (0b11)
|
||
# -------------------------------
|
||
# Layer I
|
||
(0b11, 0b11): {
|
||
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
|
||
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
|
||
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
|
||
416: 0b1100, 448: 0b1101
|
||
},
|
||
# Layer II
|
||
(0b11, 0b10): {
|
||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
|
||
320: 0b1100, 384: 0b1101
|
||
},
|
||
# Layer III
|
||
(0b11, 0b01): {
|
||
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
|
||
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
|
||
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
|
||
256: 0b0100, 320: 0b0101
|
||
},
|
||
|
||
# -------------------------------
|
||
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
|
||
# -------------------------------
|
||
# Layer I
|
||
(0b10, 0b11): {
|
||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
||
224: 0b1100, 256: 0b1101
|
||
},
|
||
(0b00, 0b11): {
|
||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
||
224: 0b1100, 256: 0b1101
|
||
},
|
||
|
||
# Layer II
|
||
(0b10, 0b10): {
|
||
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
||
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
||
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
||
144: 0b1100, 160: 0b1101
|
||
},
|
||
(0b00, 0b10): {
|
||
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
||
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
||
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
||
144: 0b1100, 160: 0b1101
|
||
},
|
||
|
||
# Layer III
|
||
(0b10, 0b01): {
|
||
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
||
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
|
||
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
|
||
144: 0b0100, 160: 0b0101
|
||
},
|
||
(0b00, 0b01): {
|
||
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
||
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
|
||
}
|
||
}
|
||
|
||
# 获取当前版本的比特率表
|
||
key = (mpeg_version, layer_code)
|
||
if key not in bitrate_tables:
|
||
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
|
||
bitrate_table = bitrate_tables[key]
|
||
|
||
if bitrate_kbps not in bitrate_table:
|
||
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
|
||
bitrate_index = bitrate_table[bitrate_kbps]
|
||
|
||
# ----------------------------------
|
||
# 确定声道模式
|
||
# ----------------------------------
|
||
if channels == 1:
|
||
channel_mode = 0b11 # 单声道
|
||
elif channels == 2:
|
||
channel_mode = 0b00 # 立体声
|
||
else:
|
||
raise ValueError("声道数必须为1或2")
|
||
|
||
# ----------------------------------
|
||
# 组合帧头字段(修正层编码)
|
||
# ----------------------------------
|
||
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
|
||
version = mpeg_version << 19 # MPEG 版本 2位
|
||
layer_bits = layer_code << 17 # Layer 编码(I:0b11, II:0b10, III:0b01)
|
||
protection = 0 << 16 # 无 CRC
|
||
bitrate_bits = bitrate_index << 12
|
||
sample_rate_bits = sample_rate_index << 10
|
||
padding = 0 << 9 # 无填充
|
||
private = 0 << 8
|
||
mode = channel_mode << 6
|
||
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
|
||
copyright = 0 << 3
|
||
original = 0 << 2
|
||
emphasis = 0b00 # 无强调
|
||
|
||
frame_header = (
|
||
sync |
|
||
version |
|
||
layer_bits |
|
||
protection |
|
||
bitrate_bits |
|
||
sample_rate_bits |
|
||
padding |
|
||
private |
|
||
mode |
|
||
mode_ext |
|
||
copyright |
|
||
original |
|
||
emphasis
|
||
)
|
||
|
||
return frame_header.to_bytes(4, byteorder='big')
|
||
|
||
|
||
# ------------------------------------------------
|
||
def audio_fade_in(audio_data, fade_length):
|
||
# 假设音频数据是16位单声道PCM
|
||
# 将二进制数据转换为整数数组
|
||
samples = array.array('h', audio_data)
|
||
|
||
# 对前fade_length个样本进行淡入处理
|
||
for i in range(fade_length):
|
||
fade_factor = i / fade_length
|
||
samples[i] = int(samples[i] * fade_factor)
|
||
|
||
# 将整数数组转换回二进制数据
|
||
return samples.tobytes()
|
||
|
||
|
||
def parse_markdown_json(json_string):
|
||
# 使用正则表达式匹配Markdown中的JSON代码块
|
||
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
|
||
if match:
|
||
try:
|
||
# 尝试解析JSON字符串
|
||
data = json.loads(match[1])
|
||
return {'success': True, 'data': data}
|
||
except json.JSONDecodeError as e:
|
||
# 如果解析失败,返回错误信息
|
||
return {'success': False, 'data': str(e)}
|
||
else:
|
||
return {'success': False, 'data': 'not a valid markdown json string'}
|
||
|
||
|
||
audio_text_cache = {}
|
||
cache_lock = Lock()
|
||
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
||
|
||
|
||
# 全角字符到半角字符的映射
|
||
def fullwidth_to_halfwidth(s):
|
||
full_to_half_map = {
|
||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||
'、': '、', '・': '・', ':': ':'
|
||
}
|
||
return ''.join(full_to_half_map.get(char, char) for char in s)
|
||
|
||
|
||
def split_text_at_punctuation(text, chunk_size=100):
|
||
# 使用正则表达式找到所有的标点符号和特殊字符
|
||
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
||
tokens = re.split(punctuation_pattern, text)
|
||
|
||
# 移除空字符串
|
||
tokens = [token for token in tokens if token]
|
||
|
||
# 存储最终的文本块
|
||
chunks = []
|
||
current_chunk = ''
|
||
|
||
for token in tokens:
|
||
if len(current_chunk) + len(token) <= chunk_size:
|
||
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
||
current_chunk += (token + ' ')
|
||
else:
|
||
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
||
chunks.append(current_chunk.strip())
|
||
current_chunk = token + ' '
|
||
|
||
# 添加最后一个块(如果有剩余)
|
||
if current_chunk:
|
||
chunks.append(current_chunk.strip())
|
||
|
||
return chunks
|
||
|
||
|
||
def extract_text_from_markdown(markdown_text):
|
||
# 移除Markdown标题
|
||
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
||
# 移除内联代码块
|
||
text = re.sub(r'`[^`]+`', '', text)
|
||
# 移除代码块
|
||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||
# 移除加粗和斜体
|
||
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
||
# 移除链接
|
||
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
||
# 移除图片
|
||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||
# 移除HTML标签
|
||
text = re.sub(r'<[^>]+>', '', text)
|
||
# 转换标点符号
|
||
# text = re.sub(r'[^\w\s]', '', text)
|
||
text = fullwidth_to_halfwidth(text)
|
||
# 移除多余的空格
|
||
text = re.sub(r'\s+', ' ', text).strip()
|
||
|
||
return text
|
||
|
||
|
||
def encode_gzip_base64(original_data: bytes) -> str:
|
||
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
|
||
# Step 1: Gzip 压缩
|
||
with BytesIO() as buf:
|
||
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
|
||
gz_file.write(original_data)
|
||
compressed_bytes = buf.getvalue()
|
||
|
||
# Step 2: Base64 编码(配置与Android端匹配)
|
||
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP)
|
||
|
||
|
||
def clean_audio_cache():
|
||
"""定时清理过期缓存"""
|
||
with cache_lock:
|
||
now = time.time()
|
||
expired_keys = [
|
||
k for k, v in audio_text_cache.items()
|
||
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
|
||
]
|
||
for k in expired_keys:
|
||
entry = audio_text_cache.pop(k, None)
|
||
if entry and entry.get('audio_stream'):
|
||
entry['audio_stream'].close()
|
||
|
||
|
||
def start_background_cleaner():
|
||
"""启动后台清理线程"""
|
||
|
||
def cleaner_loop():
|
||
while True:
|
||
time.sleep(180) # 每3分钟清理一次
|
||
clean_audio_cache()
|
||
|
||
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
|
||
cleaner_thread.start()
|
||
|
||
|
||
def test_qwen_chat():
|
||
messages = [
|
||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||
{'role': 'user', 'content': '你是谁?'}
|
||
]
|
||
response = Generation.call(
|
||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx",
|
||
api_key=ALI_KEY,
|
||
model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models
|
||
messages=messages,
|
||
result_format="message"
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
print(response.output.choices[0].message.content)
|
||
else:
|
||
print(f"HTTP返回码:{response.status_code}")
|
||
print(f"错误码:{response.code}")
|
||
print(f"错误信息:{response.message}")
|
||
print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
|
||
|
||
|
||
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
|
||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||
from dashscope.audio.tts import (
|
||
ResultCallback as TTSResultCallback,
|
||
SpeechSynthesizer as TTSSpeechSynthesizer,
|
||
SpeechSynthesisResult as TTSSpeechSynthesisResult,
|
||
)
|
||
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
|
||
from dashscope.audio.tts_v2 import (
|
||
ResultCallback as CosyResultCallback,
|
||
SpeechSynthesizer as CosySpeechSynthesizer,
|
||
AudioFormat,
|
||
)
|
||
|
||
|
||
class QwenTTS:
|
||
def __init__(self, key, format="mp3", sample_rate=44100, model_name="cosyvoice-v1/longxiaochun",
|
||
special_characters: Optional[Dict[str, str]] = None):
|
||
import dashscope
|
||
import ssl
|
||
logging.info(f"---begin--init QwenTTS-- {format} {sample_rate} {model_name} {model_name.split('@')[0]}") # cyx
|
||
self.model_name = model_name.split('@')[0]
|
||
dashscope.api_key = key
|
||
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
|
||
self.synthesizer = None
|
||
self.callback = None
|
||
self.is_cosyvoice = False
|
||
self.voice = ""
|
||
self.format = format
|
||
self.sample_rate = sample_rate
|
||
self.first_chunk = True
|
||
if '/' in self.model_name:
|
||
parts = self.model_name.split('/', 1)
|
||
# 返回分离后的两个字符串parts[0], parts[1]
|
||
if parts[0] == 'cosyvoice-v1' or parts[0] == 'cosyvoice-v2':
|
||
self.is_cosyvoice = True
|
||
self.voice = parts[1]
|
||
self.completion_event = None # 新增:用于通知任务完成
|
||
# 特殊字符及其拼音映射
|
||
self.special_characters = special_characters or {
|
||
"㼽": "chuang3",
|
||
"䡇": "yue4"
|
||
# 可以添加更多特殊字符的映射
|
||
}
|
||
|
||
class Callback(TTSResultCallback):
|
||
def __init__(self,data_callback=None,completion_event=None) -> None:
|
||
self.dque = deque()
|
||
self.data_callback = data_callback
|
||
self.completion_event = completion_event # 新增完成事件引用
|
||
def _run(self):
|
||
while True:
|
||
if not self.dque:
|
||
time.sleep(0)
|
||
continue
|
||
val = self.dque.popleft()
|
||
if val:
|
||
yield val
|
||
else:
|
||
break
|
||
|
||
def on_open(self):
|
||
pass
|
||
|
||
def on_complete(self):
|
||
self.dque.append(None)
|
||
if self.data_callback:
|
||
self.data_callback(None) # 发送结束信号
|
||
# 通知任务完成
|
||
if self.completion_event:
|
||
self.completion_event.set()
|
||
|
||
def on_error(self, response: SpeechSynthesisResponse):
|
||
print("Qwen tts error", str(response))
|
||
raise RuntimeError(str(response))
|
||
|
||
def on_close(self):
|
||
pass
|
||
|
||
def on_event(self, result: TTSSpeechSynthesisResult):
|
||
data =result.get_audio_frame()
|
||
if data is not None:
|
||
if len(data) > 0:
|
||
if self.data_callback:
|
||
self.data_callback(data)
|
||
else:
|
||
self.dque.append(data)
|
||
#self.dque.append(result.get_audio_frame())
|
||
|
||
# --------------------------
|
||
|
||
class Callback_Cosy(CosyResultCallback):
|
||
def __init__(self, data_callback=None,completion_event=None) -> None:
|
||
self.dque = deque()
|
||
self.data_callback = data_callback
|
||
self.completion_event = completion_event # 新增完成事件引用
|
||
|
||
def _run(self):
|
||
while True:
|
||
if not self.dque:
|
||
time.sleep(0)
|
||
continue
|
||
val = self.dque.popleft()
|
||
if val:
|
||
yield val
|
||
else:
|
||
break
|
||
|
||
def on_open(self):
|
||
logging.info("Qwen CosyVoice tts open ")
|
||
pass
|
||
|
||
def on_complete(self):
|
||
self.dque.append(None)
|
||
if self.data_callback:
|
||
self.data_callback(None) # 发送结束信号
|
||
# 通知任务完成
|
||
if self.completion_event:
|
||
self.completion_event.set()
|
||
|
||
def on_error(self, response: SpeechSynthesisResponse):
|
||
print("Qwen tts error", str(response))
|
||
if self.data_callback:
|
||
self.data_callback(f"ERROR:{str(response)}".encode())
|
||
raise RuntimeError(str(response))
|
||
|
||
def on_close(self):
|
||
# print("---Qwen call back close") # cyx
|
||
logging.info("Qwen CosyVoice tts close")
|
||
pass
|
||
|
||
""" canceled for test 语音大模型CosyVoice
|
||
def on_event(self, result: SpeechSynthesisResult):
|
||
if result.get_audio_frame() is not None:
|
||
self.dque.append(result.get_audio_frame())
|
||
"""
|
||
|
||
def on_event(self, message):
|
||
# logging.info(f"recv speech synthsis message {message}")
|
||
pass
|
||
|
||
# 以下适合语音大模型CosyVoice
|
||
def on_data(self, data: bytes) -> None:
|
||
if len(data) > 0:
|
||
if self.data_callback:
|
||
self.data_callback(data)
|
||
else:
|
||
self.dque.append(data)
|
||
|
||
# --------------------------
|
||
|
||
def tts(self, text, on_data = None,completion_event=None):
|
||
# logging.info(f"---QwenTTS tts begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
||
# text = self.normalize_text(text)
|
||
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
||
# text = self.normalize_text(text)
|
||
|
||
try:
|
||
# if self.model_name != 'cosyvoice-v1':
|
||
if self.is_cosyvoice is False:
|
||
self.callback = self.Callback(
|
||
data_callback=on_data,
|
||
completion_event=completion_event
|
||
)
|
||
TTSSpeechSynthesizer.call(model=self.model_name,
|
||
text=text,
|
||
callback=self.callback,
|
||
format=self.format) # format="mp3")
|
||
else:
|
||
self.callback = self.Callback_Cosy()
|
||
format = self.get_audio_format(self.format, self.sample_rate)
|
||
self.synthesizer = CosySpeechSynthesizer(
|
||
model='cosyvoice-v2',
|
||
# voice="longyuan", #"longfei",
|
||
voice=self.voice,
|
||
callback=self.callback,
|
||
format=format
|
||
)
|
||
self.synthesizer.call(text)
|
||
except Exception as e:
|
||
print(f"---dale---20 error {e}") # cyx
|
||
# -----------------------------------
|
||
try:
|
||
for data in self.callback._run():
|
||
# logging.info(f"dashcope return data {len(data)}")
|
||
yield data
|
||
# print(f"---Qwen return data {num_tokens_from_string(text)}")
|
||
# yield num_tokens_from_string(text)
|
||
|
||
except Exception as e:
|
||
raise RuntimeError(f"**ERROR**: {e}")
|
||
|
||
def setup_tts(self, on_data,completion_event=None):
|
||
|
||
"""设置 TTS 回调,返回配置好的 synthesizer"""
|
||
#if not self.is_cosyvoice:
|
||
# raise NotImplementedError("Only CosyVoice supported")
|
||
|
||
if self.is_cosyvoice:
|
||
# 创建 CosyVoice 回调
|
||
self.callback = self.Callback_Cosy(
|
||
data_callback=on_data,
|
||
completion_event=completion_event)
|
||
else:
|
||
self.callback = self.Callback(
|
||
data_callback=on_data,
|
||
completion_event=completion_event)
|
||
|
||
|
||
if self.is_cosyvoice:
|
||
format_val = self.get_audio_format(self.format, self.sample_rate)
|
||
# logging.info(f"Qwen setup_tts {self.voice} {format_val}")
|
||
self.synthesizer = CosySpeechSynthesizer(
|
||
model='cosyvoice-v1',
|
||
voice=self.voice, # voice="longyuan", #"longfei",
|
||
callback=self.callback,
|
||
format=format_val
|
||
)
|
||
|
||
return self.synthesizer
|
||
|
||
def apply_phoneme_tags(self, text: str) -> str:
|
||
"""
|
||
在文本中查找特殊字符并用<phoneme>标签包裹它们
|
||
"""
|
||
# 如果文本已经是SSML格式,直接返回
|
||
if text.strip().startswith("<speak>") and text.strip().endswith("</speak>"):
|
||
return text
|
||
|
||
# 为特殊字符添加SSML标签
|
||
for char, pinyin in self.special_characters.items():
|
||
# 使用正则表达式确保只替换整个字符(避免部分匹配)
|
||
pattern = r'([^<]|^)' + re.escape(char) + r'([^>]|$)'
|
||
replacement = r'\1<phoneme alphabet="py" ph="' + pinyin + r'">' + char + r'</phoneme>\2'
|
||
text = re.sub(pattern, replacement, text)
|
||
|
||
# 如果文本中已有SSML标签,直接返回
|
||
if "<speak>" in text:
|
||
return text
|
||
|
||
# 否则包裹在<speak>标签中
|
||
return f"<speak>{text}</speak>"
|
||
|
||
def text_tts_call(self, text):
|
||
if self.special_characters and self.is_cosyvoice is False:
|
||
text = self.apply_phoneme_tags(text)
|
||
#logging.info(f"Applied SSML phoneme tags to text: {text}")
|
||
|
||
if self.synthesizer and self.is_cosyvoice:
|
||
logging.info(f"Qwen text_tts_call {text} {self.is_cosyvoice}")
|
||
format_val = self.get_audio_format(self.format, self.sample_rate)
|
||
self.synthesizer = CosySpeechSynthesizer(
|
||
model='cosyvoice-v1',
|
||
voice=self.voice, # voice="longyuan", #"longfei",
|
||
callback=self.callback,
|
||
format=format_val
|
||
)
|
||
self.synthesizer.call(text)
|
||
if self.is_cosyvoice is False:
|
||
logging.info(f"Qwen text_tts_call {text}")
|
||
TTSSpeechSynthesizer.call(model=self.model_name,
|
||
text=text,
|
||
callback=self.callback,
|
||
format=self.format)
|
||
|
||
def streaming_call(self, text):
|
||
if self.synthesizer:
|
||
self.synthesizer.streaming_call(text)
|
||
|
||
def end_streaming_call(self):
|
||
if self.synthesizer:
|
||
# logging.info(f"---dale end_streaming_call")
|
||
self.synthesizer.streaming_complete()
|
||
|
||
def get_audio_format(self, format: str, sample_rate: int):
|
||
"""动态获取音频格式"""
|
||
from dashscope.audio.tts_v2 import AudioFormat
|
||
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
|
||
format_map = {
|
||
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
|
||
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
|
||
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
|
||
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
|
||
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
|
||
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
|
||
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
|
||
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
|
||
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
|
||
(48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
|
||
(48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
|
||
(48000, 'wav'): AudioFormat.WAV_48000HZ_MONO_16BIT
|
||
|
||
}
|
||
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
|
||
|
||
|
||
class DoubaoTTS:
|
||
def __init__(self, key, format="mp3", sample_rate=8000, model_name="doubao-tts"):
|
||
logging.info(f"---begin--init DoubaoTTS-- {format} {sample_rate} {model_name}")
|
||
# 解析豆包认证信息 (appid, token, cluster, voice_type)
|
||
try:
|
||
self.appid = "7282190702"
|
||
self.token = "v64Fj-fwLLKIHBgqH2_fWx5dsBEShXd9"
|
||
self.cluster = "volcano_tts"
|
||
self.voice_type ="zh_female_qingxinnvsheng_mars_bigtts" # "zh_male_jieshuonansheng_mars_bigtts" #"zh_male_ruyaqingnian_mars_bigtts" #"zh_male_jieshuonansheng_mars_bigtts"
|
||
except Exception as e:
|
||
raise ValueError(f"Invalid Doubao key format: {str(e)}")
|
||
|
||
self.format = format
|
||
self.sample_rate = sample_rate
|
||
self.model_name = model_name
|
||
self.callback = None
|
||
self.ws = None
|
||
self.loop = None
|
||
self.task = None
|
||
self.event = threading.Event()
|
||
self.data_queue = deque()
|
||
self.host = "openspeech.bytedance.com"
|
||
self.api_url = f"wss://{self.host}/api/v1/tts/ws_binary"
|
||
self.default_header = bytearray(b'\x11\x10\x11\x00')
|
||
self.total_data_size = 0
|
||
self.completion_event = None # 新增:用于通知任务完成
|
||
|
||
|
||
class Callback:
|
||
def __init__(self, data_callback=None,completion_event=None):
|
||
self.data_callback = data_callback
|
||
self.data_queue = deque()
|
||
self.completion_event = completion_event # 完成事件引用
|
||
|
||
def on_data(self, data):
|
||
if self.data_callback:
|
||
self.data_callback(data)
|
||
else:
|
||
self.data_queue.append(data)
|
||
# 通知任务完成
|
||
if self.completion_event:
|
||
self.completion_event.set()
|
||
|
||
def on_complete(self):
|
||
if self.data_callback:
|
||
self.data_callback(None)
|
||
|
||
def on_error(self, error):
|
||
if self.data_callback:
|
||
self.data_callback(f"ERROR:{error}".encode())
|
||
|
||
def setup_tts(self, on_data,completion_event):
|
||
|
||
"""设置回调,返回自身(因为豆包需要异步启动)"""
|
||
self.callback = self.Callback(
|
||
data_callback=on_data,
|
||
completion_event=completion_event
|
||
)
|
||
return self
|
||
|
||
def text_tts_call(self, text):
|
||
"""同步调用,启动异步任务并等待完成"""
|
||
self.total_data_size = 0
|
||
self.loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(self.loop)
|
||
self.task = self.loop.create_task(self._async_tts(text))
|
||
try:
|
||
self.loop.run_until_complete(self.task)
|
||
except Exception as e:
|
||
logging.error(f"DoubaoTTS--0 call error: {e}")
|
||
self.callback.on_error(str(e))
|
||
|
||
async def _async_tts(self, text):
|
||
"""异步执行TTS请求"""
|
||
header = {"Authorization": f"Bearer; {self.token}"}
|
||
request_json = {
|
||
"app": {
|
||
"appid": self.appid,
|
||
"token": "access_token", # 固定值
|
||
"cluster": self.cluster
|
||
},
|
||
"user": {
|
||
"uid": str(uuid.uuid4()) # 随机用户ID
|
||
},
|
||
"audio": {
|
||
"voice_type": self.voice_type,
|
||
"encoding": self.format,
|
||
"speed_ratio": 1.0,
|
||
"volume_ratio": 1.0,
|
||
"pitch_ratio": 1.0,
|
||
},
|
||
"request": {
|
||
"reqid": str(uuid.uuid4()),
|
||
"text": text,
|
||
"text_type": "plain",
|
||
"operation": "submit" # 使用submit模式支持流式
|
||
}
|
||
}
|
||
|
||
# 构建请求数据
|
||
payload_bytes = str.encode(json.dumps(request_json))
|
||
payload_bytes = gzip.compress(payload_bytes)
|
||
full_client_request = bytearray(self.default_header)
|
||
full_client_request.extend(len(payload_bytes).to_bytes(4, 'big'))
|
||
full_client_request.extend(payload_bytes)
|
||
|
||
try:
|
||
async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws:
|
||
self.ws = ws
|
||
await ws.send(full_client_request)
|
||
|
||
# 接收音频数据
|
||
while True:
|
||
res = await ws.recv()
|
||
done = self._parse_response(res)
|
||
if done:
|
||
self.callback.on_complete()
|
||
break
|
||
except Exception as e:
|
||
logging.error(f"DoubaoTTS--1 WebSocket error: {e}")
|
||
self.callback.on_error(str(e))
|
||
finally:
|
||
# 通知任务完成
|
||
if self.completion_event:
|
||
self.completion_event.set()
|
||
|
||
|
||
def _parse_response(self, res):
|
||
"""解析豆包返回的二进制响应"""
|
||
# 协议头解析 (4字节)
|
||
header_size = res[0] & 0x0f
|
||
message_type = res[1] >> 4
|
||
payload = res[header_size * 4:]
|
||
|
||
# 音频数据响应
|
||
if message_type == 0xb: # audio-only server response
|
||
message_flags = res[1] & 0x0f
|
||
|
||
# ACK消息,忽略
|
||
if message_flags == 0:
|
||
return False
|
||
|
||
# 音频数据消息
|
||
sequence_number = int.from_bytes(payload[:4], "big", signed=True)
|
||
payload_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||
audio_data = payload[8:8 + payload_size]
|
||
|
||
if audio_data:
|
||
self.total_data_size = self.total_data_size + len(audio_data)
|
||
self.callback.on_data(audio_data)
|
||
|
||
#logging.info(f"doubao _parse_response: {sequence_number} {len(audio_data)} {self.total_data_size}")
|
||
# 序列号为负表示结束
|
||
return sequence_number < 0
|
||
|
||
# 错误响应
|
||
elif message_type == 0xf:
|
||
code = int.from_bytes(payload[:4], "big", signed=False)
|
||
msg_size = int.from_bytes(payload[4:8], "big", signed=False)
|
||
error_msg = payload[8:8 + msg_size]
|
||
|
||
try:
|
||
# 尝试解压错误消息
|
||
error_msg = gzip.decompress(error_msg).decode()
|
||
except:
|
||
error_msg = error_msg.decode(errors='ignore')
|
||
|
||
logging.error(f"DoubaoTTS error: {error_msg}")
|
||
self.callback.on_error(error_msg)
|
||
return False
|
||
|
||
return False
|
||
|
||
|
||
class UnifiedTTSEngine:
|
||
def __init__(self):
|
||
self.lock = threading.Lock()
|
||
self.tasks = {}
|
||
self.executor = ThreadPoolExecutor(max_workers=10)
|
||
self.cache_expire = 300 # 5分钟缓存
|
||
# 启动清理过期任务的定时器
|
||
self.cleanup_timer = None
|
||
self.start_cleanup_timer()
|
||
|
||
def _cleanup_old_tasks(self):
|
||
"""清理过期任务"""
|
||
now = time.time()
|
||
with self.lock:
|
||
expired_ids = [task_id for task_id, task in self.tasks.items()
|
||
if now - task['created_at'] > self.cache_expire]
|
||
for task_id in expired_ids:
|
||
self._remove_task(task_id)
|
||
|
||
def _remove_task(self, task_id):
|
||
"""移除任务"""
|
||
if task_id in self.tasks:
|
||
task = self.tasks.pop(task_id)
|
||
# 取消可能的后台任务
|
||
if 'future' in task and not task['future'].done():
|
||
task['future'].cancel()
|
||
# 其他资源在任务被移除后会被垃圾回收
|
||
# 资源释放机制总结:
|
||
# 移除任务引用:self.tasks.pop() 解除任务对象引用,触发垃圾回收。
|
||
# 取消后台线程:future.cancel() 终止未完成线程,释放线程资源。
|
||
# 自动内存回收:Python GC 回收任务对象及其队列、缓冲区占用的内存。
|
||
# 线程池管理:执行器自动回收线程至池中,避免资源泄漏。
|
||
|
||
def create_tts_task(self, text, format, sample_rate, model_name, key, delay_gen_audio=False):
|
||
"""创建TTS任务(同步方法)"""
|
||
self._cleanup_old_tasks()
|
||
audio_stream_id = str(uuid.uuid4())
|
||
|
||
# 创建任务数据结构
|
||
task_data = {
|
||
'id': audio_stream_id,
|
||
'text': text,
|
||
'format': format,
|
||
'sample_rate': sample_rate,
|
||
'model_name': model_name,
|
||
'key': key,
|
||
'delay_gen_audio': delay_gen_audio,
|
||
'created_at': time.time(),
|
||
'status': 'pending',
|
||
'data_queue': deque(),
|
||
'event': threading.Event(),
|
||
'completed': False,
|
||
'error': None
|
||
}
|
||
with self.lock:
|
||
self.tasks[audio_stream_id] = task_data
|
||
|
||
# 如果不是延迟模式,立即启动任务
|
||
if not delay_gen_audio:
|
||
self._start_tts_task(audio_stream_id)
|
||
|
||
return audio_stream_id
|
||
|
||
def _start_tts_task(self, audio_stream_id):
|
||
# 启动TTS任务(后台线程)
|
||
|
||
task = self.tasks.get(audio_stream_id)
|
||
if not task or task['status'] != 'pending':
|
||
return
|
||
logging.info("已经启动 start tts task {audio_stream_id}")
|
||
task['status'] = 'processing'
|
||
|
||
# 在后台线程中执行TTS
|
||
future = self.executor.submit(self._run_tts_sync, audio_stream_id)
|
||
task['future'] = future
|
||
|
||
# 如果需要等待任务完成
|
||
if not task.get('delay_gen_audio', True):
|
||
try:
|
||
# 等待任务完成(最多5分钟)
|
||
future.result(timeout=300)
|
||
logging.info(f"TTS任务 {audio_stream_id} 已完成")
|
||
self._merge_audio_data(audio_stream_id)
|
||
except concurrent.futures.TimeoutError:
|
||
task['error'] = "TTS生成超时"
|
||
task['completed'] = True
|
||
logging.error(f"TTS任务 {audio_stream_id} 超时")
|
||
except Exception as e:
|
||
task['error'] = f"ERROR:{str(e)}"
|
||
task['completed'] = True
|
||
logging.exception(f"TTS任务执行异常: {str(e)}")
|
||
|
||
def _run_tts_sync(self, audio_stream_id):
|
||
# 同步执行TTS生成 在后台线程中执行
|
||
task = self.tasks.get(audio_stream_id)
|
||
if not task:
|
||
return
|
||
|
||
try:
|
||
# 创建完成事件
|
||
completion_event = threading.Event()
|
||
# 创建TTS实例
|
||
# 根据model_name选择TTS引擎
|
||
# 前端传入 cosyvoice-v1/longhua@Tongyi-Qianwen
|
||
model_name_wo_brand = task['model_name'].split('@')[0]
|
||
model_name_version = model_name_wo_brand.split('/')[0]
|
||
if "longhua" in task['model_name'] or "zh_female_qingxinnvsheng_mars_bigtts" in task['model_name']:
|
||
# 豆包TTS
|
||
tts = DoubaoTTS(
|
||
key=task['key'],
|
||
format=task['format'],
|
||
sample_rate=task['sample_rate'],
|
||
model_name=task['model_name']
|
||
)
|
||
else:
|
||
# 通义千问TTS
|
||
tts = QwenTTS(
|
||
key=task['key'],
|
||
format=task['format'],
|
||
sample_rate=task['sample_rate'],
|
||
model_name=task['model_name']
|
||
)
|
||
|
||
# 定义同步数据处理函数
|
||
def data_handler(data):
|
||
if data is None: # 结束信号
|
||
task['completed'] = True
|
||
task['event'].set()
|
||
logging.info(f"--data_handler on_complete")
|
||
elif data.startswith(b"ERROR"): # 错误信号
|
||
task['error'] = data.decode()
|
||
task['completed'] = True
|
||
task['event'].set()
|
||
else: # 音频数据
|
||
task['data_queue'].append(data)
|
||
|
||
|
||
# 设置并执行TTS
|
||
synthesizer = tts.setup_tts(data_handler,completion_event)
|
||
#synthesizer.call(task['text'])
|
||
tts.text_tts_call(task['text'])
|
||
# 等待完成或超时
|
||
# 等待完成或超时
|
||
if not completion_event.wait(timeout=300): # 5分钟超时
|
||
task['error'] = "TTS generation timeout"
|
||
task['completed'] = True
|
||
|
||
logging.info(f"--tts task event set error = {task['error']}")
|
||
|
||
except Exception as e:
|
||
logging.info(f"UnifiedTTSEngine _run_tts_sync ERROR: {str(e)}")
|
||
task['error'] = f"ERROR:{str(e)}"
|
||
task['completed'] = True
|
||
finally:
|
||
# 确保清理TTS资源
|
||
logging.info("UnifiedTTSEngine _run_tts_sync finally")
|
||
if hasattr(tts, 'loop') and tts.loop:
|
||
tts.loop.close()
|
||
|
||
def _merge_audio_data(self, audio_stream_id):
|
||
"""将任务的所有音频数据合并到ByteIO缓冲区"""
|
||
task = self.tasks.get(audio_stream_id)
|
||
if not task or not task.get('completed'):
|
||
return
|
||
|
||
try:
|
||
logging.info(f"开始合并音频数据: {audio_stream_id}")
|
||
|
||
# 创建内存缓冲区
|
||
buffer = io.BytesIO()
|
||
|
||
# 合并所有数据块
|
||
for data_chunk in task['data_queue']:
|
||
buffer.write(data_chunk)
|
||
|
||
# 重置指针位置以便读取
|
||
buffer.seek(0)
|
||
|
||
# 保存到任务对象
|
||
task['buffer'] = buffer
|
||
logging.info(f"音频数据合并完成,总大小: {buffer.getbuffer().nbytes} 字节")
|
||
|
||
# 可选:清理原始数据队列以节省内存
|
||
task['data_queue'].clear()
|
||
|
||
except Exception as e:
|
||
logging.error(f"合并音频数据失败: {str(e)}")
|
||
task['error'] = f"合并错误: {str(e)}"
|
||
|
||
async def get_audio_stream(self, audio_stream_id):
|
||
"""获取音频流(异步生成器)"""
|
||
task = self.tasks.get(audio_stream_id)
|
||
if not task:
|
||
raise RuntimeError("Audio stream not found")
|
||
|
||
# 如果是延迟任务且未启动,现在启动 status 为 pending
|
||
if task['delay_gen_audio'] and task['status'] == 'pending':
|
||
self._start_tts_task(audio_stream_id)
|
||
total_audio_data_size = 0
|
||
# 等待任务启动
|
||
while task['status'] == 'pending':
|
||
await asyncio.sleep(0.1)
|
||
|
||
# 流式返回数据
|
||
while not task['completed'] or task['data_queue']:
|
||
while task['data_queue']:
|
||
data = task['data_queue'].popleft()
|
||
total_audio_data_size += len(data)
|
||
#logging.info(f"yield audio data {len(data)} {total_audio_data_size}")
|
||
yield data
|
||
|
||
# 短暂等待新数据
|
||
await asyncio.sleep(0.05)
|
||
|
||
# 检查错误
|
||
if task['error']:
|
||
raise RuntimeError(task['error'])
|
||
|
||
def start_cleanup_timer(self):
|
||
"""启动定时清理任务"""
|
||
if self.cleanup_timer:
|
||
self.cleanup_timer.cancel()
|
||
|
||
self.cleanup_timer = threading.Timer(30.0, self.cleanup_task) # 每30秒清理一次
|
||
self.cleanup_timer.daemon = True # 设置为守护线程
|
||
self.cleanup_timer.start()
|
||
|
||
def cleanup_task(self):
|
||
"""执行清理任务"""
|
||
try:
|
||
self._cleanup_old_tasks()
|
||
except Exception as e:
|
||
logging.error(f"清理任务时出错: {str(e)}")
|
||
finally:
|
||
self.start_cleanup_timer() # 重新启动定时器
|
||
|
||
|
||
# 全局 TTS 引擎实例
|
||
tts_engine = UnifiedTTSEngine()
|
||
|
||
|
||
def replace_domain(url: str) -> str:
|
||
"""替换URL中的域名为本地地址,不使用urllib.parse"""
|
||
# 定义需要替换的域名列表
|
||
domains_to_replace = [
|
||
"http://1.13.185.116:9380",
|
||
"https://ragflow.szzysztech.com",
|
||
"1.13.185.116:9380",
|
||
"ragflow.szzysztech.com"
|
||
]
|
||
|
||
# 尝试替换每个可能的域名
|
||
for domain in domains_to_replace:
|
||
if domain in url:
|
||
# 直接替换域名部分
|
||
return url.replace(domain, "http://localhost:9380", 1)
|
||
|
||
# 如果未匹配到特定域名,尝试智能替换
|
||
if "://" in url:
|
||
# 分割协议和路径
|
||
protocol, path = url.split("://", 1)
|
||
|
||
# 查找第一个斜杠位置来确定域名结束位置
|
||
slash_pos = path.find("/")
|
||
if slash_pos > 0:
|
||
# 替换域名部分
|
||
return f"http://localhost:9380{path[slash_pos:]}"
|
||
else:
|
||
# 没有路径部分,直接返回本地地址
|
||
return "http://localhost:9380"
|
||
else:
|
||
# 没有协议部分,直接添加本地地址
|
||
return f"http://localhost:9380/{url}"
|
||
|
||
|
||
async def proxy_aichat_audio_stream(client_id: str, audio_url: str):
|
||
"""代理外部音频流请求"""
|
||
try:
|
||
# 替换域名为本地地址
|
||
local_url = audio_url
|
||
logging.info(f"代理音频流: {audio_url} -> {local_url}")
|
||
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
async with client.stream("GET", local_url) as response:
|
||
# 流式转发音频数据
|
||
async for chunk in response.aiter_bytes():
|
||
if not await manager.send_bytes(client_id, chunk):
|
||
logging.warning(f"Audio proxy interrupted for {client_id}")
|
||
return
|
||
except Exception as e:
|
||
logging.error(f"Audio proxy failed: {str(e)}")
|
||
await manager.send_text(client_id, json.dumps({
|
||
"type": "error",
|
||
"message": f"音频流获取失败: {str(e)}"
|
||
}))
|
||
|
||
|
||
# 代理函数 - 文本流
|
||
# 在微信小程序中,原来APK使用的SSE机制不能正常工作,需要使用WebSocket
|
||
async def proxy_aichat_text_stream(client_id: str, completions_url: str, payload: dict):
|
||
"""代理大模型文本流请求 - 兼容现有Flask实现"""
|
||
try:
|
||
logging.info(f"代理文本流: completions_url={completions_url} {payload}")
|
||
logging.debug(f"请求负载: {json.dumps(payload, ensure_ascii=False)}")
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
'Authorization': 'Bearer ragflow-NhZTY5Y2M4YWQ1MzExZWY4Zjc3MDI0Mm'
|
||
}
|
||
tts_model_name = payload.get('tts_model', 'cosyvoice-v1/longyuan@Tongyi-Qianwen')
|
||
#if 'longyuan' in tts_model_name:
|
||
# tts_model_name = "cosyvoice-v2/longyuan_v2@Tongyi-Qianwen"
|
||
# 创建TTS实例
|
||
tts_model = QwenTTS(
|
||
key=ALI_KEY,
|
||
format=payload.get('tts_stream_format', 'mp3'),
|
||
sample_rate=payload.get('tts_sample_rate', 48000),
|
||
model_name=tts_model_name
|
||
)
|
||
streaming_call = False
|
||
if tts_model.is_cosyvoice:
|
||
streaming_call = True
|
||
|
||
# 创建流会话
|
||
tts_stream_session_id = stream_manager.create_session(
|
||
tts_model=tts_model,
|
||
sample_rate=payload.get('tts_sample_rate', 48000),
|
||
stream_format=payload.get('tts_stream_format', 'mp3'),
|
||
session_id=None,
|
||
streaming_call= streaming_call
|
||
)
|
||
# logging.info(f"---tts_stream_session_id = {tts_stream_session_id}")
|
||
tts_stream_session_id_sent = False
|
||
send_sentence_tts_url = False
|
||
# 添加一个事件来标记所有句子已发送
|
||
all_sentences_sent = asyncio.Event()
|
||
|
||
# 任务:监听并发送新生成的句子
|
||
async def send_new_sentences():
|
||
"""监听并发送新生成的句子"""
|
||
try:
|
||
while True:
|
||
# 获取下一个句子
|
||
sentence_info = stream_manager.get_next_sentence(tts_stream_session_id)
|
||
|
||
if sentence_info:
|
||
logging.info(f"--proxy_aichat_text_stream 发送sentence_info\r\n")
|
||
# 发送句子信息
|
||
await manager.send_json(client_id, {
|
||
"type": "tts_sentence",
|
||
"id": sentence_info['id'],
|
||
"text": stream_manager.get_sentence_info(sentence_info['id'])['text'],
|
||
"url": sentence_info['url']
|
||
})
|
||
else:
|
||
# 检查会话是否结束且没有更多句子
|
||
session = stream_manager.get_session(tts_stream_session_id)
|
||
if not session or (not session['active']):
|
||
#and session['current_sentence_index'] >= len(session['sentences'])):
|
||
# 标记所有句子已发送
|
||
all_sentences_sent.set()
|
||
break
|
||
|
||
# 等待新句子生成
|
||
await asyncio.sleep(0.1)
|
||
except asyncio.CancelledError:
|
||
logging.info("句子监听任务被取消")
|
||
except Exception as e:
|
||
logging.error(f"句子监听任务出错: {str(e)}")
|
||
all_sentences_sent.set()
|
||
|
||
|
||
if send_sentence_tts_url:
|
||
# 启动句子监听任务
|
||
sentence_task = asyncio.create_task(send_new_sentences())
|
||
# 使用更长的超时时间 (5分钟)
|
||
timeout = httpx.Timeout(300.0, connect=60.0)
|
||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||
# 关键修改:使用流式请求模式
|
||
async with client.stream( # <-- 使用stream方法
|
||
"POST",
|
||
completions_url,
|
||
json=payload,
|
||
headers=headers
|
||
) as response:
|
||
logging.info(f"响应状态: HTTP {response.status_code}")
|
||
|
||
if response.status_code != 200:
|
||
# 读取错误信息(非流式)
|
||
error_content = await response.aread()
|
||
error_msg = f"后端错误: HTTP {response.status_code}"
|
||
error_msg += f" - {error_content[:200].decode()}" if error_content else ""
|
||
await manager.send_text(client_id, json.dumps({"type": "error", "message": error_msg}))
|
||
return
|
||
|
||
# 验证SSE流
|
||
content_type = response.headers.get("content-type", "").lower()
|
||
if "text/event-stream" not in content_type:
|
||
logging.warning("非流式响应,转发完整内容")
|
||
full_content = await response.aread()
|
||
await manager.send_text(client_id, json.dumps({
|
||
"type": "text",
|
||
"data": full_content.decode('utf-8')
|
||
}))
|
||
return
|
||
|
||
logging.info("开始处理SSE流")
|
||
event_count = 0
|
||
# 使用异步迭代器逐行处理
|
||
async for line in response.aiter_lines():
|
||
# 跳过空行和注释行
|
||
if not line or line.startswith(':'):
|
||
continue
|
||
|
||
# 处理SSE事件
|
||
if line.startswith("data:"):
|
||
data_str = line[5:].strip()
|
||
if data_str: # 过滤空数据
|
||
try:
|
||
# 解析并提取增量文本
|
||
data_obj = json.loads(data_str)
|
||
delta_text = None
|
||
if isinstance(data_obj, dict) and isinstance(data_obj.get('data', None), dict):
|
||
delta_text = data_obj.get('data', None).get('delta_ans', "")
|
||
if tts_stream_session_id_sent is False:
|
||
logging.info(f"--proxy_aichat_text_stream 发送audio_stream_url")
|
||
data_obj.get('data')['audio_stream_url'] = f"/tts_stream/{tts_stream_session_id}"
|
||
data_str = json.dumps(data_obj)
|
||
tts_stream_session_id_sent = True
|
||
# 直接转发原始数据
|
||
await manager.send_text(client_id, json.dumps({
|
||
"type": "text",
|
||
"data": data_str
|
||
}))
|
||
# 这里构建{"type":"text",'data':"data_str"}) 是为了前端websocket进行数据解析
|
||
if delta_text:
|
||
# 追加到会话管理器
|
||
stream_manager.append_text(tts_stream_session_id, delta_text)
|
||
# logging.info(f"文本代理转发: {data_str}")
|
||
event_count += 1
|
||
except Exception as e:
|
||
logging.error(f"事件发送失败: {str(e)}")
|
||
|
||
# 保持连接活性
|
||
await asyncio.sleep(0.001) # 避免CPU空转
|
||
|
||
logging.info(f"SSE流处理完成,事件数: {event_count}")
|
||
|
||
# 发送文本流结束信号
|
||
await manager.send_text(client_id, json.dumps({"type": "end"}))
|
||
# 标记文本输入结束
|
||
if stream_manager.finish_text_input:
|
||
stream_manager.finish_text_input(tts_stream_session_id)
|
||
|
||
if send_sentence_tts_url:
|
||
# 等待所有句子生成并发送(最多等待300秒)
|
||
try:
|
||
await asyncio.wait_for(all_sentences_sent.wait(), timeout=300.0)
|
||
logging.info(f"所有TTS句子已发送")
|
||
except asyncio.TimeoutError:
|
||
logging.warning("等待TTS句子发送超时")
|
||
|
||
# 取消句子监听任务(如果仍在运行)
|
||
if not sentence_task.done():
|
||
sentence_task.cancel()
|
||
try:
|
||
await sentence_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
except httpx.ReadTimeout:
|
||
logging.error("读取后端服务超时")
|
||
await manager.send_text(client_id, json.dumps({
|
||
"type": "error",
|
||
"message": "后端服务响应超时"
|
||
}))
|
||
except httpx.ConnectError as e:
|
||
logging.error(f"连接后端服务失败: {str(e)}")
|
||
await manager.send_text(client_id, json.dumps({
|
||
"type": "error",
|
||
"message": f"无法连接到后端服务: {str(e)}"
|
||
}))
|
||
except Exception as e:
|
||
logging.exception(f"文本代理失败: {str(e)}")
|
||
await manager.send_text(client_id, json.dumps({
|
||
"type": "error",
|
||
"message": f"文本流获取失败: {str(e)}"
|
||
}))
|
||
|
||
|
||
@tts_router.get("/audio/pcm_mp3")
|
||
async def stream_mp3():
|
||
def audio_generator():
|
||
path = './test.mp3'
|
||
try:
|
||
with open(path, 'rb') as f:
|
||
while True:
|
||
chunk = f.read(1024)
|
||
if not chunk:
|
||
break
|
||
yield chunk
|
||
except Exception as e:
|
||
logging.error(f"MP3 streaming error: {str(e)}")
|
||
|
||
return StreamingResponse(
|
||
audio_generator(),
|
||
media_type="audio/mpeg",
|
||
headers={
|
||
"Cache-Control": "no-store",
|
||
"Accept-Ranges": "bytes"
|
||
}
|
||
)
|
||
|
||
|
||
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
|
||
"""动态生成WAV头(严格保持原有逻辑结构)"""
|
||
with BytesIO() as wav_buffer:
|
||
with wave.open(wav_buffer, 'wb') as wav_file:
|
||
wav_file.setnchannels(1) # 保持原单声道设置
|
||
wav_file.setsampwidth(2) # 保持原16-bit设置
|
||
wav_file.setframerate(sample_rate)
|
||
wav_file.writeframes(pcm_data)
|
||
wav_buffer.seek(0)
|
||
return wav_buffer.read()
|
||
|
||
|
||
def generate_silence_header(duration_ms: int = 500) -> bytes:
|
||
"""生成静音数据(用于MP3流式传输预缓冲)"""
|
||
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
|
||
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
|
||
|
||
|
||
# ------------------------ API路由 ------------------------
|
||
@tts_router.get("/tts_sentence/{sentence_id}")
|
||
async def get_sentence_audio(sentence_id: str):
|
||
# 获取音频数据
|
||
audio_data = stream_manager.get_sentence_audio(sentence_id)
|
||
if not audio_data:
|
||
raise HTTPException(status_code=404, detail="Audio not found")
|
||
|
||
# 获取音频格式
|
||
sentence_info = stream_manager.get_sentence_info(sentence_id)
|
||
if not sentence_info:
|
||
raise HTTPException(status_code=404, detail="Sentence info not found")
|
||
|
||
# 确定MIME类型
|
||
format = sentence_info['format']
|
||
media_type = "audio/mpeg" if format == "mp3" else "audio/wav"
|
||
logging.info(f"--http get sentence tts audio stream {sentence_id}")
|
||
# 返回流式响应
|
||
return StreamingResponse(
|
||
io.BytesIO(audio_data),
|
||
media_type=media_type,
|
||
headers={
|
||
"Content-Disposition": f"attachment; filename=audio.{format}",
|
||
"Cache-Control": "max-age=3600" # 缓存1小时
|
||
}
|
||
)
|
||
|
||
@tts_router.post("/chats/{chat_id}/tts")
|
||
async def create_tts_request(chat_id: str, request: Request):
|
||
try:
|
||
data = await request.json()
|
||
logging.info(f"Creating TTS request: {data}")
|
||
|
||
# 参数校验
|
||
text = data.get("text", "").strip()
|
||
if not text:
|
||
raise HTTPException(400, detail="Text cannot be empty")
|
||
|
||
format = data.get("tts_stream_format", "mp3")
|
||
if format not in ["mp3", "wav", "pcm"]:
|
||
raise HTTPException(400, detail="Unsupported audio format")
|
||
|
||
sample_rate = data.get("tts_sample_rate", 48000)
|
||
if sample_rate not in [8000, 16000, 22050, 44100, 48000]:
|
||
raise HTTPException(400, detail="Unsupported sample rate")
|
||
|
||
model_name = data.get("model_name", "cosyvoice-v1/longxiaochun")
|
||
delay_gen_audio = data.get('delay_gen_audio', False)
|
||
|
||
# 创建TTS任务
|
||
audio_stream_id = tts_engine.create_tts_task(
|
||
text=text,
|
||
format=format,
|
||
sample_rate=sample_rate,
|
||
model_name=model_name,
|
||
key=ALI_KEY,
|
||
delay_gen_audio=delay_gen_audio
|
||
)
|
||
|
||
return JSONResponse(
|
||
status_code=200,
|
||
content={
|
||
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
||
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
||
"ws_url": f"/chats/{chat_id}/tts/{audio_stream_id}", # WebSocket URL 2025 0622新增
|
||
"expires_at": (datetime.datetime.now() + datetime.timedelta(seconds=300)).isoformat()
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(f"Request failed: {str(e)}")
|
||
raise HTTPException(500, detail="Internal server error")
|
||
|
||
|
||
executor = ThreadPoolExecutor()
|
||
|
||
|
||
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
|
||
async def get_tts_audio(
|
||
chat_id: str,
|
||
audio_stream_id: str,
|
||
range: str = Header(None)
|
||
):
|
||
try:
|
||
# 获取任务信息
|
||
task = tts_engine.tasks.get(audio_stream_id)
|
||
if not task:
|
||
# 返回友好的错误信息而不是抛出异常
|
||
return JSONResponse(
|
||
status_code=404,
|
||
content={
|
||
"error": "Audio stream not found",
|
||
"message": f"The requested audio stream ID '{audio_stream_id}' does not exist or has expired",
|
||
"suggestion": "Please create a new TTS request and try again"
|
||
}
|
||
)
|
||
|
||
# 获取媒体类型
|
||
format = task['format']
|
||
media_type = {
|
||
"mp3": "audio/mpeg",
|
||
"wav": "audio/wav",
|
||
"pcm": f"audio/L16; rate={task['sample_rate']}; channels=1"
|
||
}[format]
|
||
|
||
# 如果任务已完成且有完整缓冲区,处理Range请求
|
||
logging.info(f"get_tts_audio task = {task.get('completed', 'None')} {task.get('buffer', 'None')}")
|
||
|
||
# 创建响应内容生成器
|
||
def buffer_read(buffer):
|
||
content_length = buffer.getbuffer().nbytes
|
||
remaining = content_length
|
||
chunk_size = 4096
|
||
buffer.seek(0)
|
||
while remaining > 0:
|
||
read_size = min(remaining, chunk_size)
|
||
data = buffer.read(read_size)
|
||
if not data:
|
||
break
|
||
yield data
|
||
remaining -= len(data)
|
||
|
||
if task.get('completed') and task.get('buffer') is not None:
|
||
buffer = task['buffer']
|
||
total_size = buffer.getbuffer().nbytes
|
||
# 强制小文件使用流式传输(避免206响应问题)
|
||
|
||
if total_size < 1024 * 120: # 小于300KB
|
||
range = None
|
||
|
||
if range:
|
||
# 处理范围请求
|
||
return handle_range_request(range, buffer, total_size, media_type)
|
||
else:
|
||
return StreamingResponse(
|
||
buffer_read(buffer),
|
||
media_type=media_type,
|
||
headers={
|
||
"Accept-Ranges": "bytes",
|
||
"Cache-Control": "no-store",
|
||
"Transfer-Encoding": "chunked"
|
||
})
|
||
|
||
# 创建流式响应
|
||
logging.info("tts_engine.get_audio_stream--0")
|
||
return StreamingResponse(
|
||
tts_engine.get_audio_stream(audio_stream_id),
|
||
media_type=media_type,
|
||
headers={
|
||
"Accept-Ranges": "bytes",
|
||
"Cache-Control": "no-store",
|
||
"Transfer-Encoding": "chunked"
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(f"Audio streaming failed: {str(e)}")
|
||
raise HTTPException(500, detail="Audio generation error")
|
||
|
||
|
||
def handle_range_request(range_header: str, buffer: BytesIO, total_size: int, media_type: str):
|
||
"""处理 HTTP Range 请求"""
|
||
try:
|
||
# 解析 Range 头部 (示例: "bytes=0-1023")
|
||
range_type, range_spec = range_header.split('=')
|
||
if range_type != 'bytes':
|
||
raise ValueError("Unsupported range type")
|
||
|
||
start_str, end_str = range_spec.split('-')
|
||
start = int(start_str)
|
||
end = int(end_str) if end_str else total_size - 1
|
||
logging.info(f"handle_range_request--1 {start_str}-{end_str} {end}")
|
||
# 验证范围有效性
|
||
if start >= total_size or end >= total_size:
|
||
raise HTTPException(status_code=416, headers={
|
||
"Content-Range": f"bytes */{total_size}"
|
||
})
|
||
|
||
# 计算内容长度
|
||
content_length = end - start + 1
|
||
|
||
# 设置状态码
|
||
status_code = 206 # Partial Content
|
||
if start == 0 and end == total_size - 1:
|
||
status_code = 200 # Full Content
|
||
|
||
# 设置流读取位置
|
||
buffer.seek(start)
|
||
|
||
# 创建响应内容生成器
|
||
def content_generator():
|
||
remaining = content_length
|
||
chunk_size = 4096
|
||
while remaining > 0:
|
||
read_size = min(remaining, chunk_size)
|
||
data = buffer.read(read_size)
|
||
if not data:
|
||
break
|
||
yield data
|
||
remaining -= len(data)
|
||
|
||
# 返回分块响应
|
||
return StreamingResponse(
|
||
content_generator(),
|
||
status_code=status_code,
|
||
media_type=media_type,
|
||
headers={
|
||
"Content-Range": f"bytes {start}-{end}/{total_size}",
|
||
"Content-Length": str(content_length),
|
||
"Accept-Ranges": "bytes",
|
||
"Cache-Control": "public, max-age=3600"
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@tts_router.websocket("/chats/{chat_id}/tts/{audio_stream_id}")
|
||
async def websocket_tts_endpoint(
|
||
websocket: WebSocket,
|
||
chat_id: str,
|
||
audio_stream_id: str
|
||
):
|
||
# 接收 header 参数
|
||
headers = websocket.headers
|
||
service_type = headers.get("x-tts-type") # 注意:header 名称转为小写
|
||
# audio_url = headers.get("x-audio-url")
|
||
"""
|
||
前端示例
|
||
websocketConnection = uni.connectSocket({
|
||
url: url,
|
||
header: {
|
||
'Authorization': token,
|
||
'X-Tts-Type': 'AiChat', //'Ask' // 自定义参数1
|
||
'X-Device-Type': 'mobile', // 自定义参数2
|
||
'X-User-ID': '12345' // 自定义参数3
|
||
},
|
||
success: () => {
|
||
console.log('WebSocket connected');
|
||
},
|
||
fail: (err) => {
|
||
console.error('WebSocket connection failed:', err);
|
||
}
|
||
});
|
||
"""
|
||
# 创建唯一连接 ID
|
||
connection_id = str(uuid.uuid4())
|
||
# logging.info(f"---dale-- websocket connection_id = {connection_id} chat_id={chat_id}")
|
||
await manager.connect(websocket, connection_id)
|
||
|
||
completed_successfully = False
|
||
try:
|
||
# 根据tts_type路由到不同的音频源
|
||
if service_type == "AiChatTts":
|
||
# 音频代理服务
|
||
audio_url = f"http://localhost:9380/api/v1/tts_stream/{audio_stream_id}"
|
||
# await proxy_aichat_audio_stream(connection_id, audio_url)
|
||
sample_rate = stream_manager.get_session(audio_stream_id).get('sample_rate')
|
||
audio_data_size =0
|
||
await manager.send_json(connection_id, {"command": "sample_rate", "params": sample_rate})
|
||
async for data in stream_manager.get_tts_buffer_data(audio_stream_id):
|
||
if data.get('type') == 'sentence_end':
|
||
await manager.send_json(connection_id, {"command": "sentence_end"})
|
||
|
||
if data.get('type') == 'arraybuffer':
|
||
audio_data_size += len(data.get('data'))
|
||
if not await manager.send_bytes(connection_id, data.get('data')):
|
||
break
|
||
completed_successfully = True
|
||
logging.info(f"--- proxy AiChatTts audio_data_size={audio_data_size}")
|
||
elif service_type == "AiChatText":
|
||
# 文本代理服务
|
||
# 等待客户端发送初始请求数据 进行大模型对话代理时,需要前端连接后发送payload
|
||
payload = await websocket.receive_json()
|
||
completions_url = f"http://localhost:9380/api/v1/chats/{chat_id}/completions"
|
||
await proxy_aichat_text_stream(connection_id, completions_url, payload)
|
||
completed_successfully = True
|
||
else:
|
||
# 使用引擎的生成器直接获取音频流
|
||
async for data in tts_engine.get_audio_stream(audio_stream_id):
|
||
if not await manager.send_bytes(connection_id, data):
|
||
logging.warning(f"Send failed, connection closed: {connection_id}")
|
||
break
|
||
await manager.send_json(connection_id, {"command": "sentence_end"})
|
||
completed_successfully = True
|
||
|
||
# 发送完成信号前检查连接状态
|
||
if manager.is_connected(connection_id):
|
||
# 发送完成信号
|
||
await manager.send_json(connection_id, {"status": "completed"})
|
||
|
||
# 添加短暂延迟确保消息送达
|
||
await asyncio.sleep(0.1)
|
||
|
||
# 主动关闭WebSocket连接
|
||
await manager.disconnect(connection_id, code=1000, reason="Audio stream completed")
|
||
except WebSocketDisconnect:
|
||
logging.info(f"WebSocket disconnected: {connection_id}")
|
||
except Exception as e:
|
||
logging.error(f"WebSocket TTS error: {str(e)}")
|
||
if manager.is_connected(connection_id):
|
||
await manager.send_json(connection_id, {"error": str(e)})
|
||
finally:
|
||
pass
|
||
# await manager.disconnect(connection_id)
|
||
|
||
|
||
def cleanup_cache():
|
||
"""清理过期缓存"""
|
||
with cache_lock:
|
||
now = datetime.datetime.now()
|
||
expired = [k for k, v in audio_text_cache.items()
|
||
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
|
||
for key in expired:
|
||
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
|
||
del audio_text_cache[key]
|
||
|
||
# 应用启动时启动清理线程
|
||
# start_background_cleaner()
|