1025 lines
38 KiB
Python
1025 lines
38 KiB
Python
import logging
|
||
import binascii
|
||
from copy import deepcopy
|
||
from timeit import default_timer as timer
|
||
import datetime
|
||
from datetime import timedelta
|
||
import threading, time,queue,uuid,time,array
|
||
from threading import Lock, Thread
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import base64, gzip
|
||
import os,io, re, json
|
||
from io import BytesIO
|
||
from typing import Optional, Dict, Any
|
||
import asyncio
|
||
|
||
from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query
|
||
from fastapi import FastAPI, UploadFile, File, Form, Header
|
||
from fastapi.responses import StreamingResponse,JSONResponse,Response
|
||
TTS_SAMPLERATE = 44100 # 22050 # 16000
|
||
FORMAT = "mp3"
|
||
CHANNELS = 1 # 单声道
|
||
SAMPLE_WIDTH = 2 # 16-bit = 2字节
|
||
|
||
tts_router = APIRouter()
|
||
# logger = logging.getLogger(__name__)
|
||
class MillisecondsFormatter(logging.Formatter):
|
||
"""自定义日志格式器,添加毫秒时间戳"""
|
||
def formatTime(self, record, datefmt=None):
|
||
# 将时间戳转换为本地时间元组
|
||
ct = self.converter(record.created)
|
||
# 格式化为 "小时:分钟:秒"
|
||
t = time.strftime("%H:%M:%S", ct)
|
||
# 添加毫秒(3位)
|
||
return f"{t}.{int(record.msecs):03d}"
|
||
|
||
# 配置全局日志格式
|
||
def configure_logging():
|
||
# 创建 Formatter
|
||
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
||
formatter = MillisecondsFormatter(log_format)
|
||
|
||
# 获取根 Logger 并清除已有配置
|
||
root_logger = logging.getLogger()
|
||
root_logger.handlers = []
|
||
|
||
# 创建并配置 Handler(输出到控制台)
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 设置日志级别并添加 Handler
|
||
root_logger.setLevel(logging.INFO)
|
||
root_logger.addHandler(console_handler)
|
||
|
||
# 调用配置函数(程序启动时运行一次)
|
||
configure_logging()
|
||
|
||
class StreamSessionManager:
|
||
def __init__(self):
|
||
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
|
||
self.lock = threading.Lock()
|
||
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
||
self.gc_interval = 300 # 5分钟清理一次
|
||
self.gc_tts = 3 # 3s
|
||
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
|
||
session_id = str(uuid.uuid4())
|
||
with self.lock:
|
||
self.sessions[session_id] = {
|
||
'tts_model': tts_model,
|
||
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
||
'task_queue': queue.Queue(),
|
||
'active': True,
|
||
'last_active': time.time(),
|
||
'audio_chunk_count':0,
|
||
'finished': threading.Event(), # 添加事件对象
|
||
'sample_rate':sample_rate,
|
||
'stream_format':stream_format,
|
||
"tts_chunk_data_valid":False
|
||
}
|
||
# 启动任务处理线程
|
||
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
||
return session_id
|
||
|
||
def append_text(self, session_id, text):
|
||
with self.lock:
|
||
session = self.sessions.get(session_id)
|
||
if not session: return
|
||
# 将文本放入任务队列(非阻塞)
|
||
try:
|
||
session['task_queue'].put(text, block=False)
|
||
except queue.Full:
|
||
logging.warning(f"Session {session_id} task queue full")
|
||
|
||
def _process_tasks(self, session_id):
|
||
"""任务处理线程(每个会话独立)"""
|
||
while True:
|
||
session = self.sessions.get(session_id)
|
||
if not session or not session['active']:
|
||
break
|
||
try:
|
||
# 合并多个文本块(最多等待50ms)
|
||
texts = []
|
||
while len(texts) < 5: # 最大合并5个文本块
|
||
try:
|
||
text = session['task_queue'].get(timeout=0.05)
|
||
texts.append(text)
|
||
except queue.Empty:
|
||
break
|
||
|
||
if texts:
|
||
# 提交到线程池处理
|
||
future=self.executor.submit(
|
||
self._generate_audio,
|
||
session_id,
|
||
' '.join(texts) # 合并文本减少请求次数
|
||
)
|
||
future.result() # 等待转换任务执行完毕
|
||
# 会话超时检查
|
||
if time.time() - session['last_active'] > self.gc_interval:
|
||
self.close_session(session_id)
|
||
break
|
||
if time.time() - session['last_active'] > self.gc_tts:
|
||
session['finished'].set()
|
||
break
|
||
|
||
except Exception as e:
|
||
logging.error(f"Task processing error: {str(e)}")
|
||
|
||
def _generate_audio(self, session_id, text):
|
||
"""实际生成音频(线程池执行)"""
|
||
session = self.sessions.get(session_id)
|
||
if not session: return
|
||
# logging.info(f"_generate_audio:{text}")
|
||
first_chunk = True
|
||
# logging.info(f"转换开始!!! {text}")
|
||
try:
|
||
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
|
||
if session['stream_format'] == 'wav':
|
||
if first_chunk:
|
||
chunk_len = len(chunk)
|
||
if chunk_len > 2048:
|
||
session['buffer'].put(audio_fade_in(chunk,1024))
|
||
else:
|
||
session['buffer'].put(audio_fade_in(chunk, chunk_len))
|
||
first_chunk = False
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
session['last_active'] = time.time()
|
||
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
||
if session['tts_chunk_data_valid'] is False:
|
||
session['tts_chunk_data_valid'] = True #20250510 增加,表示连接TTS后台已经返回,可以通知前端了
|
||
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
|
||
except Exception as e:
|
||
session['buffer'].put(f"ERROR:{str(e)}")
|
||
logging.info(f"--_generate_audio--error {str(e)}")
|
||
|
||
|
||
def close_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
# 标记会话为不活跃
|
||
self.sessions[session_id]['active'] = False
|
||
# 延迟2秒后清理资源
|
||
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
||
|
||
def _clean_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
del self.sessions[session_id]
|
||
|
||
def get_session(self, session_id):
|
||
return self.sessions.get(session_id)
|
||
|
||
stream_manager = StreamSessionManager()
|
||
|
||
|
||
def allowed_file(filename):
|
||
return '.' in filename and \
|
||
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
|
||
|
||
|
||
audio_text_cache = {}
|
||
cache_lock = Lock()
|
||
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
||
|
||
"""
|
||
@manager.route('/tts_stream/<session_id>', methods=['GET'])
|
||
def tts_stream(session_id):
|
||
session = stream_manager.sessions.get(session_id)
|
||
logging.info(f"--tts_stream {session}")
|
||
if session is None:
|
||
return get_error_data_result(message="Audio stream not found or expired.")
|
||
|
||
def generate():
|
||
count = 0;
|
||
finished_event = session['finished']
|
||
try:
|
||
while not finished_event.is_set():
|
||
if not session or not session['active']:
|
||
break
|
||
try:
|
||
chunk = session['buffer'].get_nowait() #
|
||
count = count + 1
|
||
if isinstance(chunk, str) and chunk.startswith("ERROR"):
|
||
logging.info(f"---tts stream error!!!! {chunk}")
|
||
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
|
||
break
|
||
if session['stream_format'] == "wav":
|
||
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
|
||
yield gzip_base64_data
|
||
else:
|
||
yield chunk
|
||
retry_count = 0 # 成功收到数据重置重试计数器
|
||
except queue.Empty:
|
||
if session['stream_format'] == "wav":
|
||
pass
|
||
else:
|
||
pass
|
||
except Exception as e:
|
||
logging.info(f"tts streag get error2 {e} ")
|
||
|
||
|
||
finally:
|
||
# 确保流结束后关闭会话
|
||
if session:
|
||
# 延迟关闭会话,确保所有数据已发送
|
||
stream_manager.close_session(session_id)
|
||
logging.info(f"Session {session_id} closed.")
|
||
# 关键响应头设置
|
||
|
||
if session['stream_format'] == "wav":
|
||
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
|
||
else:
|
||
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
|
||
resp.headers.add_header("Cache-Control", "no-cache")
|
||
resp.headers.add_header("Connection", "keep-alive")
|
||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||
return resp
|
||
|
||
"""
|
||
def generate_mp3_header(
|
||
sample_rate: int,
|
||
bitrate_kbps: int,
|
||
channels: int = 1,
|
||
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
|
||
) -> bytes:
|
||
"""
|
||
动态生成 MP3 帧头(4字节),支持 Layer I/II/III
|
||
|
||
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
|
||
:param bitrate_kbps: 比特率(单位 kbps)
|
||
:param channels: 声道数 (1: 单声道, 2: 立体声)
|
||
:param layer: 编码层 ("I", "II", "III")
|
||
:return: 4字节的帧头数据
|
||
"""
|
||
# ----------------------------------
|
||
# 参数校验
|
||
# ----------------------------------
|
||
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
|
||
if sample_rate not in valid_sample_rates:
|
||
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
|
||
|
||
valid_layers = {"I", "II", "III"}
|
||
if layer not in valid_layers:
|
||
raise ValueError(f"不支持的层,可选:{valid_layers}")
|
||
|
||
# ----------------------------------
|
||
# 确定 MPEG 版本和采样率索引
|
||
# ----------------------------------
|
||
if sample_rate == 44100:
|
||
mpeg_version = 0b11 # MPEG-1
|
||
sample_rate_index = 0b00
|
||
elif sample_rate == 22050:
|
||
mpeg_version = 0b10 # MPEG-2
|
||
sample_rate_index = 0b00
|
||
elif sample_rate == 16000:
|
||
mpeg_version = 0b10 # MPEG-2
|
||
sample_rate_index = 0b10
|
||
elif sample_rate == 8000:
|
||
mpeg_version = 0b00 # MPEG-2.5
|
||
sample_rate_index = 0b10
|
||
else:
|
||
raise ValueError("采样率与版本不匹配")
|
||
|
||
# ----------------------------------
|
||
# 动态选择比特率表(关键扩展)
|
||
# ----------------------------------
|
||
# Layer 编码映射(I:0b11, II:0b10, III:0b01)
|
||
layer_code = {
|
||
"I": 0b11,
|
||
"II": 0b10,
|
||
"III": 0b01
|
||
}[layer]
|
||
|
||
# 比特率表(覆盖所有 Layer)
|
||
bitrate_tables = {
|
||
# -------------------------------
|
||
# MPEG-1 (0b11)
|
||
# -------------------------------
|
||
# Layer I
|
||
(0b11, 0b11): {
|
||
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
|
||
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
|
||
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
|
||
416: 0b1100, 448: 0b1101
|
||
},
|
||
# Layer II
|
||
(0b11, 0b10): {
|
||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
|
||
320: 0b1100, 384: 0b1101
|
||
},
|
||
# Layer III
|
||
(0b11, 0b01): {
|
||
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
|
||
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
|
||
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
|
||
256: 0b0100, 320: 0b0101
|
||
},
|
||
|
||
# -------------------------------
|
||
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
|
||
# -------------------------------
|
||
# Layer I
|
||
(0b10, 0b11): {
|
||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
||
224: 0b1100, 256: 0b1101
|
||
},
|
||
(0b00, 0b11): {
|
||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
||
224: 0b1100, 256: 0b1101
|
||
},
|
||
|
||
# Layer II
|
||
(0b10, 0b10): {
|
||
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
||
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
||
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
||
144: 0b1100, 160: 0b1101
|
||
},
|
||
(0b00, 0b10): {
|
||
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
||
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
||
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
||
144: 0b1100, 160: 0b1101
|
||
},
|
||
|
||
# Layer III
|
||
(0b10, 0b01): {
|
||
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
||
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
|
||
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
|
||
144: 0b0100, 160: 0b0101
|
||
},
|
||
(0b00, 0b01): {
|
||
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
||
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
|
||
}
|
||
}
|
||
|
||
# 获取当前版本的比特率表
|
||
key = (mpeg_version, layer_code)
|
||
if key not in bitrate_tables:
|
||
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
|
||
bitrate_table = bitrate_tables[key]
|
||
|
||
if bitrate_kbps not in bitrate_table:
|
||
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
|
||
bitrate_index = bitrate_table[bitrate_kbps]
|
||
|
||
# ----------------------------------
|
||
# 确定声道模式
|
||
# ----------------------------------
|
||
if channels == 1:
|
||
channel_mode = 0b11 # 单声道
|
||
elif channels == 2:
|
||
channel_mode = 0b00 # 立体声
|
||
else:
|
||
raise ValueError("声道数必须为1或2")
|
||
|
||
# ----------------------------------
|
||
# 组合帧头字段(修正层编码)
|
||
# ----------------------------------
|
||
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
|
||
version = mpeg_version << 19 # MPEG 版本 2位
|
||
layer_bits = layer_code << 17 # Layer 编码(I:0b11, II:0b10, III:0b01)
|
||
protection = 0 << 16 # 无 CRC
|
||
bitrate_bits = bitrate_index << 12
|
||
sample_rate_bits = sample_rate_index << 10
|
||
padding = 0 << 9 # 无填充
|
||
private = 0 << 8
|
||
mode = channel_mode << 6
|
||
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
|
||
copyright = 0 << 3
|
||
original = 0 << 2
|
||
emphasis = 0b00 # 无强调
|
||
|
||
frame_header = (
|
||
sync |
|
||
version |
|
||
layer_bits |
|
||
protection |
|
||
bitrate_bits |
|
||
sample_rate_bits |
|
||
padding |
|
||
private |
|
||
mode |
|
||
mode_ext |
|
||
copyright |
|
||
original |
|
||
emphasis
|
||
)
|
||
|
||
return frame_header.to_bytes(4, byteorder='big')
|
||
|
||
# ------------------------------------------------
|
||
def audio_fade_in(audio_data, fade_length):
|
||
# 假设音频数据是16位单声道PCM
|
||
# 将二进制数据转换为整数数组
|
||
samples = array.array('h', audio_data)
|
||
|
||
# 对前fade_length个样本进行淡入处理
|
||
for i in range(fade_length):
|
||
fade_factor = i / fade_length
|
||
samples[i] = int(samples[i] * fade_factor)
|
||
|
||
# 将整数数组转换回二进制数据
|
||
return samples.tobytes()
|
||
|
||
|
||
def parse_markdown_json(json_string):
|
||
# 使用正则表达式匹配Markdown中的JSON代码块
|
||
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
|
||
if match:
|
||
try:
|
||
# 尝试解析JSON字符串
|
||
data = json.loads(match[1])
|
||
return {'success': True, 'data': data}
|
||
except json.JSONDecodeError as e:
|
||
# 如果解析失败,返回错误信息
|
||
return {'success': False, 'data': str(e)}
|
||
else:
|
||
return {'success': False, 'data': 'not a valid markdown json string'}
|
||
|
||
|
||
|
||
|
||
audio_text_cache = {}
|
||
cache_lock = Lock()
|
||
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
||
|
||
|
||
# 全角字符到半角字符的映射
|
||
def fullwidth_to_halfwidth(s):
|
||
full_to_half_map = {
|
||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||
'、': '、', '・': '・', ':': ':'
|
||
}
|
||
return ''.join(full_to_half_map.get(char, char) for char in s)
|
||
|
||
|
||
def split_text_at_punctuation(text, chunk_size=100):
|
||
# 使用正则表达式找到所有的标点符号和特殊字符
|
||
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
||
tokens = re.split(punctuation_pattern, text)
|
||
|
||
# 移除空字符串
|
||
tokens = [token for token in tokens if token]
|
||
|
||
# 存储最终的文本块
|
||
chunks = []
|
||
current_chunk = ''
|
||
|
||
for token in tokens:
|
||
if len(current_chunk) + len(token) <= chunk_size:
|
||
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
||
current_chunk += (token + ' ')
|
||
else:
|
||
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
||
chunks.append(current_chunk.strip())
|
||
current_chunk = token + ' '
|
||
|
||
# 添加最后一个块(如果有剩余)
|
||
if current_chunk:
|
||
chunks.append(current_chunk.strip())
|
||
|
||
return chunks
|
||
|
||
|
||
def extract_text_from_markdown(markdown_text):
|
||
# 移除Markdown标题
|
||
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
||
# 移除内联代码块
|
||
text = re.sub(r'`[^`]+`', '', text)
|
||
# 移除代码块
|
||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||
# 移除加粗和斜体
|
||
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
||
# 移除链接
|
||
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
||
# 移除图片
|
||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||
# 移除HTML标签
|
||
text = re.sub(r'<[^>]+>', '', text)
|
||
# 转换标点符号
|
||
# text = re.sub(r'[^\w\s]', '', text)
|
||
text = fullwidth_to_halfwidth(text)
|
||
# 移除多余的空格
|
||
text = re.sub(r'\s+', ' ', text).strip()
|
||
|
||
return text
|
||
|
||
|
||
def encode_gzip_base64(original_data: bytes) -> str:
|
||
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
|
||
# Step 1: Gzip 压缩
|
||
with BytesIO() as buf:
|
||
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
|
||
gz_file.write(original_data)
|
||
compressed_bytes = buf.getvalue()
|
||
|
||
# Step 2: Base64 编码(配置与Android端匹配)
|
||
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP)
|
||
|
||
|
||
def clean_audio_cache():
|
||
"""定时清理过期缓存"""
|
||
with cache_lock:
|
||
now = time.time()
|
||
expired_keys = [
|
||
k for k, v in audio_text_cache.items()
|
||
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
|
||
]
|
||
for k in expired_keys:
|
||
entry = audio_text_cache.pop(k, None)
|
||
if entry and entry.get('audio_stream'):
|
||
entry['audio_stream'].close()
|
||
|
||
|
||
def start_background_cleaner():
|
||
"""启动后台清理线程"""
|
||
|
||
def cleaner_loop():
|
||
while True:
|
||
time.sleep(180) # 每3分钟清理一次
|
||
clean_audio_cache()
|
||
|
||
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
|
||
cleaner_thread.start()
|
||
|
||
|
||
def test_qwen_chat():
|
||
messages = [
|
||
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
||
{'role': 'user', 'content': '你是谁?'}
|
||
]
|
||
response = Generation.call(
|
||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx",
|
||
api_key=ALI_KEY,
|
||
model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models
|
||
messages=messages,
|
||
result_format="message"
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
print(response.output.choices[0].message.content)
|
||
else:
|
||
print(f"HTTP返回码:{response.status_code}")
|
||
print(f"错误码:{response.code}")
|
||
print(f"错误信息:{response.message}")
|
||
print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
|
||
|
||
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
|
||
|
||
class QwenTTS:
|
||
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longxiaochun", base_url=""):
|
||
import dashscope
|
||
import ssl
|
||
print("---begin--init QwenTTS--") # cyx
|
||
self.model_name = model_name
|
||
dashscope.api_key = key
|
||
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
|
||
self.synthesizer = None
|
||
self.callback = None
|
||
self.is_cosyvoice = False
|
||
self.voice = ""
|
||
self.format = format
|
||
self.sample_rate = sample_rate
|
||
if '/' in model_name:
|
||
parts = model_name.split('/', 1)
|
||
# 返回分离后的两个字符串parts[0], parts[1]
|
||
if parts[0] == 'cosyvoice-v1':
|
||
self.is_cosyvoice = True
|
||
self.voice = parts[1]
|
||
|
||
def tts(self, text):
|
||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||
if self.is_cosyvoice is False:
|
||
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
||
from collections import deque
|
||
else:
|
||
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
|
||
from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat # , SpeechSynthesisResult
|
||
from dashscope.audio.tts import SpeechSynthesisResult
|
||
from collections import deque
|
||
|
||
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
||
|
||
class Callback(ResultCallback):
|
||
def __init__(self) -> None:
|
||
self.dque = deque()
|
||
|
||
def _run(self):
|
||
while True:
|
||
if not self.dque:
|
||
time.sleep(0)
|
||
continue
|
||
val = self.dque.popleft()
|
||
if val:
|
||
yield val
|
||
else:
|
||
break
|
||
|
||
def on_open(self):
|
||
pass
|
||
|
||
def on_complete(self):
|
||
self.dque.append(None)
|
||
|
||
def on_error(self, response: SpeechSynthesisResponse):
|
||
print("Qwen tts error", str(response))
|
||
raise RuntimeError(str(response))
|
||
|
||
def on_close(self):
|
||
pass
|
||
|
||
def on_event(self, result: SpeechSynthesisResult):
|
||
if result.get_audio_frame() is not None:
|
||
self.dque.append(result.get_audio_frame())
|
||
|
||
# --------------------------
|
||
|
||
class Callback_v2(ResultCallback):
|
||
def __init__(self) -> None:
|
||
self.dque = deque()
|
||
|
||
def _run(self):
|
||
while True:
|
||
if not self.dque:
|
||
time.sleep(0)
|
||
continue
|
||
val = self.dque.popleft()
|
||
if val:
|
||
yield val
|
||
else:
|
||
break
|
||
|
||
def on_open(self):
|
||
logging.info("Qwen tts open")
|
||
pass
|
||
|
||
def on_complete(self):
|
||
self.dque.append(None)
|
||
|
||
def on_error(self, response: SpeechSynthesisResponse):
|
||
print("Qwen tts error", str(response))
|
||
raise RuntimeError(str(response))
|
||
|
||
def on_close(self):
|
||
# print("---Qwen call back close") # cyx
|
||
logging.info("Qwen tts close")
|
||
pass
|
||
|
||
""" canceled for test 语音大模型CosyVoice
|
||
def on_event(self, result: SpeechSynthesisResult):
|
||
if result.get_audio_frame() is not None:
|
||
self.dque.append(result.get_audio_frame())
|
||
"""
|
||
|
||
def on_event(self, message):
|
||
# print(f"recv speech synthsis message {message}")
|
||
pass
|
||
|
||
# 以下适合语音大模型CosyVoice
|
||
def on_data(self, data: bytes) -> None:
|
||
if len(data) > 0:
|
||
self.dque.append(data)
|
||
|
||
# --------------------------
|
||
|
||
# text = self.normalize_text(text)
|
||
|
||
try:
|
||
# if self.model_name != 'cosyvoice-v1':
|
||
if self.is_cosyvoice is False:
|
||
self.callback = Callback()
|
||
SpeechSynthesizer.call(model=self.model_name,
|
||
text=text,
|
||
callback=self.callback,
|
||
format="wav") # format="mp3")
|
||
else:
|
||
self.callback = Callback_v2()
|
||
format =self.get_audio_format(self.format,self.sample_rate)
|
||
self.synthesizer = SpeechSynthesizer(
|
||
model='cosyvoice-v1',
|
||
# voice="longyuan", #"longfei",
|
||
voice=self.voice,
|
||
callback=self.callback,
|
||
format=format
|
||
)
|
||
self.synthesizer.call(text)
|
||
except Exception as e:
|
||
print(f"---dale---20 error {e}") # cyx
|
||
# -----------------------------------
|
||
try:
|
||
for data in self.callback._run():
|
||
#logging.info(f"dashcope return data {len(data)}")
|
||
yield data
|
||
# print(f"---Qwen return data {num_tokens_from_string(text)}")
|
||
# yield num_tokens_from_string(text)
|
||
|
||
except Exception as e:
|
||
raise RuntimeError(f"**ERROR**: {e}")
|
||
|
||
def get_audio_format(self, format: str, sample_rate: int):
|
||
"""动态获取音频格式"""
|
||
from dashscope.audio.tts_v2 import AudioFormat
|
||
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
|
||
format_map = {
|
||
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
|
||
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
|
||
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
|
||
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
|
||
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
|
||
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
|
||
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
|
||
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
|
||
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
|
||
(48800, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
|
||
(48800, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
|
||
(48800, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
|
||
|
||
}
|
||
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
|
||
|
||
def end_tts(self):
|
||
if self.synthesizer:
|
||
self.synthesizer.streaming_complete()
|
||
|
||
@tts_router.get("/audio/pcm_mp3")
|
||
async def stream_mp3():
|
||
def audio_generator():
|
||
path = './test.mp3'
|
||
try:
|
||
with open(path, 'rb') as f:
|
||
while True:
|
||
chunk = f.read(1024)
|
||
if not chunk:
|
||
break
|
||
yield chunk
|
||
except Exception as e:
|
||
logging.error(f"MP3 streaming error: {str(e)}")
|
||
|
||
return StreamingResponse(
|
||
audio_generator(),
|
||
media_type="audio/mpeg",
|
||
headers={
|
||
"Cache-Control": "no-store",
|
||
"Accept-Ranges": "bytes"
|
||
}
|
||
)
|
||
|
||
|
||
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
|
||
"""动态生成WAV头(严格保持原有逻辑结构)"""
|
||
with BytesIO() as wav_buffer:
|
||
with wave.open(wav_buffer, 'wb') as wav_file:
|
||
wav_file.setnchannels(1) # 保持原单声道设置
|
||
wav_file.setsampwidth(2) # 保持原16-bit设置
|
||
wav_file.setframerate(sample_rate)
|
||
wav_file.writeframes(pcm_data)
|
||
wav_buffer.seek(0)
|
||
return wav_buffer.read()
|
||
|
||
|
||
def generate_silence_header(duration_ms: int = 500) -> bytes:
|
||
"""生成静音数据(用于MP3流式传输预缓冲)"""
|
||
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
|
||
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
|
||
|
||
|
||
# ------------------------ API路由 ------------------------
|
||
@tts_router.post("/chats/{chat_id}/tts")
|
||
async def create_tts_request(chat_id: str, request: Request):
|
||
"""创建TTS音频流"""
|
||
try:
|
||
data = await request.json()
|
||
logging.info(f"API--create_tts_request1-- {data}")
|
||
# 参数校验
|
||
text = data.get("text", "").strip()
|
||
if not text:
|
||
raise HTTPException(400, detail="文本内容不能为空")
|
||
|
||
format = data.get("tts_stream_format", "mp3")
|
||
if format not in ["mp3", "wav", "pcm"]:
|
||
raise HTTPException(400, detail="不支持的音频格式")
|
||
|
||
sample_rate = data.get("tts_sample_rate", 48000)
|
||
|
||
if sample_rate not in [8000, 16000, 22050, 44100]:
|
||
raise HTTPException(400, detail="不支持的采样率")
|
||
model_name = data.get("model_name","cosyvoice-v1/longxiaochun")
|
||
delay_gen_audio = data.get('delay_gen_audio',False)
|
||
format ="mp3"
|
||
|
||
#sample_rate = 48000
|
||
# 生成音频流
|
||
audio_stream_id = str(uuid.uuid4())
|
||
buffer = None
|
||
tts_info = {
|
||
"text" :text,
|
||
"buffer": buffer,
|
||
"format": format,
|
||
"sample_rate": sample_rate,
|
||
"model_name" : model_name,
|
||
"created_at": datetime.datetime.now()
|
||
}
|
||
|
||
if delay_gen_audio is False:
|
||
logging.info("--begin generate tts --")
|
||
buffer = io.BytesIO()
|
||
tts = QwenTTS(ALI_KEY,format,sample_rate,model_name.split("@")[0])
|
||
try:
|
||
for chunk in tts.tts(text):
|
||
buffer.write(chunk)
|
||
buffer.seek(0)
|
||
except Exception as e:
|
||
logging.error(f"TTS生成失败: {str(e)}")
|
||
raise HTTPException(500, detail="音频生成失败")
|
||
tts_info['buffer'] =buffer
|
||
tts_info['size'] = buffer.getbuffer().nbytes
|
||
|
||
# 存储到缓存
|
||
with cache_lock:
|
||
audio_text_cache[audio_stream_id] = tts_info
|
||
logging.info(f" tts return {audio_text_cache[audio_stream_id]}")
|
||
return JSONResponse(
|
||
status_code=200,
|
||
content={
|
||
"tts_url":f"/chats/{chat_id}/tts/{audio_stream_id}",
|
||
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
||
"expires_at": (datetime.datetime.now() + timedelta(seconds=CACHE_EXPIRE_SECONDS)).isoformat()
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logging.error(f"请求处理失败: {str(e)}")
|
||
raise HTTPException(500, detail="服务器内部错误")
|
||
|
||
executor = ThreadPoolExecutor()
|
||
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
|
||
async def get_tts_audio(
|
||
chat_id: str,
|
||
audio_stream_id: str,
|
||
range: str = Header(None)
|
||
):
|
||
"""获取音频流"""
|
||
# 清理过期缓存
|
||
cleanup_cache()
|
||
|
||
# 获取缓存
|
||
with cache_lock:
|
||
tts_info = audio_text_cache.get(audio_stream_id)
|
||
|
||
if not tts_info:
|
||
#raise HTTPException(404, detail="音频流不存在或已过期")
|
||
logging.warning("音频流不存在或已过期")
|
||
return
|
||
# 准备响应参数
|
||
buffer = tts_info.get("buffer",None)
|
||
format = tts_info.get("format","mp3")
|
||
sample_rate = tts_info.get("sample_rate", 44100)
|
||
total_size = tts_info.get('size',0)
|
||
model_name = tts_info['model_name']
|
||
text = tts_info['text']
|
||
media_type = {
|
||
"mp3": "audio/mpeg",
|
||
"wav": "audio/wav",
|
||
"pcm": f"audio/L16; rate={tts_info['sample_rate']}; channels=1"
|
||
}[format]
|
||
logging.info(f"enter get_tts_audio1 buffer={buffer} {format} {sample_rate} range {range}")
|
||
|
||
def generate_audio():
|
||
logging.info(f"get_tts_audio1 generate_audio {format} {sample_rate} {model_name} range{range}")
|
||
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
|
||
try:
|
||
for chunk in tts.tts(text):
|
||
if format == 'wav':
|
||
yield add_wav_header(chunk, sample_rate)
|
||
else:
|
||
yield chunk
|
||
finally:
|
||
# 恢复全局变量
|
||
pass
|
||
def read_buffer():
|
||
print(f"read_buffer {audio_stream} {audio_stream.closed}")
|
||
audio_stream.seek(0) # 关键重置操作
|
||
data = audio_stream.read(1024)
|
||
while data:
|
||
yield data
|
||
data = audio_stream.read(1024)
|
||
|
||
async def generate_audio_async():
|
||
logging.info(f"get_tts_audio1 generate_audio_async {format} {sample_rate} {model_name}")
|
||
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
|
||
#前端有可能传入 cosyvoice-v1/longyuan@Tongyi-Qianwen 通过.split("@")[0] 取@之前部分
|
||
loop = asyncio.get_event_loop()
|
||
yield_len = 0
|
||
try:
|
||
# 将同步函数放入线程池执行
|
||
sync_generator = await loop.run_in_executor(executor, lambda: tts.tts(text))
|
||
for chunk in sync_generator:
|
||
yield_len = yield_len + len(chunk)
|
||
yield chunk
|
||
except (BrokenPipeError, ConnectionResetError):
|
||
logging.warning("客户端主动断开连接")
|
||
finally:
|
||
# 恢复全局变量
|
||
logging.info(f"yield len={yield_len}")
|
||
pass
|
||
# 处理范围请求
|
||
if buffer is not None:
|
||
if range:
|
||
return handle_range_request(range, buffer, total_size, media_type)
|
||
buffer.seek(0)
|
||
# 完整文件响应
|
||
return StreamingResponse(
|
||
iter(lambda: buffer.read(4096), b""),
|
||
media_type=media_type,
|
||
headers={
|
||
"Accept-Ranges": "bytes",
|
||
"Content-Length": str(total_size),
|
||
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
|
||
}
|
||
)
|
||
else:
|
||
return StreamingResponse(
|
||
#generate_audio(),
|
||
generate_audio_async(),
|
||
media_type=media_type,
|
||
headers={
|
||
"Transfer-Encoding": "chunked",
|
||
"Cache-Control": "no-store",
|
||
#"Content-Disposition": "inline",
|
||
#"Content-Length": str(1067890),
|
||
"Accept-Ranges": "bytes",
|
||
"ETag":audio_stream_id,
|
||
}
|
||
)
|
||
|
||
|
||
def handle_range_request(range_header: str, buffer:BytesIO, total_size: int, media_type: str):
|
||
"""处理 HTTP Range 请求"""
|
||
try:
|
||
# 解析 Range 头部 (示例: "bytes=0-1023")
|
||
range_type, range_spec = range_header.split('=')
|
||
if range_type != 'bytes':
|
||
raise ValueError("Unsupported range type")
|
||
|
||
start_str, end_str = range_spec.split('-')
|
||
start = int(start_str)
|
||
end = int(end_str) if end_str else total_size - 1
|
||
|
||
# 验证范围有效性
|
||
if start >= total_size or end >= total_size:
|
||
raise HTTPException(status_code=416, headers={
|
||
"Content-Range": f"bytes */{total_size}"
|
||
})
|
||
status_code = 206
|
||
if start>0 and end == total_size-1:
|
||
status_code = 200
|
||
# 设置流读取位置
|
||
buffer.seek(start)
|
||
content_length = end - start + 1
|
||
|
||
# 返回分块响应
|
||
return StreamingResponse(
|
||
iter(lambda: buffer.read(4096), b''), # 直接使用 iter 避免状态问题
|
||
status_code=status_code,
|
||
media_type=media_type,
|
||
headers={
|
||
"Content-Range": f"bytes {start}-{end}/{total_size}",
|
||
"Content-Length": str(content_length),
|
||
"Accept-Ranges": "bytes",
|
||
"Cache-Control": "public, max-age=3600"
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
def cleanup_cache():
|
||
"""清理过期缓存"""
|
||
with cache_lock:
|
||
now = datetime.datetime.now()
|
||
expired = [k for k, v in audio_text_cache.items()
|
||
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
|
||
for key in expired:
|
||
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
|
||
del audio_text_cache[key]
|
||
|
||
# 应用启动时启动清理线程
|
||
# start_background_cleaner() |