Files
ragflow_python/asr-monitor-test/app/tts_service.py

1025 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import binascii
from copy import deepcopy
from timeit import default_timer as timer
import datetime
from datetime import timedelta
import threading, time,queue,uuid,time,array
from threading import Lock, Thread
from concurrent.futures import ThreadPoolExecutor
import base64, gzip
import os,io, re, json
from io import BytesIO
from typing import Optional, Dict, Any
import asyncio
from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query
from fastapi import FastAPI, UploadFile, File, Form, Header
from fastapi.responses import StreamingResponse,JSONResponse,Response
TTS_SAMPLERATE = 44100 # 22050 # 16000
FORMAT = "mp3"
CHANNELS = 1 # 单声道
SAMPLE_WIDTH = 2 # 16-bit = 2字节
tts_router = APIRouter()
# logger = logging.getLogger(__name__)
class MillisecondsFormatter(logging.Formatter):
"""自定义日志格式器,添加毫秒时间戳"""
def formatTime(self, record, datefmt=None):
# 将时间戳转换为本地时间元组
ct = self.converter(record.created)
# 格式化为 "小时:分钟:秒"
t = time.strftime("%H:%M:%S", ct)
# 添加毫秒3位
return f"{t}.{int(record.msecs):03d}"
# 配置全局日志格式
def configure_logging():
# 创建 Formatter
log_format = "%(asctime)s - %(levelname)s - %(message)s"
formatter = MillisecondsFormatter(log_format)
# 获取根 Logger 并清除已有配置
root_logger = logging.getLogger()
root_logger.handlers = []
# 创建并配置 Handler输出到控制台
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
# 设置日志级别并添加 Handler
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
# 调用配置函数(程序启动时运行一次)
configure_logging()
class StreamSessionManager:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次
self.gc_tts = 3 # 3s
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
session_id = str(uuid.uuid4())
with self.lock:
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count':0,
'finished': threading.Event(), # 添加事件对象
'sample_rate':sample_rate,
'stream_format':stream_format,
"tts_chunk_data_valid":False
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 将文本放入任务队列(非阻塞)
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
"""任务处理线程(每个会话独立)"""
while True:
session = self.sessions.get(session_id)
if not session or not session['active']:
break
try:
# 合并多个文本块最多等待50ms
texts = []
while len(texts) < 5: # 最大合并5个文本块
try:
text = session['task_queue'].get(timeout=0.05)
texts.append(text)
except queue.Empty:
break
if texts:
# 提交到线程池处理
future=self.executor.submit(
self._generate_audio,
session_id,
' '.join(texts) # 合并文本减少请求次数
)
future.result() # 等待转换任务执行完毕
# 会话超时检查
if time.time() - session['last_active'] > self.gc_interval:
self.close_session(session_id)
break
if time.time() - session['last_active'] > self.gc_tts:
session['finished'].set()
break
except Exception as e:
logging.error(f"Task processing error: {str(e)}")
def _generate_audio(self, session_id, text):
"""实际生成音频(线程池执行)"""
session = self.sessions.get(session_id)
if not session: return
# logging.info(f"_generate_audio:{text}")
first_chunk = True
# logging.info(f"转换开始!!! {text}")
try:
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk,1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
first_chunk = False
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True #20250510 增加表示连接TTS后台已经返回可以通知前端了
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
logging.info(f"--_generate_audio--error {str(e)}")
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def get_session(self, session_id):
return self.sessions.get(session_id)
stream_manager = StreamSessionManager()
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
"""
@manager.route('/tts_stream/<session_id>', methods=['GET'])
def tts_stream(session_id):
session = stream_manager.sessions.get(session_id)
logging.info(f"--tts_stream {session}")
if session is None:
return get_error_data_result(message="Audio stream not found or expired.")
def generate():
count = 0;
finished_event = session['finished']
try:
while not finished_event.is_set():
if not session or not session['active']:
break
try:
chunk = session['buffer'].get_nowait() #
count = count + 1
if isinstance(chunk, str) and chunk.startswith("ERROR"):
logging.info(f"---tts stream error!!!! {chunk}")
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
break
if session['stream_format'] == "wav":
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
yield gzip_base64_data
else:
yield chunk
retry_count = 0 # 成功收到数据重置重试计数器
except queue.Empty:
if session['stream_format'] == "wav":
pass
else:
pass
except Exception as e:
logging.info(f"tts streag get error2 {e} ")
finally:
# 确保流结束后关闭会话
if session:
# 延迟关闭会话,确保所有数据已发送
stream_manager.close_session(session_id)
logging.info(f"Session {session_id} closed.")
# 关键响应头设置
if session['stream_format'] == "wav":
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
else:
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
"""
def generate_mp3_header(
sample_rate: int,
bitrate_kbps: int,
channels: int = 1,
layer: str = "III" # 新增参数,支持 "I"/"II"/"III"
) -> bytes:
"""
动态生成 MP3 帧头4字节支持 Layer I/II/III
:param sample_rate: 采样率 (8000, 16000, 22050, 44100)
:param bitrate_kbps: 比特率(单位 kbps
:param channels: 声道数 (1: 单声道, 2: 立体声)
:param layer: 编码层 ("I", "II", "III")
:return: 4字节的帧头数据
"""
# ----------------------------------
# 参数校验
# ----------------------------------
valid_sample_rates = {8000, 16000, 22050, 44100, 48000}
if sample_rate not in valid_sample_rates:
raise ValueError(f"不支持的采样率,可选:{valid_sample_rates}")
valid_layers = {"I", "II", "III"}
if layer not in valid_layers:
raise ValueError(f"不支持的层,可选:{valid_layers}")
# ----------------------------------
# 确定 MPEG 版本和采样率索引
# ----------------------------------
if sample_rate == 44100:
mpeg_version = 0b11 # MPEG-1
sample_rate_index = 0b00
elif sample_rate == 22050:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b00
elif sample_rate == 16000:
mpeg_version = 0b10 # MPEG-2
sample_rate_index = 0b10
elif sample_rate == 8000:
mpeg_version = 0b00 # MPEG-2.5
sample_rate_index = 0b10
else:
raise ValueError("采样率与版本不匹配")
# ----------------------------------
# 动态选择比特率表(关键扩展)
# ----------------------------------
# Layer 编码映射I:0b11, II:0b10, III:0b01
layer_code = {
"I": 0b11,
"II": 0b10,
"III": 0b01
}[layer]
# 比特率表(覆盖所有 Layer
bitrate_tables = {
# -------------------------------
# MPEG-1 (0b11)
# -------------------------------
# Layer I
(0b11, 0b11): {
32: 0b0000, 64: 0b0001, 96: 0b0010, 128: 0b0011,
160: 0b0100, 192: 0b0101, 224: 0b0110, 256: 0b0111,
288: 0b1000, 320: 0b1001, 352: 0b1010, 384: 0b1011,
416: 0b1100, 448: 0b1101
},
# Layer II
(0b11, 0b10): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
160: 0b1000, 192: 0b1001, 224: 0b1010, 256: 0b1011,
320: 0b1100, 384: 0b1101
},
# Layer III
(0b11, 0b01): {
32: 0b1000, 40: 0b1001, 48: 0b1010, 56: 0b1011,
64: 0b1100, 80: 0b1101, 96: 0b1110, 112: 0b1111,
128: 0b0000, 160: 0b0001, 192: 0b0010, 224: 0b0011,
256: 0b0100, 320: 0b0101
},
# -------------------------------
# MPEG-2 (0b10) / MPEG-2.5 (0b00)
# -------------------------------
# Layer I
(0b10, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
(0b00, 0b11): {
32: 0b0000, 48: 0b0001, 56: 0b0010, 64: 0b0011,
80: 0b0100, 96: 0b0101, 112: 0b0110, 128: 0b0111,
144: 0b1000, 160: 0b1001, 176: 0b1010, 192: 0b1011,
224: 0b1100, 256: 0b1101
},
# Layer II
(0b10, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
(0b00, 0b10): {
8: 0b0000, 16: 0b0001, 24: 0b0010, 32: 0b0011,
40: 0b0100, 48: 0b0101, 56: 0b0110, 64: 0b0111,
80: 0b1000, 96: 0b1001, 112: 0b1010, 128: 0b1011,
144: 0b1100, 160: 0b1101
},
# Layer III
(0b10, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111,
80: 0b0000, 96: 0b0001, 112: 0b0010, 128: 0b0011,
144: 0b0100, 160: 0b0101
},
(0b00, 0b01): {
8: 0b1000, 16: 0b1001, 24: 0b1010, 32: 0b1011,
40: 0b1100, 48: 0b1101, 56: 0b1110, 64: 0b1111
}
}
# 获取当前版本的比特率表
key = (mpeg_version, layer_code)
if key not in bitrate_tables:
raise ValueError(f"不支持的版本和层组合: MPEG={mpeg_version}, Layer={layer}")
bitrate_table = bitrate_tables[key]
if bitrate_kbps not in bitrate_table:
raise ValueError(f"不支持的比特率,可选:{list(bitrate_table.keys())}")
bitrate_index = bitrate_table[bitrate_kbps]
# ----------------------------------
# 确定声道模式
# ----------------------------------
if channels == 1:
channel_mode = 0b11 # 单声道
elif channels == 2:
channel_mode = 0b00 # 立体声
else:
raise ValueError("声道数必须为1或2")
# ----------------------------------
# 组合帧头字段(修正层编码)
# ----------------------------------
sync = 0x7FF << 21 # 同步字 11位 (0x7FF = 0b11111111111)
version = mpeg_version << 19 # MPEG 版本 2位
layer_bits = layer_code << 17 # Layer 编码I:0b11, II:0b10, III:0b01
protection = 0 << 16 # 无 CRC
bitrate_bits = bitrate_index << 12
sample_rate_bits = sample_rate_index << 10
padding = 0 << 9 # 无填充
private = 0 << 8
mode = channel_mode << 6
mode_ext = 0 << 4 # 扩展模式(单声道无需设置)
copyright = 0 << 3
original = 0 << 2
emphasis = 0b00 # 无强调
frame_header = (
sync |
version |
layer_bits |
protection |
bitrate_bits |
sample_rate_bits |
padding |
private |
mode |
mode_ext |
copyright |
original |
emphasis
)
return frame_header.to_bytes(4, byteorder='big')
# ------------------------------------------------
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()
def parse_markdown_json(json_string):
# 使用正则表达式匹配Markdown中的JSON代码块
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
if match:
try:
# 尝试解析JSON字符串
data = json.loads(match[1])
return {'success': True, 'data': data}
except json.JSONDecodeError as e:
# 如果解析失败,返回错误信息
return {'success': False, 'data': str(e)}
else:
return {'success': False, 'data': 'not a valid markdown json string'}
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def encode_gzip_base64(original_data: bytes) -> str:
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
# Step 1: Gzip 压缩
with BytesIO() as buf:
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
gz_file.write(original_data)
compressed_bytes = buf.getvalue()
# Step 2: Base64 编码配置与Android端匹配
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符等同于Android的Base64.NO_WRAP
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
now = time.time()
expired_keys = [
k for k, v in audio_text_cache.items()
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
]
for k in expired_keys:
entry = audio_text_cache.pop(k, None)
if entry and entry.get('audio_stream'):
entry['audio_stream'].close()
def start_background_cleaner():
"""启动后台清理线程"""
def cleaner_loop():
while True:
time.sleep(180) # 每3分钟清理一次
clean_audio_cache()
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
cleaner_thread.start()
def test_qwen_chat():
messages = [
{'role': 'system', 'content': 'You are a helpful assistant.'},
{'role': 'user', 'content': '你是谁?'}
]
response = Generation.call(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key = "sk-xxx",
api_key=ALI_KEY,
model="qwen-plus", # 模型列表https://help.aliyun.com/zh/model-studio/getting-started/models
messages=messages,
result_format="message"
)
if response.status_code == 200:
print(response.output.choices[0].message.content)
else:
print(f"HTTP返回码{response.status_code}")
print(f"错误码:{response.code}")
print(f"错误信息:{response.message}")
print("请参考文档https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
class QwenTTS:
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longxiaochun", base_url=""):
import dashscope
import ssl
print("---begin--init QwenTTS--") # cyx
self.model_name = model_name
dashscope.api_key = key
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
self.synthesizer = None
self.callback = None
self.is_cosyvoice = False
self.voice = ""
self.format = format
self.sample_rate = sample_rate
if '/' in model_name:
parts = model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1':
self.is_cosyvoice = True
self.voice = parts[1]
def tts(self, text):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
if self.is_cosyvoice is False:
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
from collections import deque
else:
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat # , SpeechSynthesisResult
from dashscope.audio.tts import SpeechSynthesisResult
from collections import deque
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
# --------------------------
class Callback_v2(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
logging.info("Qwen tts open")
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
# print("---Qwen call back close") # cyx
logging.info("Qwen tts close")
pass
""" canceled for test 语音大模型CosyVoice
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
"""
def on_event(self, message):
# print(f"recv speech synthsis message {message}")
pass
# 以下适合语音大模型CosyVoice
def on_data(self, data: bytes) -> None:
if len(data) > 0:
self.dque.append(data)
# --------------------------
# text = self.normalize_text(text)
try:
# if self.model_name != 'cosyvoice-v1':
if self.is_cosyvoice is False:
self.callback = Callback()
SpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format="wav") # format="mp3")
else:
self.callback = Callback_v2()
format =self.get_audio_format(self.format,self.sample_rate)
self.synthesizer = SpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
self.synthesizer.call(text)
except Exception as e:
print(f"---dale---20 error {e}") # cyx
# -----------------------------------
try:
for data in self.callback._run():
#logging.info(f"dashcope return data {len(data)}")
yield data
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def get_audio_format(self, format: str, sample_rate: int):
"""动态获取音频格式"""
from dashscope.audio.tts_v2 import AudioFormat
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
(48800, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
(48800, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
(48800, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
def end_tts(self):
if self.synthesizer:
self.synthesizer.streaming_complete()
@tts_router.get("/audio/pcm_mp3")
async def stream_mp3():
def audio_generator():
path = './test.mp3'
try:
with open(path, 'rb') as f:
while True:
chunk = f.read(1024)
if not chunk:
break
yield chunk
except Exception as e:
logging.error(f"MP3 streaming error: {str(e)}")
return StreamingResponse(
audio_generator(),
media_type="audio/mpeg",
headers={
"Cache-Control": "no-store",
"Accept-Ranges": "bytes"
}
)
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
"""动态生成WAV头严格保持原有逻辑结构"""
with BytesIO() as wav_buffer:
with wave.open(wav_buffer, 'wb') as wav_file:
wav_file.setnchannels(1) # 保持原单声道设置
wav_file.setsampwidth(2) # 保持原16-bit设置
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm_data)
wav_buffer.seek(0)
return wav_buffer.read()
def generate_silence_header(duration_ms: int = 500) -> bytes:
"""生成静音数据用于MP3流式传输预缓冲"""
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
# ------------------------ API路由 ------------------------
@tts_router.post("/chats/{chat_id}/tts")
async def create_tts_request(chat_id: str, request: Request):
"""创建TTS音频流"""
try:
data = await request.json()
logging.info(f"API--create_tts_request1-- {data}")
# 参数校验
text = data.get("text", "").strip()
if not text:
raise HTTPException(400, detail="文本内容不能为空")
format = data.get("tts_stream_format", "mp3")
if format not in ["mp3", "wav", "pcm"]:
raise HTTPException(400, detail="不支持的音频格式")
sample_rate = data.get("tts_sample_rate", 48000)
if sample_rate not in [8000, 16000, 22050, 44100]:
raise HTTPException(400, detail="不支持的采样率")
model_name = data.get("model_name","cosyvoice-v1/longxiaochun")
delay_gen_audio = data.get('delay_gen_audio',False)
format ="mp3"
#sample_rate = 48000
# 生成音频流
audio_stream_id = str(uuid.uuid4())
buffer = None
tts_info = {
"text" :text,
"buffer": buffer,
"format": format,
"sample_rate": sample_rate,
"model_name" : model_name,
"created_at": datetime.datetime.now()
}
if delay_gen_audio is False:
logging.info("--begin generate tts --")
buffer = io.BytesIO()
tts = QwenTTS(ALI_KEY,format,sample_rate,model_name.split("@")[0])
try:
for chunk in tts.tts(text):
buffer.write(chunk)
buffer.seek(0)
except Exception as e:
logging.error(f"TTS生成失败: {str(e)}")
raise HTTPException(500, detail="音频生成失败")
tts_info['buffer'] =buffer
tts_info['size'] = buffer.getbuffer().nbytes
# 存储到缓存
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
logging.info(f" tts return {audio_text_cache[audio_stream_id]}")
return JSONResponse(
status_code=200,
content={
"tts_url":f"/chats/{chat_id}/tts/{audio_stream_id}",
"url": f"/chats/{chat_id}/tts/{audio_stream_id}",
"expires_at": (datetime.datetime.now() + timedelta(seconds=CACHE_EXPIRE_SECONDS)).isoformat()
}
)
except Exception as e:
logging.error(f"请求处理失败: {str(e)}")
raise HTTPException(500, detail="服务器内部错误")
executor = ThreadPoolExecutor()
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
async def get_tts_audio(
chat_id: str,
audio_stream_id: str,
range: str = Header(None)
):
"""获取音频流"""
# 清理过期缓存
cleanup_cache()
# 获取缓存
with cache_lock:
tts_info = audio_text_cache.get(audio_stream_id)
if not tts_info:
#raise HTTPException(404, detail="音频流不存在或已过期")
logging.warning("音频流不存在或已过期")
return
# 准备响应参数
buffer = tts_info.get("buffer",None)
format = tts_info.get("format","mp3")
sample_rate = tts_info.get("sample_rate", 44100)
total_size = tts_info.get('size',0)
model_name = tts_info['model_name']
text = tts_info['text']
media_type = {
"mp3": "audio/mpeg",
"wav": "audio/wav",
"pcm": f"audio/L16; rate={tts_info['sample_rate']}; channels=1"
}[format]
logging.info(f"enter get_tts_audio1 buffer={buffer} {format} {sample_rate} range {range}")
def generate_audio():
logging.info(f"get_tts_audio1 generate_audio {format} {sample_rate} {model_name} range{range}")
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
try:
for chunk in tts.tts(text):
if format == 'wav':
yield add_wav_header(chunk, sample_rate)
else:
yield chunk
finally:
# 恢复全局变量
pass
def read_buffer():
print(f"read_buffer {audio_stream} {audio_stream.closed}")
audio_stream.seek(0) # 关键重置操作
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
async def generate_audio_async():
logging.info(f"get_tts_audio1 generate_audio_async {format} {sample_rate} {model_name}")
tts = QwenTTS(ALI_KEY, format , sample_rate,model_name.split("@")[0])
#前端有可能传入 cosyvoice-v1/longyuan@Tongyi-Qianwen 通过.split("@")[0] 取@之前部分
loop = asyncio.get_event_loop()
yield_len = 0
try:
# 将同步函数放入线程池执行
sync_generator = await loop.run_in_executor(executor, lambda: tts.tts(text))
for chunk in sync_generator:
yield_len = yield_len + len(chunk)
yield chunk
except (BrokenPipeError, ConnectionResetError):
logging.warning("客户端主动断开连接")
finally:
# 恢复全局变量
logging.info(f"yield len={yield_len}")
pass
# 处理范围请求
if buffer is not None:
if range:
return handle_range_request(range, buffer, total_size, media_type)
buffer.seek(0)
# 完整文件响应
return StreamingResponse(
iter(lambda: buffer.read(4096), b""),
media_type=media_type,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(total_size),
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
}
)
else:
return StreamingResponse(
#generate_audio(),
generate_audio_async(),
media_type=media_type,
headers={
"Transfer-Encoding": "chunked",
"Cache-Control": "no-store",
#"Content-Disposition": "inline",
#"Content-Length": str(1067890),
"Accept-Ranges": "bytes",
"ETag":audio_stream_id,
}
)
def handle_range_request(range_header: str, buffer:BytesIO, total_size: int, media_type: str):
"""处理 HTTP Range 请求"""
try:
# 解析 Range 头部 (示例: "bytes=0-1023")
range_type, range_spec = range_header.split('=')
if range_type != 'bytes':
raise ValueError("Unsupported range type")
start_str, end_str = range_spec.split('-')
start = int(start_str)
end = int(end_str) if end_str else total_size - 1
# 验证范围有效性
if start >= total_size or end >= total_size:
raise HTTPException(status_code=416, headers={
"Content-Range": f"bytes */{total_size}"
})
status_code = 206
if start>0 and end == total_size-1:
status_code = 200
# 设置流读取位置
buffer.seek(start)
content_length = end - start + 1
# 返回分块响应
return StreamingResponse(
iter(lambda: buffer.read(4096), b''), # 直接使用 iter 避免状态问题
status_code=status_code,
media_type=media_type,
headers={
"Content-Range": f"bytes {start}-{end}/{total_size}",
"Content-Length": str(content_length),
"Accept-Ranges": "bytes",
"Cache-Control": "public, max-age=3600"
}
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def cleanup_cache():
"""清理过期缓存"""
with cache_lock:
now = datetime.datetime.now()
expired = [k for k, v in audio_text_cache.items()
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
for key in expired:
logging.info(f"del audio_text_cache= {audio_text_cache[key]}")
del audio_text_cache[key]
# 应用启动时启动清理线程
# start_background_cleaner()