Files
ragflow_python/asr-monitor-test/app/tts_service.py

977 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import binascii
from copy import deepcopy
from timeit import default_timer as timer
import datetime
from datetime import timedelta
import threading, time,queue,uuid,time,array
from threading import Lock, Thread
from concurrent.futures import ThreadPoolExecutor
import base64, gzip
import os,io, re, json
from io import BytesIO
from typing import Optional, Dict, Any
from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query
from fastapi import FastAPI, UploadFile, File, Form, Header
from fastapi.responses import StreamingResponse,JSONResponse
TTS_SAMPLERATE = 44100 # 22050 # 16000
FORMAT = "mp3"
CHANNELS = 1 # 单声道
SAMPLE_WIDTH = 2 # 16-bit = 2字节
tts_router = APIRouter()
logger = logging.getLogger(__name__)
class StreamSessionManager:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次
self.gc_tts = 3 # 3s
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
session_id = str(uuid.uuid4())
with self.lock:
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count':0,
'finished': threading.Event(), # 添加事件对象
'sample_rate':sample_rate,
'stream_format':stream_format,
"tts_chunk_data_valid":False
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 将文本放入任务队列(非阻塞)
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
"""任务处理线程(每个会话独立)"""
while True:
session = self.sessions.get(session_id)
if not session or not session['active']:
break
try:
# 合并多个文本块最多等待50ms
texts = []
while len(texts) < 5: # 最大合并5个文本块
try:
text = session['task_queue'].get(timeout=0.05)
texts.append(text)
except queue.Empty:
break
if texts:
# 提交到线程池处理
future=self.executor.submit(
self._generate_audio,
session_id,
' '.join(texts) # 合并文本减少请求次数
)
future.result() # 等待转换任务执行完毕
# 会话超时检查
if time.time() - session['last_active'] > self.gc_interval:
self.close_session(session_id)
break
if time.time() - session['last_active'] > self.gc_tts:
session['finished'].set()
break
except Exception as e:
logging.error(f"Task processing error: {str(e)}")
def _generate_audio(self, session_id, text):
"""实际生成音频(线程池执行)"""
session = self.sessions.get(session_id)
if not session: return
# logging.info(f"_generate_audio:{text}")
first_chunk = True
# logging.info(f"转换开始!!! {text}")
try:
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk,1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
first_chunk = False
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True #20250510 增加表示连接TTS后台已经返回可以通知前端了
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
logging.info(f"--_generate_audio--error {str(e)}")
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def get_session(self, session_id):
return self.sessions.get(session_id)
stream_manager = StreamSessionManager()
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
"""
@manager.route('/tts_stream/<session_id>', methods=['GET'])
def tts_stream(session_id):
session = stream_manager.sessions.get(session_id)
logging.info(f"--tts_stream {session}")
if session is None:
return get_error_data_result(message="Audio stream not found or expired.")
def generate():
count = 0;
finished_event = session['finished']
try:
while not finished_event.is_set():
if not session or not session['active']:
break
try:
chunk = session['buffer'].get_nowait() #
count = count + 1
if isinstance(chunk, str) and chunk.startswith("ERROR"):
logging.info(f"---tts stream error!!!! {chunk}")
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
break
if session['stream_format'] == "wav":
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
yield gzip_base64_data
else:
yield chunk
retry_count = 0 # 成功收到数据重置重试计数器
except queue.Empty:
if session['stream_format'] == "wav":
pass
else:
pass
except Exception as e:
logging.info(f"tts streag get error2 {e} ")
finally:
# 确保流结束后关闭会话
if session:
# 延迟关闭会话,确保所有数据已发送
stream_manager.close_session(session_id)
logging.info(f"Session {session_id} closed.")
# 关键响应头设置
if session['stream_format'] == "wav":
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
else:
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
"""
def generate_mp3_header(bitrate_kbps=128, padding=0):
# 字段定义
sync = 0b11111111111 # 同步字11位
version = 0b11 # MPEG-12位
layer = 0b01 # Layer III2位
protection = 0b0 # 无CRC1位
bitrate_index = { # 比特率索引表MPEG-1 Layer III
32: 0b0001, 40: 0b0010, 48: 0b0011, 56: 0b0100,
64: 0b0101, 80: 0b0110, 96: 0b0111, 112: 0b1000,
128: 0b1001, 160: 0b1010, 192: 0b1011, 224: 0b1100,
256: 0b1101, 320: 0b1110
}[bitrate_kbps]
sampling_rate = 0b00 # 44.1kHz2位
padding_bit = padding # 填充位1位
private = 0b0 # 私有位1位
mode = 0b11 # 单声道2位
mode_ext = 0b00 # 扩展模式2位
copyright = 0b0 # 无版权1位
original = 0b0 # 非原版1位
emphasis = 0b00 # 无强调2位
# 组合为32位整数大端序
header = (
(sync << 21) |
(version << 19) |
(layer << 17) |
(protection << 16) |
(bitrate_index << 12) |
(sampling_rate << 10) |
(padding_bit << 9) |
(private << 8) |
(mode << 6) |
(mode_ext << 4) |
(copyright << 3) |
(original << 2) |
emphasis
)
# 转换为4字节二进制数据
return header.to_bytes(4, byteorder='big')
# ------------------------------------------------
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()
def parse_markdown_json(json_string):
# 使用正则表达式匹配Markdown中的JSON代码块
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
if match:
try:
# 尝试解析JSON字符串
data = json.loads(match[1])
return {'success': True, 'data': data}
except json.JSONDecodeError as e:
# 如果解析失败,返回错误信息
return {'success': False, 'data': str(e)}
else:
return {'success': False, 'data': 'not a valid markdown json string'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def encode_gzip_base64(original_data: bytes) -> str:
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
# Step 1: Gzip 压缩
with BytesIO() as buf:
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
gz_file.write(original_data)
compressed_bytes = buf.getvalue()
# Step 2: Base64 编码配置与Android端匹配
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符等同于Android的Base64.NO_WRAP
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
now = time.time()
expired_keys = [
k for k, v in audio_text_cache.items()
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
]
for k in expired_keys:
entry = audio_text_cache.pop(k, None)
if entry and entry.get('audio_stream'):
entry['audio_stream'].close()
def start_background_cleaner():
"""启动后台清理线程"""
def cleaner_loop():
while True:
time.sleep(180) # 每3分钟清理一次
clean_audio_cache()
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
cleaner_thread.start()
def test_qwen_chat():
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': '你是谁?'}
]
response = Generation.call(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key = "sk-xxx",
api_key=ALI_KEY,
model="qwen-plus", # 模型列表https://help.aliyun.com/zh/model-studio/getting-started/models
messages=messages,
result_format="message"
)
if response.status_code == 200:
print(response.output.choices[0].message.content)
else:
print(f"HTTP返回码{response.status_code}")
print(f"错误码:{response.code}")
print(f"错误信息:{response.message}")
print("请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
class QwenTTS:
def __init__(self, key, model_name="cosyvoice-v1/longxiaochun", base_url=""):
import dashscope
import ssl
print("---begin--init QwenTTS--") # cyx
self.model_name = model_name
dashscope.api_key = key
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
self.synthesizer = None
self.callback = None
self.is_cosyvoice = False
self.voice = ""
if '/' in model_name:
parts = model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1':
self.is_cosyvoice = True
self.voice = parts[1]
def tts(self, text):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
if self.is_cosyvoice is False:
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque
else:
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat # , SpeechSynthesisResult
from dashscope.audio.tts import SpeechSynthesisResult
from collections import deque
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
# --------------------------
class Callback_v2(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
logging.info("Qwen tts open")
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
# print("---Qwen call back close") # cyx
logging.info("Qwen tts close")
pass
""" canceled for test 语音大模型CosyVoice
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
"""
def on_event(self, message):
# print(f"recv speech synthsis message {message}")
pass
# 以下适合语音大模型CosyVoice
def on_data(self, data: bytes) -> None:
if len(data) > 0:
self.dque.append(data)
# --------------------------
# text = self.normalize_text(text)
try:
# if self.model_name != 'cosyvoice-v1':
if self.is_cosyvoice is False:
self.callback = Callback()
SpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format="wav") # format="mp3")
else:
self.callback = Callback_v2()
format =self.get_audio_format(FORMAT,TTS_SAMPLERATE)
self.synthesizer = SpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
self.synthesizer.call(text)
except Exception as e:
print(f"---dale---20 error {e}") # cyx
# -----------------------------------
try:
for data in self.callback._run():
yield data
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def get_audio_format(self, format: str, sample_rate: int):
"""动态获取音频格式"""
from dashscope.audio.tts_v2 import AudioFormat
logging.info(f"--get_audio_format-- {format} {sample_rate}")
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
def end_tts(self):
if self.synthesizer:
self.synthesizer.streaming_complete()
@tts_router.get("/audio/pcm_mp3")
async def stream_mp3():
def audio_generator():
path = './test.mp3'
try:
with open(path, 'rb') as f:
while True:
chunk = f.read(1024)
if not chunk:
break
yield chunk
except Exception as e:
logging.error(f"MP3 streaming error: {str(e)}")
return StreamingResponse(
audio_generator(),
media_type="audio/mpeg",
headers={"Cache-Control": "no-store"}
)
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
"""动态生成WAV头严格保持原有逻辑结构"""
with BytesIO() as wav_buffer:
with wave.open(wav_buffer, 'wb') as wav_file:
wav_file.setnchannels(1) # 保持原单声道设置
wav_file.setsampwidth(2) # 保持原16-bit设置
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
wav_buffer.seek(0)
return wav_buffer.read()
def generate_silence_header(duration_ms: int = 500) -> bytes:
"""生成静音数据用于MP3流式传输预缓冲"""
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio(
chat_id: str,
audio_stream_id: str,
request: Request,
range: str = Header(None) # 新增Range头解析
):
with cache_lock:
# tts_info = audio_text_cache.pop(audio_stream_id, None)
tts_info = audio_text_cache.get(audio_stream_id, None)
if not tts_info:
raise HTTPException(404, detail="音频流已过期")
audio_stream_len = tts_info.get('audio_stream_len', 0)
audio_stream = tts_info['audio_stream']
text = tts_info['text']
model_name = tts_info.get('model_name', "cosyvoice-v1/longxiaochun")
format = tts_info['format'] # 新增字段
sample_rate = tts_info['sample_rate'] # 新增字段
def stream_audio():
total = 0
try:
for chunk in tts_mdl.tts(text):
# print(f"data_length={total} {chunk}")
hex_data = binascii.hexlify(chunk).decode("utf-8") + "\r\n" # <--- 添加\n
# yield f"{hex_data}" # <--- SSE格式封装
print(f"yield {len(chunk)}")
# yield chunk
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
total = total + len(gzip_base64_data)
yield gzip_base64_data
print(f"tts gen end {total}")
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
def read_buffer():
print(f"get audio_stream {audio_stream} {audio_stream.closed}")
audio_stream.seek(0) # 关键重置操作
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
def generate_silence_header():
"""生成500ms的静音MP3帧约44100Hz"""
# 示例静音数据实际需生成标准MP3静音帧
return b'\xff\xfb\xd4\x00' * 100 # 约400字节
# 保持原有生成器结构
def generate_audio():
# 保持原有QwenTTS调用方式
tts = QwenTTS(ALI_KEY, model_name)
# 临时修改全局配置(保持原有逻辑)
global FORMAT, TTS_SAMPLERATE
original_format = FORMAT
original_sample_rate = TTS_SAMPLERATE
FORMAT = format
TTS_SAMPLERATE = sample_rate
try:
for chunk in tts.tts(text):
if format == 'wav':
yield add_wav_header(chunk, sample_rate)
else:
yield chunk
finally:
# 恢复全局变量
FORMAT = original_format
TTS_SAMPLERATE = original_sample_rate
# 保持原有流响应逻辑
media_type_map = {
'mp3': 'audio/mpeg',
'wav': 'audio/wav',
'pcm': f'audio/L16; rate={sample_rate}; channels=1'
}
# 处理Range请求逻辑
if range:
start =0
end =audio_stream_len
total_length = audio_stream_len
return StreamingResponse(
read_buffer(),
status_code=206,
headers={
"Accept-Ranges": "bytes",
"Content-Range": f"bytes {start}-{end}/{total_length}",
"Cache-Control": "no-store" # 确保不缓存动态内容
}
)
else:
try:
# 从缓存获取原始参数
if audio_stream is None:
tts_mdl = QwenTTS(ALI_KEY, "cosyvoice-v1/longxiaochun")
if audio_stream:
logging.info("return audio stream buffer")
# 确保流的位置在开始处
return StreamingResponse(
read_buffer(),
media_type=media_type_map[format],
headers={
"Transfer-Encoding": "chunked",
"Cache-Control": "no-store",
"Content-Disposition": "inline",
"Access-Control-Allow-Origin": "*"
}
)
else:
return StreamingResponse(
generate_audio(),
media_type=media_type_map[format],
headers={
"Transfer-Encoding": "chunked",
"Cache-Control": "no-store",
"Content-Disposition": "inline"
}
)
except Exception as e:
logging.error(f"音频流错误: {str(e)}")
raise HTTPException(500, detail="音频生成失败")
@tts_router.post("/chats/{chat_id}/tts")
async def create_tts_request(chat_id: str, request: Request):
try:
request_data = await request.json()
# 获取并验证新参数
format = request_data.get('format', FORMAT)
if format not in ['mp3', 'wav', 'pcm']:
raise HTTPException(400, detail="不支持的音频格式")
sample_rate = request_data.get('sample_rate', TTS_SAMPLERATE)
if sample_rate not in [8000, 16000, 22050, 44100]:
raise HTTPException(400, detail="不支持的采样率")
# 保持原有参数处理
text = request_data.get("text")
if not text or not text.strip():
raise HTTPException(400, detail="文本内容不能为空")
# 存储新增参数到缓存
audio_stream_id = str(uuid.uuid4())
tts_info = {
'text': text,
'chat_id': chat_id,
'created_at': time.time(),
'audio_stream': None,
'model_name': request_data.get('model_name'),
'format': format,
'sample_rate': sample_rate
}
audio_stream_len = 0
# 保持原有延迟生成逻辑
if request_data.get('delay_gen_audio', False) is False:
try:
# 临时设置全局变量(保持原有逻辑)
original_format = FORMAT
original_sample_rate = TTS_SAMPLERATE
tts = QwenTTS(ALI_KEY, tts_info.get('model_name', "cosyvoice-v1/longxiaochun"))
buffer = io.BytesIO()
for chunk in tts.tts(text):
audio_stream_len = audio_stream_len + len(chunk)
buffer.write(chunk)
buffer.seek(0)
tts_info['audio_stream'] = buffer
tts_info['audio_stream_len'] = audio_stream_len
finally:
# 恢复全局变量
pass
# 保持原有缓存逻辑
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
logging.info(f"create tts stream return {audio_stream_id} {audio_text_cache[audio_stream_id]}")
# 保持原响应结构
return {
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"audio_stream_id": audio_stream_id,
"sample_rate": sample_rate
}
except Exception as e:
logging.error(f"请求处理失败: {str(e)}")
raise HTTPException(500, detail="内部服务器错误")
# 辅助函数
# ------------------------ API路由 ------------------------
@tts_router.post("/chats1/{chat_id}/tts")
async def create_tts_request1(chat_id: str, request: Request):
"""创建TTS音频流"""
try:
data = await request.json()
logging.info(f"API--create_tts_request1-- {data}")
# 参数校验
text = data.get("text", "").strip()
if not text:
raise HTTPException(400, detail="文本内容不能为空")
format = data.get("format", "mp3")
if format not in ["mp3", "wav", "pcm"]:
raise HTTPException(400, detail="不支持的音频格式")
sample_rate = data.get("sample_rate", 16000)
if sample_rate not in [8000, 16000, 22050, 44100]:
raise HTTPException(400, detail="不支持的采样率")
format ="mp3"
sample_rate = 44100
# 生成音频流
audio_stream_id = str(uuid.uuid4())
buffer = io.BytesIO()
tts = QwenTTS(ALI_KEY)
try:
for chunk in tts.tts(text):
buffer.write(chunk)
buffer.seek(0)
except Exception as e:
logging.error(f"TTS生成失败: {str(e)}")
raise HTTPException(500, detail="音频生成失败")
# 存储到缓存
with cache_lock:
audio_text_cache[audio_stream_id] = {
"buffer": buffer,
"format": format,
"sample_rate": sample_rate,
"created_at": datetime.datetime.now(),
"size": buffer.getbuffer().nbytes
}
return JSONResponse(
status_code=200,
content={
"tts_url":f"/chats1/{chat_id}/tts/{audio_stream_id}",
"url": f"/chats1/{chat_id}/tts/{audio_stream_id}",
"expires_at": (datetime.datetime.now() + timedelta(seconds=CACHE_EXPIRE_SECONDS)).isoformat()
}
)
except Exception as e:
logging.error(f"请求处理失败: {str(e)}")
raise HTTPException(500, detail="服务器内部错误")
@tts_router.get("/chats1/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio1(
chat_id: str,
audio_stream_id: str,
range: str = Header(None)
):
"""获取音频流"""
# 清理过期缓存
cleanup_cache()
# 获取缓存
with cache_lock:
item = audio_text_cache.get(audio_stream_id)
if not item:
raise HTTPException(404, detail="音频流不存在或已过期")
# 准备响应参数
buffer = item["buffer"]
format = item["format"]
total_size = item["size"]
media_type = {
"mp3": "audio/mpeg",
"wav": "audio/wav",
"pcm": f"audio/L16; rate={item['sample_rate']}; channels=1"
}[format]
# 处理范围请求
if range:
return handle_range_request(range, buffer, total_size, media_type)
# 完整文件响应
return StreamingResponse(
iter(lambda: buffer.read(4096), b""),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total_size),
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
}
)
def handle_range_request(range: str, buffer: BytesIO, total_size: int, media_type: str):
"""处理HTTP范围请求"""
try:
unit, ranges = range.split("=")
if unit != "bytes":
raise ValueError
start_str, end_str = ranges.split("-")
start = int(start_str) if start_str else 0
end = int(end_str) if end_str else total_size - 1
# 验证范围有效性
if start >= total_size or end >= total_size:
raise HTTPException(416, detail="请求范围无效", headers={
"Content-Range": f"bytes */{total_size}"
})
content_length = end - start + 1
buffer.seek(start)
def chunk_generator():
remaining = content_length
while remaining > 0:
chunk = buffer.read(min(4096, remaining))
if not chunk:
break
yield chunk
remaining -= len(chunk)
return StreamingResponse(
chunk_generator(),
status_code=206,
headers={
"Content-Range": f"bytes {start}-{end}/{total_size}",
"Content-Length": str(content_length),
"Content-Type": media_type,
"Accept-Ranges": "bytes"
}
)
except ValueError:
raise HTTPException(400, detail="无效的Range头")
def cleanup_cache():
"""清理过期缓存"""
with cache_lock:
now = datetime.datetime.now()
expired = [k for k, v in audio_text_cache.items()
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
for key in expired:
del audio_text_cache[key]
# 应用启动时启动清理线程
# start_background_cleaner()