准备修改大模型对话输出的文本的tts生成为stream_call,进行备份提交
This commit is contained in:
@@ -11,17 +11,48 @@ import base64, gzip
|
||||
import os,io, re, json
|
||||
from io import BytesIO
|
||||
from typing import Optional, Dict, Any
|
||||
import asyncio
|
||||
|
||||
from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query
|
||||
from fastapi import FastAPI, UploadFile, File, Form, Header
|
||||
from fastapi.responses import StreamingResponse,JSONResponse
|
||||
from fastapi.responses import StreamingResponse,JSONResponse,Response
|
||||
TTS_SAMPLERATE = 44100 # 22050 # 16000
|
||||
FORMAT = "mp3"
|
||||
CHANNELS = 1 # 单声道
|
||||
SAMPLE_WIDTH = 2 # 16-bit = 2字节
|
||||
|
||||
tts_router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
# logger = logging.getLogger(__name__)
|
||||
class MillisecondsFormatter(logging.Formatter):
|
||||
"""自定义日志格式器,添加毫秒时间戳"""
|
||||
def formatTime(self, record, datefmt=None):
|
||||
# 将时间戳转换为本地时间元组
|
||||
ct = self.converter(record.created)
|
||||
# 格式化为 "小时:分钟:秒"
|
||||
t = time.strftime("%H:%M:%S", ct)
|
||||
# 添加毫秒(3位)
|
||||
return f"{t}.{int(record.msecs):03d}"
|
||||
|
||||
# 配置全局日志格式
|
||||
def configure_logging():
|
||||
# 创建 Formatter
|
||||
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
||||
formatter = MillisecondsFormatter(log_format)
|
||||
|
||||
# 获取根 Logger 并清除已有配置
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.handlers = []
|
||||
|
||||
# 创建并配置 Handler(输出到控制台)
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
|
||||
# 设置日志级别并添加 Handler
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 调用配置函数(程序启动时运行一次)
|
||||
configure_logging()
|
||||
|
||||
class StreamSessionManager:
|
||||
def __init__(self):
|
||||
@@ -208,46 +239,185 @@ def tts_stream(session_id):
|
||||
return resp
|
||||
|
||||
"""
|
||||
def generate_mp3_header(bitrate_kbps=128, padding=0):
|
||||
# 字段定义
|
||||
sync = 0b11111111111 # 同步字(11位)
|
||||
version = 0b11 # MPEG-1(2位)
|
||||
layer = 0b01 # Layer III(2位)
|
||||
protection = 0b0 # 无CRC(1位)
|
||||
bitrate_index = { # 比特率索引表(MPEG-1 Layer III)
|
||||
32: 0b0001, 40: 0b0010, 48: 0b0011, 56: 0b0100,
|
||||
64: 0b0101, 80: 0b0110, 96: 0b0111, 112: 0b1000,
|
||||
128: 0b1001, 160: 0b1010, 192: 0b1011, 224: 0b1100,
|
||||
256: 0b1101, 320: 0b1110
|
||||
}[bitrate_kbps]
|
||||
sampling_rate = 0b00 # 44.1kHz(2位)
|
||||
padding_bit = padding # 填充位(1位)
|
||||
private = 0b0 # 私有位(1位)
|
||||
mode = 0b11 # 单声道(2位)
|
||||
mode_ext = 0b00 # 扩展模式(2位)
|
||||
copyright = 0b0 # 无版权(1位)
|
||||
original = 0b0 # 非原版(1位)
|
||||
emphasis = 0b00 # 无强调(2位)
|
||||
def generate_mp3_header(
|
||||
sample_rate: int,
|
||||
bitrate_kbps: int,
|
||||
channels: int = 1,
|
||||
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
|
||||
) -> bytes:
|
||||
"""
|
||||
动态生成 MP3 帧头(4字节),支持 Layer I/II/III
|
||||
|
||||
# 组合为32位整数(大端序)
|
||||
header = (
|
||||
(sync << 21) |
|
||||
(version << 19) |
|
||||
(layer << 17) |
|
||||
(protection << 16) |
|
||||
(bitrate_index << 12) |
|
||||
(sampling_rate << 10) |
|
||||
(padding_bit << 9) |
|
||||
(private << 8) |
|
||||
(mode << 6) |
|
||||
(mode_ext << 4) |
|
||||
(copyright << 3) |
|
||||
(original << 2) |
|
||||
emphasis
|
||||
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
|
||||
:param bitrate_kbps: 比特率(单位 kbps)
|
||||
:param channels: 声道数 (1: 单声道, 2: 立体声)
|
||||
:param layer: 编码层 ("I", "II", "III")
|
||||
:return: 4字节的帧头数据
|
||||
"""
|
||||
# ----------------------------------
|
||||
# 参数校验
|
||||
# ----------------------------------
|
||||
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
|
||||
if sample_rate not in valid_sample_rates:
|
||||
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
|
||||
|
||||
valid_layers = {"I", "II", "III"}
|
||||
if layer not in valid_layers:
|
||||
raise ValueError(f"不支持的层,可选:{valid_layers}")
|
||||
|
||||
# ----------------------------------
|
||||
# 确定 MPEG 版本和采样率索引
|
||||
# ----------------------------------
|
||||
if sample_rate == 44100:
|
||||
mpeg_version = 0b11 # MPEG-1
|
||||
sample_rate_index = 0b00
|
||||
elif sample_rate == 22050:
|
||||
mpeg_version = 0b10 # MPEG-2
|
||||
sample_rate_index = 0b00
|
||||
elif sample_rate == 16000:
|
||||
mpeg_version = 0b10 # MPEG-2
|
||||
sample_rate_index = 0b10
|
||||
elif sample_rate == 8000:
|
||||
mpeg_version = 0b00 # MPEG-2.5
|
||||
sample_rate_index = 0b10
|
||||
else:
|
||||
raise ValueError("采样率与版本不匹配")
|
||||
|
||||
# ----------------------------------
|
||||
# 动态选择比特率表(关键扩展)
|
||||
# ----------------------------------
|
||||
# Layer 编码映射(I:0b11, II:0b10, III:0b01)
|
||||
layer_code = {
|
||||
"I": 0b11,
|
||||
"II": 0b10,
|
||||
"III": 0b01
|
||||
}[layer]
|
||||
|
||||
# 比特率表(覆盖所有 Layer)
|
||||
bitrate_tables = {
|
||||
# -------------------------------
|
||||
# MPEG-1 (0b11)
|
||||
# -------------------------------
|
||||
# Layer I
|
||||
(0b11, 0b11): {
|
||||
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
|
||||
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
|
||||
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
|
||||
416: 0b1100, 448: 0b1101
|
||||
},
|
||||
# Layer II
|
||||
(0b11, 0b10): {
|
||||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||||
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
|
||||
320: 0b1100, 384: 0b1101
|
||||
},
|
||||
# Layer III
|
||||
(0b11, 0b01): {
|
||||
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
|
||||
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
|
||||
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
|
||||
256: 0b0100, 320: 0b0101
|
||||
},
|
||||
|
||||
# -------------------------------
|
||||
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
|
||||
# -------------------------------
|
||||
# Layer I
|
||||
(0b10, 0b11): {
|
||||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||||
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
||||
224: 0b1100, 256: 0b1101
|
||||
},
|
||||
(0b00, 0b11): {
|
||||
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
|
||||
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
|
||||
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
|
||||
224: 0b1100, 256: 0b1101
|
||||
},
|
||||
|
||||
# Layer II
|
||||
(0b10, 0b10): {
|
||||
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
||||
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
||||
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
||||
144: 0b1100, 160: 0b1101
|
||||
},
|
||||
(0b00, 0b10): {
|
||||
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
|
||||
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
|
||||
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
|
||||
144: 0b1100, 160: 0b1101
|
||||
},
|
||||
|
||||
# Layer III
|
||||
(0b10, 0b01): {
|
||||
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
||||
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
|
||||
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
|
||||
144: 0b0100, 160: 0b0101
|
||||
},
|
||||
(0b00, 0b01): {
|
||||
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
|
||||
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
|
||||
}
|
||||
}
|
||||
|
||||
# 获取当前版本的比特率表
|
||||
key = (mpeg_version, layer_code)
|
||||
if key not in bitrate_tables:
|
||||
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
|
||||
bitrate_table = bitrate_tables[key]
|
||||
|
||||
if bitrate_kbps not in bitrate_table:
|
||||
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
|
||||
bitrate_index = bitrate_table[bitrate_kbps]
|
||||
|
||||
# ----------------------------------
|
||||
# 确定声道模式
|
||||
# ----------------------------------
|
||||
if channels == 1:
|
||||
channel_mode = 0b11 # 单声道
|
||||
elif channels == 2:
|
||||
channel_mode = 0b00 # 立体声
|
||||
else:
|
||||
raise ValueError("声道数必须为1或2")
|
||||
|
||||
# ----------------------------------
|
||||
# 组合帧头字段(修正层编码)
|
||||
# ----------------------------------
|
||||
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
|
||||
version = mpeg_version << 19 # MPEG 版本 2位
|
||||
layer_bits = layer_code << 17 # Layer 编码(I:0b11, II:0b10, III:0b01)
|
||||
protection = 0 << 16 # 无 CRC
|
||||
bitrate_bits = bitrate_index << 12
|
||||
sample_rate_bits = sample_rate_index << 10
|
||||
padding = 0 << 9 # 无填充
|
||||
private = 0 << 8
|
||||
mode = channel_mode << 6
|
||||
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
|
||||
copyright = 0 << 3
|
||||
original = 0 << 2
|
||||
emphasis = 0b00 # 无强调
|
||||
|
||||
frame_header = (
|
||||
sync |
|
||||
version |
|
||||
layer_bits |
|
||||
protection |
|
||||
bitrate_bits |
|
||||
sample_rate_bits |
|
||||
padding |
|
||||
private |
|
||||
mode |
|
||||
mode_ext |
|
||||
copyright |
|
||||
original |
|
||||
emphasis
|
||||
)
|
||||
# 转换为4字节二进制数据
|
||||
return header.to_bytes(4, byteorder='big')
|
||||
|
||||
return frame_header.to_bytes(4, byteorder='big')
|
||||
|
||||
# ------------------------------------------------
|
||||
def audio_fade_in(audio_data, fade_length):
|
||||
@@ -414,7 +584,7 @@ def test_qwen_chat():
|
||||
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
|
||||
|
||||
class QwenTTS:
|
||||
def __init__(self, key, model_name="cosyvoice-v1/longxiaochun", base_url=""):
|
||||
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longxiaochun", base_url=""):
|
||||
import dashscope
|
||||
import ssl
|
||||
print("---begin--init QwenTTS--") # cyx
|
||||
@@ -425,6 +595,8 @@ class QwenTTS:
|
||||
self.callback = None
|
||||
self.is_cosyvoice = False
|
||||
self.voice = ""
|
||||
self.format = format
|
||||
self.sample_rate = sample_rate
|
||||
if '/' in model_name:
|
||||
parts = model_name.split('/', 1)
|
||||
# 返回分离后的两个字符串parts[0], parts[1]
|
||||
@@ -539,7 +711,7 @@ class QwenTTS:
|
||||
format="wav") # format="mp3")
|
||||
else:
|
||||
self.callback = Callback_v2()
|
||||
format =self.get_audio_format(FORMAT,TTS_SAMPLERATE)
|
||||
format =self.get_audio_format(self.format,self.sample_rate)
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model='cosyvoice-v1',
|
||||
# voice="longyuan", #"longfei",
|
||||
@@ -553,6 +725,7 @@ class QwenTTS:
|
||||
# -----------------------------------
|
||||
try:
|
||||
for data in self.callback._run():
|
||||
#logging.info(f"dashcope return data {len(data)}")
|
||||
yield data
|
||||
# print(f"---Qwen return data {num_tokens_from_string(text)}")
|
||||
# yield num_tokens_from_string(text)
|
||||
@@ -563,7 +736,7 @@ class QwenTTS:
|
||||
def get_audio_format(self, format: str, sample_rate: int):
|
||||
"""动态获取音频格式"""
|
||||
from dashscope.audio.tts_v2 import AudioFormat
|
||||
logging.info(f"--get_audio_format-- {format} {sample_rate}")
|
||||
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
|
||||
format_map = {
|
||||
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
|
||||
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
|
||||
@@ -575,6 +748,10 @@ class QwenTTS:
|
||||
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
|
||||
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
|
||||
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
|
||||
(48800, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
|
||||
(48800, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
|
||||
(48800, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
|
||||
|
||||
}
|
||||
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
|
||||
|
||||
@@ -599,7 +776,10 @@ async def stream_mp3():
|
||||
return StreamingResponse(
|
||||
audio_generator(),
|
||||
media_type="audio/mpeg",
|
||||
headers={"Cache-Control": "no-store"}
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Accept-Ranges": "bytes"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -621,207 +801,9 @@ def generate_silence_header(duration_ms: int = 500) -> bytes:
|
||||
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
|
||||
|
||||
|
||||
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
|
||||
async def get_tts_audio(
|
||||
chat_id: str,
|
||||
audio_stream_id: str,
|
||||
request: Request,
|
||||
range: str = Header(None) # 新增Range头解析
|
||||
):
|
||||
with cache_lock:
|
||||
# tts_info = audio_text_cache.pop(audio_stream_id, None)
|
||||
tts_info = audio_text_cache.get(audio_stream_id, None)
|
||||
if not tts_info:
|
||||
raise HTTPException(404, detail="音频流已过期")
|
||||
audio_stream_len = tts_info.get('audio_stream_len', 0)
|
||||
audio_stream = tts_info['audio_stream']
|
||||
text = tts_info['text']
|
||||
|
||||
model_name = tts_info.get('model_name', "cosyvoice-v1/longxiaochun")
|
||||
format = tts_info['format'] # 新增字段
|
||||
sample_rate = tts_info['sample_rate'] # 新增字段
|
||||
|
||||
def stream_audio():
|
||||
total = 0
|
||||
try:
|
||||
|
||||
for chunk in tts_mdl.tts(text):
|
||||
# print(f"data_length={total} {chunk}")
|
||||
hex_data = binascii.hexlify(chunk).decode("utf-8") + "\r\n" # <--- 添加\n
|
||||
# yield f"{hex_data}" # <--- SSE格式封装
|
||||
print(f"yield {len(chunk)}")
|
||||
# yield chunk
|
||||
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
|
||||
total = total + len(gzip_base64_data)
|
||||
yield gzip_base64_data
|
||||
print(f"tts gen end {total}")
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e)}},
|
||||
ensure_ascii=False)).encode('utf-8')
|
||||
|
||||
def read_buffer():
|
||||
print(f"get audio_stream {audio_stream} {audio_stream.closed}")
|
||||
audio_stream.seek(0) # 关键重置操作
|
||||
data = audio_stream.read(1024)
|
||||
while data:
|
||||
yield data
|
||||
data = audio_stream.read(1024)
|
||||
|
||||
def generate_silence_header():
|
||||
"""生成500ms的静音MP3帧(约44100Hz)"""
|
||||
# 示例静音数据(实际需生成标准MP3静音帧)
|
||||
return b'\xff\xfb\xd4\x00' * 100 # 约400字节
|
||||
|
||||
# 保持原有生成器结构
|
||||
def generate_audio():
|
||||
# 保持原有QwenTTS调用方式
|
||||
tts = QwenTTS(ALI_KEY, model_name)
|
||||
|
||||
# 临时修改全局配置(保持原有逻辑)
|
||||
global FORMAT, TTS_SAMPLERATE
|
||||
original_format = FORMAT
|
||||
original_sample_rate = TTS_SAMPLERATE
|
||||
FORMAT = format
|
||||
TTS_SAMPLERATE = sample_rate
|
||||
|
||||
try:
|
||||
for chunk in tts.tts(text):
|
||||
if format == 'wav':
|
||||
yield add_wav_header(chunk, sample_rate)
|
||||
else:
|
||||
yield chunk
|
||||
finally:
|
||||
# 恢复全局变量
|
||||
FORMAT = original_format
|
||||
TTS_SAMPLERATE = original_sample_rate
|
||||
|
||||
# 保持原有流响应逻辑
|
||||
media_type_map = {
|
||||
'mp3': 'audio/mpeg',
|
||||
'wav': 'audio/wav',
|
||||
'pcm': f'audio/L16; rate={sample_rate}; channels=1'
|
||||
}
|
||||
|
||||
# 处理Range请求逻辑
|
||||
if range:
|
||||
start =0
|
||||
end =audio_stream_len
|
||||
total_length = audio_stream_len
|
||||
return StreamingResponse(
|
||||
read_buffer(),
|
||||
status_code=206,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Range": f"bytes {start}-{end}/{total_length}",
|
||||
"Cache-Control": "no-store" # 确保不缓存动态内容
|
||||
}
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# 从缓存获取原始参数
|
||||
|
||||
if audio_stream is None:
|
||||
tts_mdl = QwenTTS(ALI_KEY, "cosyvoice-v1/longxiaochun")
|
||||
|
||||
|
||||
if audio_stream:
|
||||
logging.info("return audio stream buffer")
|
||||
# 确保流的位置在开始处
|
||||
return StreamingResponse(
|
||||
read_buffer(),
|
||||
media_type=media_type_map[format],
|
||||
headers={
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Cache-Control": "no-store",
|
||||
"Content-Disposition": "inline",
|
||||
"Access-Control-Allow-Origin": "*"
|
||||
}
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
generate_audio(),
|
||||
media_type=media_type_map[format],
|
||||
headers={
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Cache-Control": "no-store",
|
||||
"Content-Disposition": "inline"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"音频流错误: {str(e)}")
|
||||
raise HTTPException(500, detail="音频生成失败")
|
||||
|
||||
|
||||
# ------------------------ API路由 ------------------------
|
||||
@tts_router.post("/chats/{chat_id}/tts")
|
||||
async def create_tts_request(chat_id: str, request: Request):
|
||||
try:
|
||||
request_data = await request.json()
|
||||
# 获取并验证新参数
|
||||
format = request_data.get('format', FORMAT)
|
||||
if format not in ['mp3', 'wav', 'pcm']:
|
||||
raise HTTPException(400, detail="不支持的音频格式")
|
||||
|
||||
sample_rate = request_data.get('sample_rate', TTS_SAMPLERATE)
|
||||
if sample_rate not in [8000, 16000, 22050, 44100]:
|
||||
raise HTTPException(400, detail="不支持的采样率")
|
||||
|
||||
# 保持原有参数处理
|
||||
text = request_data.get("text")
|
||||
if not text or not text.strip():
|
||||
raise HTTPException(400, detail="文本内容不能为空")
|
||||
|
||||
# 存储新增参数到缓存
|
||||
audio_stream_id = str(uuid.uuid4())
|
||||
tts_info = {
|
||||
'text': text,
|
||||
'chat_id': chat_id,
|
||||
'created_at': time.time(),
|
||||
'audio_stream': None,
|
||||
'model_name': request_data.get('model_name'),
|
||||
'format': format,
|
||||
'sample_rate': sample_rate
|
||||
}
|
||||
|
||||
audio_stream_len = 0
|
||||
# 保持原有延迟生成逻辑
|
||||
if request_data.get('delay_gen_audio', False) is False:
|
||||
try:
|
||||
# 临时设置全局变量(保持原有逻辑)
|
||||
original_format = FORMAT
|
||||
original_sample_rate = TTS_SAMPLERATE
|
||||
tts = QwenTTS(ALI_KEY, tts_info.get('model_name', "cosyvoice-v1/longxiaochun"))
|
||||
buffer = io.BytesIO()
|
||||
for chunk in tts.tts(text):
|
||||
audio_stream_len = audio_stream_len + len(chunk)
|
||||
buffer.write(chunk)
|
||||
buffer.seek(0)
|
||||
tts_info['audio_stream'] = buffer
|
||||
tts_info['audio_stream_len'] = audio_stream_len
|
||||
finally:
|
||||
# 恢复全局变量
|
||||
pass
|
||||
# 保持原有缓存逻辑
|
||||
with cache_lock:
|
||||
audio_text_cache[audio_stream_id] = tts_info
|
||||
logging.info(f"create tts stream return {audio_stream_id} {audio_text_cache[audio_stream_id]}")
|
||||
# 保持原响应结构
|
||||
return {
|
||||
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
||||
"audio_stream_id": audio_stream_id,
|
||||
"sample_rate": sample_rate
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"请求处理失败: {str(e)}")
|
||||
raise HTTPException(500, detail="内部服务器错误")
|
||||
|
||||
# 辅助函数
|
||||
|
||||
# ------------------------ API路由 ------------------------
|
||||
@tts_router.post("/chats1/{chat_id}/tts")
|
||||
async def create_tts_request1(chat_id: str, request: Request):
|
||||
"""创建TTS音频流"""
|
||||
try:
|
||||
data = await request.json()
|
||||
@@ -831,43 +813,54 @@ async def create_tts_request1(chat_id: str, request: Request):
|
||||
if not text:
|
||||
raise HTTPException(400, detail="文本内容不能为空")
|
||||
|
||||
format = data.get("format", "mp3")
|
||||
format = data.get("tts_stream_format", "mp3")
|
||||
if format not in ["mp3", "wav", "pcm"]:
|
||||
raise HTTPException(400, detail="不支持的音频格式")
|
||||
|
||||
sample_rate = data.get("sample_rate", 16000)
|
||||
sample_rate = data.get("tts_sample_rate", 48000)
|
||||
|
||||
if sample_rate not in [8000, 16000, 22050, 44100]:
|
||||
raise HTTPException(400, detail="不支持的采样率")
|
||||
model_name = data.get("model_name","cosyvoice-v1/longxiaochun")
|
||||
delay_gen_audio = data.get('delay_gen_audio',False)
|
||||
format ="mp3"
|
||||
sample_rate = 44100
|
||||
|
||||
#sample_rate = 48000
|
||||
# 生成音频流
|
||||
audio_stream_id = str(uuid.uuid4())
|
||||
buffer = io.BytesIO()
|
||||
tts = QwenTTS(ALI_KEY)
|
||||
|
||||
try:
|
||||
for chunk in tts.tts(text):
|
||||
buffer.write(chunk)
|
||||
buffer.seek(0)
|
||||
except Exception as e:
|
||||
logging.error(f"TTS生成失败: {str(e)}")
|
||||
raise HTTPException(500, detail="音频生成失败")
|
||||
|
||||
# 存储到缓存
|
||||
with cache_lock:
|
||||
audio_text_cache[audio_stream_id] = {
|
||||
buffer = None
|
||||
tts_info = {
|
||||
"text" :text,
|
||||
"buffer": buffer,
|
||||
"format": format,
|
||||
"sample_rate": sample_rate,
|
||||
"created_at": datetime.datetime.now(),
|
||||
"size": buffer.getbuffer().nbytes
|
||||
"model_name" : model_name,
|
||||
"created_at": datetime.datetime.now()
|
||||
}
|
||||
|
||||
if delay_gen_audio is False:
|
||||
logging.info("--begin generate tts --")
|
||||
buffer = io.BytesIO()
|
||||
tts = QwenTTS(ALI_KEY,format,sample_rate,model_name.split("@")[0])
|
||||
try:
|
||||
for chunk in tts.tts(text):
|
||||
buffer.write(chunk)
|
||||
buffer.seek(0)
|
||||
except Exception as e:
|
||||
logging.error(f"TTS生成失败: {str(e)}")
|
||||
raise HTTPException(500, detail="音频生成失败")
|
||||
tts_info['buffer'] =buffer
|
||||
tts_info['size'] = buffer.getbuffer().nbytes
|
||||
|
||||
# 存储到缓存
|
||||
with cache_lock:
|
||||
audio_text_cache[audio_stream_id] = tts_info
|
||||
logging.info(f" tts return {audio_text_cache[audio_stream_id]}")
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content={
|
||||
"tts_url":f"/chats1/{chat_id}/tts/{audio_stream_id}",
|
||||
"url": f"/chats1/{chat_id}/tts/{audio_stream_id}",
|
||||
"tts_url":f"/chats/{chat_id}/tts/{audio_stream_id}",
|
||||
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
||||
"expires_at": (datetime.datetime.now() + timedelta(seconds=CACHE_EXPIRE_SECONDS)).isoformat()
|
||||
}
|
||||
)
|
||||
@@ -876,9 +869,9 @@ async def create_tts_request1(chat_id: str, request: Request):
|
||||
logging.error(f"请求处理失败: {str(e)}")
|
||||
raise HTTPException(500, detail="服务器内部错误")
|
||||
|
||||
|
||||
@tts_router.get("/chats1/{chat_id}/tts/{audio_stream_id}")
|
||||
async def get_tts_audio1(
|
||||
executor = ThreadPoolExecutor()
|
||||
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
|
||||
async def get_tts_audio(
|
||||
chat_id: str,
|
||||
audio_stream_id: str,
|
||||
range: str = Header(None)
|
||||
@@ -889,80 +882,134 @@ async def get_tts_audio1(
|
||||
|
||||
# 获取缓存
|
||||
with cache_lock:
|
||||
item = audio_text_cache.get(audio_stream_id)
|
||||
|
||||
if not item:
|
||||
raise HTTPException(404, detail="音频流不存在或已过期")
|
||||
tts_info = audio_text_cache.get(audio_stream_id)
|
||||
|
||||
if not tts_info:
|
||||
#raise HTTPException(404, detail="音频流不存在或已过期")
|
||||
logging.warning("音频流不存在或已过期")
|
||||
return
|
||||
# 准备响应参数
|
||||
buffer = item["buffer"]
|
||||
format = item["format"]
|
||||
total_size = item["size"]
|
||||
buffer = tts_info.get("buffer",None)
|
||||
format = tts_info.get("format","mp3")
|
||||
sample_rate = tts_info.get("sample_rate", 44100)
|
||||
total_size = tts_info.get('size',0)
|
||||
model_name = tts_info['model_name']
|
||||
text = tts_info['text']
|
||||
media_type = {
|
||||
"mp3": "audio/mpeg",
|
||||
"wav": "audio/wav",
|
||||
"pcm": f"audio/L16; rate={item['sample_rate']}; channels=1"
|
||||
"pcm": f"audio/L16; rate={tts_info['sample_rate']}; channels=1"
|
||||
}[format]
|
||||
logging.info(f"enter get_tts_audio1 buffer={buffer} {format} {sample_rate} range {range}")
|
||||
|
||||
def generate_audio():
|
||||
logging.info(f"get_tts_audio1 generate_audio {format} {sample_rate} {model_name} range{range}")
|
||||
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
|
||||
try:
|
||||
for chunk in tts.tts(text):
|
||||
if format == 'wav':
|
||||
yield add_wav_header(chunk, sample_rate)
|
||||
else:
|
||||
yield chunk
|
||||
finally:
|
||||
# 恢复全局变量
|
||||
pass
|
||||
def read_buffer():
|
||||
print(f"read_buffer {audio_stream} {audio_stream.closed}")
|
||||
audio_stream.seek(0) # 关键重置操作
|
||||
data = audio_stream.read(1024)
|
||||
while data:
|
||||
yield data
|
||||
data = audio_stream.read(1024)
|
||||
|
||||
async def generate_audio_async():
|
||||
logging.info(f"get_tts_audio1 generate_audio_async {format} {sample_rate} {model_name}")
|
||||
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
|
||||
#前端有可能传入 cosyvoice-v1/longyuan@Tongyi-Qianwen 通过.split("@")[0] 取@之前部分
|
||||
loop = asyncio.get_event_loop()
|
||||
yield_len = 0
|
||||
try:
|
||||
# 将同步函数放入线程池执行
|
||||
sync_generator = await loop.run_in_executor(executor, lambda: tts.tts(text))
|
||||
for chunk in sync_generator:
|
||||
yield_len = yield_len + len(chunk)
|
||||
yield chunk
|
||||
except (BrokenPipeError, ConnectionResetError):
|
||||
logging.warning("客户端主动断开连接")
|
||||
finally:
|
||||
# 恢复全局变量
|
||||
logging.info(f"yield len={yield_len}")
|
||||
pass
|
||||
# 处理范围请求
|
||||
if range:
|
||||
return handle_range_request(range, buffer, total_size, media_type)
|
||||
|
||||
# 完整文件响应
|
||||
return StreamingResponse(
|
||||
iter(lambda: buffer.read(4096), b""),
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total_size),
|
||||
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
|
||||
}
|
||||
)
|
||||
if buffer is not None:
|
||||
if range:
|
||||
return handle_range_request(range, buffer, total_size, media_type)
|
||||
buffer.seek(0)
|
||||
# 完整文件响应
|
||||
return StreamingResponse(
|
||||
iter(lambda: buffer.read(4096), b""),
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(total_size),
|
||||
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
|
||||
}
|
||||
)
|
||||
else:
|
||||
return StreamingResponse(
|
||||
#generate_audio(),
|
||||
generate_audio_async(),
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Transfer-Encoding": "chunked",
|
||||
"Cache-Control": "no-store",
|
||||
#"Content-Disposition": "inline",
|
||||
#"Content-Length": str(1067890),
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag":audio_stream_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def handle_range_request(range: str, buffer: BytesIO, total_size: int, media_type: str):
|
||||
"""处理HTTP范围请求"""
|
||||
def handle_range_request(range_header: str, buffer:BytesIO, total_size: int, media_type: str):
|
||||
"""处理 HTTP Range 请求"""
|
||||
try:
|
||||
unit, ranges = range.split("=")
|
||||
if unit != "bytes":
|
||||
raise ValueError
|
||||
# 解析 Range 头部 (示例: "bytes=0-1023")
|
||||
range_type, range_spec = range_header.split('=')
|
||||
if range_type != 'bytes':
|
||||
raise ValueError("Unsupported range type")
|
||||
|
||||
start_str, end_str = ranges.split("-")
|
||||
start = int(start_str) if start_str else 0
|
||||
start_str, end_str = range_spec.split('-')
|
||||
start = int(start_str)
|
||||
end = int(end_str) if end_str else total_size - 1
|
||||
|
||||
# 验证范围有效性
|
||||
if start >= total_size or end >= total_size:
|
||||
raise HTTPException(416, detail="请求范围无效", headers={
|
||||
raise HTTPException(status_code=416, headers={
|
||||
"Content-Range": f"bytes */{total_size}"
|
||||
})
|
||||
|
||||
content_length = end - start + 1
|
||||
status_code = 206
|
||||
if start>0 and end == total_size-1:
|
||||
status_code = 200
|
||||
# 设置流读取位置
|
||||
buffer.seek(start)
|
||||
content_length = end - start + 1
|
||||
|
||||
def chunk_generator():
|
||||
remaining = content_length
|
||||
while remaining > 0:
|
||||
chunk = buffer.read(min(4096, remaining))
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
remaining -= len(chunk)
|
||||
|
||||
# 返回分块响应
|
||||
return StreamingResponse(
|
||||
chunk_generator(),
|
||||
status_code=206,
|
||||
iter(lambda: buffer.read(4096), b''), # 直接使用 iter 避免状态问题
|
||||
status_code=status_code,
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{total_size}",
|
||||
"Content-Length": str(content_length),
|
||||
"Content-Type": media_type,
|
||||
"Accept-Ranges": "bytes"
|
||||
"Accept-Ranges": "bytes",
|
||||
"Cache-Control": "public, max-age=3600"
|
||||
}
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(400, detail="无效的Range头")
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def cleanup_cache():
|
||||
"""清理过期缓存"""
|
||||
@@ -971,6 +1018,7 @@ def cleanup_cache():
|
||||
expired = [k for k, v in audio_text_cache.items()
|
||||
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
|
||||
for key in expired:
|
||||
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
|
||||
del audio_text_cache[key]
|
||||
|
||||
# 应用启动时启动清理线程
|
||||
|
||||
Reference in New Issue
Block a user