准备修改大模型对话输出的文本的tts生成为stream_call,进行备份提交

This commit is contained in:
qcloud
2025-05-26 21:38:46 +08:00
parent e29f79b9ac
commit 0665eb2c2d
2786 changed files with 1375 additions and 1041863 deletions

View File

@@ -11,17 +11,48 @@ import base64, gzip
import os,io, re, json
from io import BytesIO
from typing import Optional, Dict, Any
import asyncio
from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query
from fastapi import FastAPI, UploadFile, File, Form, Header
from fastapi.responses import StreamingResponse,JSONResponse
from fastapi.responses import StreamingResponse,JSONResponse,Response
TTS_SAMPLERATE = 44100 # 22050 # 16000
FORMAT = "mp3"
CHANNELS = 1 # 单声道
SAMPLE_WIDTH = 2 # 16-bit = 2字节
tts_router = APIRouter()
logger = logging.getLogger(__name__)
# logger = logging.getLogger(__name__)
class MillisecondsFormatter(logging.Formatter):
"""自定义日志格式器,添加毫秒时间戳"""
def formatTime(self, record, datefmt=None):
# 将时间戳转换为本地时间元组
ct = self.converter(record.created)
# 格式化为 "小时:分钟:秒"
t = time.strftime("%H:%M:%S", ct)
# 添加毫秒3位
return f"{t}.{int(record.msecs):03d}"
# 配置全局日志格式
def configure_logging():
# 创建 Formatter
log_format = "%(asctime)s - %(levelname)s - %(message)s"
formatter = MillisecondsFormatter(log_format)
# 获取根 Logger 并清除已有配置
root_logger = logging.getLogger()
root_logger.handlers = []
# 创建并配置 Handler输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
# 设置日志级别并添加 Handler
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
# 调用配置函数(程序启动时运行一次)
configure_logging()
class StreamSessionManager:
def __init__(self):
@@ -208,46 +239,185 @@ def tts_stream(session_id):
return resp
"""
def generate_mp3_header(bitrate_kbps=128, padding=0):
# 字段定义
sync = 0b11111111111 # 同步字11位
version = 0b11 # MPEG-12位
layer = 0b01 # Layer III2位
protection = 0b0 # 无CRC1位
bitrate_index = { # 比特率索引表MPEG-1 Layer III
32: 0b0001, 40: 0b0010, 48: 0b0011, 56: 0b0100,
64: 0b0101, 80: 0b0110, 96: 0b0111, 112: 0b1000,
128: 0b1001, 160: 0b1010, 192: 0b1011, 224: 0b1100,
256: 0b1101, 320: 0b1110
}[bitrate_kbps]
sampling_rate = 0b00 # 44.1kHz2位
padding_bit = padding # 填充位1位
private = 0b0 # 私有位1位
mode = 0b11 # 单声道2位
mode_ext = 0b00 # 扩展模式2位
copyright = 0b0 # 无版权1位
original = 0b0 # 非原版1位
emphasis = 0b00 # 无强调2位
def generate_mp3_header(
sample_rate: int,
bitrate_kbps: int,
channels: int = 1,
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
) -> bytes:
"""
动态生成 MP3 帧头4字节支持 Layer I/II/III
# 组合为32位整数大端序
header = (
(sync << 21) |
(version << 19) |
(layer << 17) |
(protection << 16) |
(bitrate_index << 12) |
(sampling_rate << 10) |
(padding_bit << 9) |
(private << 8) |
(mode << 6) |
(mode_ext << 4) |
(copyright << 3) |
(original << 2) |
emphasis
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
:param bitrate_kbps: 比特率(单位 kbps
:param channels: 声道数 (1: 单声道, 2: 立体声)
:param layer: 编码层 ("I", "II", "III")
:return: 4字节的帧头数据
"""
# ----------------------------------
# 参数校验
# ----------------------------------
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
if sample_rate not in valid_sample_rates:
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
valid_layers = {"I", "II", "III"}
if layer not in valid_layers:
raise ValueError(f"不支持的层,可选:{valid_layers}")
# ----------------------------------
# 确定 MPEG 版本和采样率索引
# ----------------------------------
if sample_rate == 44100:
mpeg_version = 0b11 # MPEG-1
sample_rate_index = 0b00
elif sample_rate == 22050:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b00
elif sample_rate == 16000:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b10
elif sample_rate == 8000:
mpeg_version = 0b00 # MPEG-2.5
sample_rate_index = 0b10
else:
raise ValueError("采样率与版本不匹配")
# ----------------------------------
# 动态选择比特率表(关键扩展)
# ----------------------------------
# Layer 编码映射I:0b11, II:0b10, III:0b01
layer_code = {
"I": 0b11,
"II": 0b10,
"III": 0b01
}[layer]
# 比特率表(覆盖所有 Layer
bitrate_tables = {
# -------------------------------
# MPEG-1 (0b11)
# -------------------------------
# Layer I
(0b11, 0b11): {
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
416: 0b1100, 448: 0b1101
},
# Layer II
(0b11, 0b10): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
320: 0b1100, 384: 0b1101
},
# Layer III
(0b11, 0b01): {
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
256: 0b0100, 320: 0b0101
},
# -------------------------------
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
# -------------------------------
# Layer I
(0b10, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
(0b00, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
# Layer II
(0b10, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
(0b00, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
# Layer III
(0b10, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
144: 0b0100, 160: 0b0101
},
(0b00, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
}
}
# 获取当前版本的比特率表
key = (mpeg_version, layer_code)
if key not in bitrate_tables:
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
bitrate_table = bitrate_tables[key]
if bitrate_kbps not in bitrate_table:
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
bitrate_index = bitrate_table[bitrate_kbps]
# ----------------------------------
# 确定声道模式
# ----------------------------------
if channels == 1:
channel_mode = 0b11 # 单声道
elif channels == 2:
channel_mode = 0b00 # 立体声
else:
raise ValueError("声道数必须为1或2")
# ----------------------------------
# 组合帧头字段(修正层编码)
# ----------------------------------
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
version = mpeg_version << 19 # MPEG 版本 2位
layer_bits = layer_code << 17 # Layer 编码I:0b11, II:0b10, III:0b01
protection = 0 << 16 # 无 CRC
bitrate_bits = bitrate_index << 12
sample_rate_bits = sample_rate_index << 10
padding = 0 << 9 # 无填充
private = 0 << 8
mode = channel_mode << 6
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
copyright = 0 << 3
original = 0 << 2
emphasis = 0b00 # 无强调
frame_header = (
sync |
version |
layer_bits |
protection |
bitrate_bits |
sample_rate_bits |
padding |
private |
mode |
mode_ext |
copyright |
original |
emphasis
)
# 转换为4字节二进制数据
return header.to_bytes(4, byteorder='big')
return frame_header.to_bytes(4, byteorder='big')
# ------------------------------------------------
def audio_fade_in(audio_data, fade_length):
@@ -414,7 +584,7 @@ def test_qwen_chat():
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
class QwenTTS:
def __init__(self, key, model_name="cosyvoice-v1/longxiaochun", base_url=""):
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longxiaochun", base_url=""):
import dashscope
import ssl
print("---begin--init QwenTTS--") # cyx
@@ -425,6 +595,8 @@ class QwenTTS:
self.callback = None
self.is_cosyvoice = False
self.voice = ""
self.format = format
self.sample_rate = sample_rate
if '/' in model_name:
parts = model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
@@ -539,7 +711,7 @@ class QwenTTS:
format="wav") # format="mp3")
else:
self.callback = Callback_v2()
format =self.get_audio_format(FORMAT,TTS_SAMPLERATE)
format =self.get_audio_format(self.format,self.sample_rate)
self.synthesizer = SpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
@@ -553,6 +725,7 @@ class QwenTTS:
# -----------------------------------
try:
for data in self.callback._run():
#logging.info(f"dashcope return data {len(data)}")
yield data
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
@@ -563,7 +736,7 @@ class QwenTTS:
def get_audio_format(self, format: str, sample_rate: int):
"""动态获取音频格式"""
from dashscope.audio.tts_v2 import AudioFormat
logging.info(f"--get_audio_format-- {format} {sample_rate}")
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
@@ -575,6 +748,10 @@ class QwenTTS:
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
(48800, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
(48800, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
(48800, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
@@ -599,7 +776,10 @@ async def stream_mp3():
return StreamingResponse(
audio_generator(),
media_type="audio/mpeg",
headers={"Cache-Control": "no-store"}
headers={
"Cache-Control": "no-store",
"Accept-Ranges": "bytes"
}
)
@@ -621,207 +801,9 @@ def generate_silence_header(duration_ms: int = 500) -> bytes:
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio(
chat_id: str,
audio_stream_id: str,
request: Request,
range: str = Header(None) # 新增Range头解析
):
with cache_lock:
# tts_info = audio_text_cache.pop(audio_stream_id, None)
tts_info = audio_text_cache.get(audio_stream_id, None)
if not tts_info:
raise HTTPException(404, detail="音频流已过期")
audio_stream_len = tts_info.get('audio_stream_len', 0)
audio_stream = tts_info['audio_stream']
text = tts_info['text']
model_name = tts_info.get('model_name', "cosyvoice-v1/longxiaochun")
format = tts_info['format'] # 新增字段
sample_rate = tts_info['sample_rate'] # 新增字段
def stream_audio():
total = 0
try:
for chunk in tts_mdl.tts(text):
# print(f"data_length={total} {chunk}")
hex_data = binascii.hexlify(chunk).decode("utf-8") + "\r\n" # <--- 添加\n
# yield f"{hex_data}" # <--- SSE格式封装
print(f"yield {len(chunk)}")
# yield chunk
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
total = total + len(gzip_base64_data)
yield gzip_base64_data
print(f"tts gen end {total}")
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
def read_buffer():
print(f"get audio_stream {audio_stream} {audio_stream.closed}")
audio_stream.seek(0) # 关键重置操作
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
def generate_silence_header():
"""生成500ms的静音MP3帧约44100Hz"""
# 示例静音数据实际需生成标准MP3静音帧
return b'\xff\xfb\xd4\x00' * 100 # 约400字节
# 保持原有生成器结构
def generate_audio():
# 保持原有QwenTTS调用方式
tts = QwenTTS(ALI_KEY, model_name)
# 临时修改全局配置(保持原有逻辑)
global FORMAT, TTS_SAMPLERATE
original_format = FORMAT
original_sample_rate = TTS_SAMPLERATE
FORMAT = format
TTS_SAMPLERATE = sample_rate
try:
for chunk in tts.tts(text):
if format == 'wav':
yield add_wav_header(chunk, sample_rate)
else:
yield chunk
finally:
# 恢复全局变量
FORMAT = original_format
TTS_SAMPLERATE = original_sample_rate
# 保持原有流响应逻辑
media_type_map = {
'mp3': 'audio/mpeg',
'wav': 'audio/wav',
'pcm': f'audio/L16; rate={sample_rate}; channels=1'
}
# 处理Range请求逻辑
if range:
start =0
end =audio_stream_len
total_length = audio_stream_len
return StreamingResponse(
read_buffer(),
status_code=206,
headers={
"Accept-Ranges": "bytes",
"Content-Range": f"bytes {start}-{end}/{total_length}",
"Cache-Control": "no-store" # 确保不缓存动态内容
}
)
else:
try:
# 从缓存获取原始参数
if audio_stream is None:
tts_mdl = QwenTTS(ALI_KEY, "cosyvoice-v1/longxiaochun")
if audio_stream:
logging.info("return audio stream buffer")
# 确保流的位置在开始处
return StreamingResponse(
read_buffer(),
media_type=media_type_map[format],
headers={
"Transfer-Encoding": "chunked",
"Cache-Control": "no-store",
"Content-Disposition": "inline",
"Access-Control-Allow-Origin": "*"
}
)
else:
return StreamingResponse(
generate_audio(),
media_type=media_type_map[format],
headers={
"Transfer-Encoding": "chunked",
"Cache-Control": "no-store",
"Content-Disposition": "inline"
}
)
except Exception as e:
logging.error(f"音频流错误: {str(e)}")
raise HTTPException(500, detail="音频生成失败")
# ------------------------ API路由 ------------------------
@tts_router.post("/chats/{chat_id}/tts")
async def create_tts_request(chat_id: str, request: Request):
try:
request_data = await request.json()
# 获取并验证新参数
format = request_data.get('format', FORMAT)
if format not in ['mp3', 'wav', 'pcm']:
raise HTTPException(400, detail="不支持的音频格式")
sample_rate = request_data.get('sample_rate', TTS_SAMPLERATE)
if sample_rate not in [8000, 16000, 22050, 44100]:
raise HTTPException(400, detail="不支持的采样率")
# 保持原有参数处理
text = request_data.get("text")
if not text or not text.strip():
raise HTTPException(400, detail="文本内容不能为空")
# 存储新增参数到缓存
audio_stream_id = str(uuid.uuid4())
tts_info = {
'text': text,
'chat_id': chat_id,
'created_at': time.time(),
'audio_stream': None,
'model_name': request_data.get('model_name'),
'format': format,
'sample_rate': sample_rate
}
audio_stream_len = 0
# 保持原有延迟生成逻辑
if request_data.get('delay_gen_audio', False) is False:
try:
# 临时设置全局变量(保持原有逻辑)
original_format = FORMAT
original_sample_rate = TTS_SAMPLERATE
tts = QwenTTS(ALI_KEY, tts_info.get('model_name', "cosyvoice-v1/longxiaochun"))
buffer = io.BytesIO()
for chunk in tts.tts(text):
audio_stream_len = audio_stream_len + len(chunk)
buffer.write(chunk)
buffer.seek(0)
tts_info['audio_stream'] = buffer
tts_info['audio_stream_len'] = audio_stream_len
finally:
# 恢复全局变量
pass
# 保持原有缓存逻辑
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
logging.info(f"create tts stream return {audio_stream_id} {audio_text_cache[audio_stream_id]}")
# 保持原响应结构
return {
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"audio_stream_id": audio_stream_id,
"sample_rate": sample_rate
}
except Exception as e:
logging.error(f"请求处理失败: {str(e)}")
raise HTTPException(500, detail="内部服务器错误")
# 辅助函数
# ------------------------ API路由 ------------------------
@tts_router.post("/chats1/{chat_id}/tts")
async def create_tts_request1(chat_id: str, request: Request):
"""创建TTS音频流"""
try:
data = await request.json()
@@ -831,43 +813,54 @@ async def create_tts_request1(chat_id: str, request: Request):
if not text:
raise HTTPException(400, detail="文本内容不能为空")
format = data.get("format", "mp3")
format = data.get("tts_stream_format", "mp3")
if format not in ["mp3", "wav", "pcm"]:
raise HTTPException(400, detail="不支持的音频格式")
sample_rate = data.get("sample_rate", 16000)
sample_rate = data.get("tts_sample_rate", 48000)
if sample_rate not in [8000, 16000, 22050, 44100]:
raise HTTPException(400, detail="不支持的采样率")
model_name = data.get("model_name","cosyvoice-v1/longxiaochun")
delay_gen_audio = data.get('delay_gen_audio',False)
format ="mp3"
sample_rate = 44100
#sample_rate = 48000
# 生成音频流
audio_stream_id = str(uuid.uuid4())
buffer = io.BytesIO()
tts = QwenTTS(ALI_KEY)
try:
for chunk in tts.tts(text):
buffer.write(chunk)
buffer.seek(0)
except Exception as e:
logging.error(f"TTS生成失败: {str(e)}")
raise HTTPException(500, detail="音频生成失败")
# 存储到缓存
with cache_lock:
audio_text_cache[audio_stream_id] = {
buffer = None
tts_info = {
"text" :text,
"buffer": buffer,
"format": format,
"sample_rate": sample_rate,
"created_at": datetime.datetime.now(),
"size": buffer.getbuffer().nbytes
"model_name" : model_name,
"created_at": datetime.datetime.now()
}
if delay_gen_audio is False:
logging.info("--begin generate tts --")
buffer = io.BytesIO()
tts = QwenTTS(ALI_KEY,format,sample_rate,model_name.split("@")[0])
try:
for chunk in tts.tts(text):
buffer.write(chunk)
buffer.seek(0)
except Exception as e:
logging.error(f"TTS生成失败: {str(e)}")
raise HTTPException(500, detail="音频生成失败")
tts_info['buffer'] =buffer
tts_info['size'] = buffer.getbuffer().nbytes
# 存储到缓存
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
logging.info(f" tts return {audio_text_cache[audio_stream_id]}")
return JSONResponse(
status_code=200,
content={
"tts_url":f"/chats1/{chat_id}/tts/{audio_stream_id}",
"url": f"/chats1/{chat_id}/tts/{audio_stream_id}",
"tts_url":f"/chats/{chat_id}/tts/{audio_stream_id}",
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"expires_at": (datetime.datetime.now() + timedelta(seconds=CACHE_EXPIRE_SECONDS)).isoformat()
}
)
@@ -876,9 +869,9 @@ async def create_tts_request1(chat_id: str, request: Request):
logging.error(f"请求处理失败: {str(e)}")
raise HTTPException(500, detail="服务器内部错误")
@tts_router.get("/chats1/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio1(
executor = ThreadPoolExecutor()
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio(
chat_id: str,
audio_stream_id: str,
range: str = Header(None)
@@ -889,80 +882,134 @@ async def get_tts_audio1(
# 获取缓存
with cache_lock:
item = audio_text_cache.get(audio_stream_id)
if not item:
raise HTTPException(404, detail="音频流不存在或已过期")
tts_info = audio_text_cache.get(audio_stream_id)
if not tts_info:
#raise HTTPException(404, detail="音频流不存在或已过期")
logging.warning("音频流不存在或已过期")
return
# 准备响应参数
buffer = item["buffer"]
format = item["format"]
total_size = item["size"]
buffer = tts_info.get("buffer",None)
format = tts_info.get("format","mp3")
sample_rate = tts_info.get("sample_rate", 44100)
total_size = tts_info.get('size',0)
model_name = tts_info['model_name']
text = tts_info['text']
media_type = {
"mp3": "audio/mpeg",
"wav": "audio/wav",
"pcm": f"audio/L16; rate={item['sample_rate']}; channels=1"
"pcm": f"audio/L16; rate={tts_info['sample_rate']}; channels=1"
}[format]
logging.info(f"enter get_tts_audio1 buffer={buffer} {format} {sample_rate} range {range}")
def generate_audio():
logging.info(f"get_tts_audio1 generate_audio {format} {sample_rate} {model_name} range{range}")
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
try:
for chunk in tts.tts(text):
if format == 'wav':
yield add_wav_header(chunk, sample_rate)
else:
yield chunk
finally:
# 恢复全局变量
pass
def read_buffer():
print(f"read_buffer {audio_stream} {audio_stream.closed}")
audio_stream.seek(0) # 关键重置操作
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
async def generate_audio_async():
logging.info(f"get_tts_audio1 generate_audio_async {format} {sample_rate} {model_name}")
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
#前端有可能传入 cosyvoice-v1/longyuan@Tongyi-Qianwen 通过.split("@")[0] 取@之前部分
loop = asyncio.get_event_loop()
yield_len = 0
try:
# 将同步函数放入线程池执行
sync_generator = await loop.run_in_executor(executor, lambda: tts.tts(text))
for chunk in sync_generator:
yield_len = yield_len + len(chunk)
yield chunk
except (BrokenPipeError, ConnectionResetError):
logging.warning("客户端主动断开连接")
finally:
# 恢复全局变量
logging.info(f"yield len={yield_len}")
pass
# 处理范围请求
if range:
return handle_range_request(range, buffer, total_size, media_type)
# 完整文件响应
return StreamingResponse(
iter(lambda: buffer.read(4096), b""),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total_size),
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
}
)
if buffer is not None:
if range:
return handle_range_request(range, buffer, total_size, media_type)
buffer.seek(0)
# 完整文件响应
return StreamingResponse(
iter(lambda: buffer.read(4096), b""),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total_size),
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
}
)
else:
return StreamingResponse(
#generate_audio(),
generate_audio_async(),
media_type=media_type,
headers={
"Transfer-Encoding": "chunked",
"Cache-Control": "no-store",
#"Content-Disposition": "inline",
#"Content-Length": str(1067890),
"Accept-Ranges": "bytes",
"ETag":audio_stream_id,
}
)
def handle_range_request(range: str, buffer: BytesIO, total_size: int, media_type: str):
"""处理HTTP范围请求"""
def handle_range_request(range_header: str, buffer:BytesIO, total_size: int, media_type: str):
"""处理 HTTP Range 请求"""
try:
unit, ranges = range.split("=")
if unit != "bytes":
raise ValueError
# 解析 Range 头部 (示例: "bytes=0-1023")
range_type, range_spec = range_header.split('=')
if range_type != 'bytes':
raise ValueError("Unsupported range type")
start_str, end_str = ranges.split("-")
start = int(start_str) if start_str else 0
start_str, end_str = range_spec.split('-')
start = int(start_str)
end = int(end_str) if end_str else total_size - 1
# 验证范围有效性
if start >= total_size or end >= total_size:
raise HTTPException(416, detail="请求范围无效", headers={
raise HTTPException(status_code=416, headers={
"Content-Range": f"bytes */{total_size}"
})
content_length = end - start + 1
status_code = 206
if start>0 and end == total_size-1:
status_code = 200
# 设置流读取位置
buffer.seek(start)
content_length = end - start + 1
def chunk_generator():
remaining = content_length
while remaining > 0:
chunk = buffer.read(min(4096, remaining))
if not chunk:
break
yield chunk
remaining -= len(chunk)
# 返回分块响应
return StreamingResponse(
chunk_generator(),
status_code=206,
iter(lambda: buffer.read(4096), b''), # 直接使用 iter 避免状态问题
status_code=status_code,
media_type=media_type,
headers={
"Content-Range": f"bytes {start}-{end}/{total_size}",
"Content-Length": str(content_length),
"Content-Type": media_type,
"Accept-Ranges": "bytes"
"Accept-Ranges": "bytes",
"Cache-Control": "public, max-age=3600"
}
)
except ValueError:
raise HTTPException(400, detail="无效的Range头")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def cleanup_cache():
"""清理过期缓存"""
@@ -971,6 +1018,7 @@ def cleanup_cache():
expired = [k for k, v in audio_text_cache.items()
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
for key in expired:
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
del audio_text_cache[key]
# 应用启动时启动清理线程