977 lines
36 KiB
Python
977 lines
36 KiB
Python
|
|
import logging
|
|||
|
|
import binascii
|
|||
|
|
from copy import deepcopy
|
|||
|
|
from timeit import default_timer as timer
|
|||
|
|
import datetime
|
|||
|
|
from datetime import timedelta
|
|||
|
|
import threading, time,queue,uuid,time,array
|
|||
|
|
from threading import Lock, Thread
|
|||
|
|
from concurrent.futures import ThreadPoolExecutor
|
|||
|
|
import base64, gzip
|
|||
|
|
import os,io, re, json
|
|||
|
|
from io import BytesIO
|
|||
|
|
from typing import Optional, Dict, Any
|
|||
|
|
|
|||
|
|
from fastapi import WebSocket, APIRouter,WebSocketDisconnect,Request,Body,Query
|
|||
|
|
from fastapi import FastAPI, UploadFile, File, Form, Header
|
|||
|
|
from fastapi.responses import StreamingResponse,JSONResponse
|
|||
|
|
TTS_SAMPLERATE = 44100 # 22050 # 16000
|
|||
|
|
FORMAT = "mp3"
|
|||
|
|
CHANNELS = 1 # 单声道
|
|||
|
|
SAMPLE_WIDTH = 2 # 16-bit = 2字节
|
|||
|
|
|
|||
|
|
tts_router = APIRouter()
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
class StreamSessionManager:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
|
|||
|
|
self.lock = threading.Lock()
|
|||
|
|
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
|||
|
|
self.gc_interval = 300 # 5分钟清理一次
|
|||
|
|
self.gc_tts = 3 # 3s
|
|||
|
|
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
|
|||
|
|
session_id = str(uuid.uuid4())
|
|||
|
|
with self.lock:
|
|||
|
|
self.sessions[session_id] = {
|
|||
|
|
'tts_model': tts_model,
|
|||
|
|
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
|||
|
|
'task_queue': queue.Queue(),
|
|||
|
|
'active': True,
|
|||
|
|
'last_active': time.time(),
|
|||
|
|
'audio_chunk_count':0,
|
|||
|
|
'finished': threading.Event(), # 添加事件对象
|
|||
|
|
'sample_rate':sample_rate,
|
|||
|
|
'stream_format':stream_format,
|
|||
|
|
"tts_chunk_data_valid":False
|
|||
|
|
}
|
|||
|
|
# 启动任务处理线程
|
|||
|
|
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
|||
|
|
return session_id
|
|||
|
|
|
|||
|
|
def append_text(self, session_id, text):
|
|||
|
|
with self.lock:
|
|||
|
|
session = self.sessions.get(session_id)
|
|||
|
|
if not session: return
|
|||
|
|
# 将文本放入任务队列(非阻塞)
|
|||
|
|
try:
|
|||
|
|
session['task_queue'].put(text, block=False)
|
|||
|
|
except queue.Full:
|
|||
|
|
logging.warning(f"Session {session_id} task queue full")
|
|||
|
|
|
|||
|
|
def _process_tasks(self, session_id):
|
|||
|
|
"""任务处理线程(每个会话独立)"""
|
|||
|
|
while True:
|
|||
|
|
session = self.sessions.get(session_id)
|
|||
|
|
if not session or not session['active']:
|
|||
|
|
break
|
|||
|
|
try:
|
|||
|
|
# 合并多个文本块(最多等待50ms)
|
|||
|
|
texts = []
|
|||
|
|
while len(texts) < 5: # 最大合并5个文本块
|
|||
|
|
try:
|
|||
|
|
text = session['task_queue'].get(timeout=0.05)
|
|||
|
|
texts.append(text)
|
|||
|
|
except queue.Empty:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if texts:
|
|||
|
|
# 提交到线程池处理
|
|||
|
|
future=self.executor.submit(
|
|||
|
|
self._generate_audio,
|
|||
|
|
session_id,
|
|||
|
|
' '.join(texts) # 合并文本减少请求次数
|
|||
|
|
)
|
|||
|
|
future.result() # 等待转换任务执行完毕
|
|||
|
|
# 会话超时检查
|
|||
|
|
if time.time() - session['last_active'] > self.gc_interval:
|
|||
|
|
self.close_session(session_id)
|
|||
|
|
break
|
|||
|
|
if time.time() - session['last_active'] > self.gc_tts:
|
|||
|
|
session['finished'].set()
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"Task processing error: {str(e)}")
|
|||
|
|
|
|||
|
|
def _generate_audio(self, session_id, text):
|
|||
|
|
"""实际生成音频(线程池执行)"""
|
|||
|
|
session = self.sessions.get(session_id)
|
|||
|
|
if not session: return
|
|||
|
|
# logging.info(f"_generate_audio:{text}")
|
|||
|
|
first_chunk = True
|
|||
|
|
# logging.info(f"转换开始!!! {text}")
|
|||
|
|
try:
|
|||
|
|
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
|
|||
|
|
if session['stream_format'] == 'wav':
|
|||
|
|
if first_chunk:
|
|||
|
|
chunk_len = len(chunk)
|
|||
|
|
if chunk_len > 2048:
|
|||
|
|
session['buffer'].put(audio_fade_in(chunk,1024))
|
|||
|
|
else:
|
|||
|
|
session['buffer'].put(audio_fade_in(chunk, chunk_len))
|
|||
|
|
first_chunk = False
|
|||
|
|
else:
|
|||
|
|
session['buffer'].put(chunk)
|
|||
|
|
else:
|
|||
|
|
session['buffer'].put(chunk)
|
|||
|
|
session['last_active'] = time.time()
|
|||
|
|
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
|||
|
|
if session['tts_chunk_data_valid'] is False:
|
|||
|
|
session['tts_chunk_data_valid'] = True #20250510 增加,表示连接TTS后台已经返回,可以通知前端了
|
|||
|
|
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
|
|||
|
|
except Exception as e:
|
|||
|
|
session['buffer'].put(f"ERROR:{str(e)}")
|
|||
|
|
logging.info(f"--_generate_audio--error {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def close_session(self, session_id):
|
|||
|
|
with self.lock:
|
|||
|
|
if session_id in self.sessions:
|
|||
|
|
# 标记会话为不活跃
|
|||
|
|
self.sessions[session_id]['active'] = False
|
|||
|
|
# 延迟2秒后清理资源
|
|||
|
|
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
|||
|
|
|
|||
|
|
def _clean_session(self, session_id):
|
|||
|
|
with self.lock:
|
|||
|
|
if session_id in self.sessions:
|
|||
|
|
del self.sessions[session_id]
|
|||
|
|
|
|||
|
|
def get_session(self, session_id):
|
|||
|
|
return self.sessions.get(session_id)
|
|||
|
|
|
|||
|
|
stream_manager = StreamSessionManager()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def allowed_file(filename):
|
|||
|
|
return '.' in filename and \
|
|||
|
|
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
|
|||
|
|
|
|||
|
|
|
|||
|
|
audio_text_cache = {}
|
|||
|
|
cache_lock = Lock()
|
|||
|
|
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
@manager.route('/tts_stream/<session_id>', methods=['GET'])
|
|||
|
|
def tts_stream(session_id):
|
|||
|
|
session = stream_manager.sessions.get(session_id)
|
|||
|
|
logging.info(f"--tts_stream {session}")
|
|||
|
|
if session is None:
|
|||
|
|
return get_error_data_result(message="Audio stream not found or expired.")
|
|||
|
|
|
|||
|
|
def generate():
|
|||
|
|
count = 0;
|
|||
|
|
finished_event = session['finished']
|
|||
|
|
try:
|
|||
|
|
while not finished_event.is_set():
|
|||
|
|
if not session or not session['active']:
|
|||
|
|
break
|
|||
|
|
try:
|
|||
|
|
chunk = session['buffer'].get_nowait() #
|
|||
|
|
count = count + 1
|
|||
|
|
if isinstance(chunk, str) and chunk.startswith("ERROR"):
|
|||
|
|
logging.info(f"---tts stream error!!!! {chunk}")
|
|||
|
|
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
|
|||
|
|
break
|
|||
|
|
if session['stream_format'] == "wav":
|
|||
|
|
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
|
|||
|
|
yield gzip_base64_data
|
|||
|
|
else:
|
|||
|
|
yield chunk
|
|||
|
|
retry_count = 0 # 成功收到数据重置重试计数器
|
|||
|
|
except queue.Empty:
|
|||
|
|
if session['stream_format'] == "wav":
|
|||
|
|
pass
|
|||
|
|
else:
|
|||
|
|
pass
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.info(f"tts streag get error2 {e} ")
|
|||
|
|
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# 确保流结束后关闭会话
|
|||
|
|
if session:
|
|||
|
|
# 延迟关闭会话,确保所有数据已发送
|
|||
|
|
stream_manager.close_session(session_id)
|
|||
|
|
logging.info(f"Session {session_id} closed.")
|
|||
|
|
# 关键响应头设置
|
|||
|
|
|
|||
|
|
if session['stream_format'] == "wav":
|
|||
|
|
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
|
|||
|
|
else:
|
|||
|
|
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
|
|||
|
|
resp.headers.add_header("Cache-Control", "no-cache")
|
|||
|
|
resp.headers.add_header("Connection", "keep-alive")
|
|||
|
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
|||
|
|
return resp
|
|||
|
|
|
|||
|
|
"""
|
|||
|
|
def generate_mp3_header(bitrate_kbps=128, padding=0):
|
|||
|
|
# 字段定义
|
|||
|
|
sync = 0b11111111111 # 同步字(11位)
|
|||
|
|
version = 0b11 # MPEG-1(2位)
|
|||
|
|
layer = 0b01 # Layer III(2位)
|
|||
|
|
protection = 0b0 # 无CRC(1位)
|
|||
|
|
bitrate_index = { # 比特率索引表(MPEG-1 Layer III)
|
|||
|
|
32: 0b0001, 40: 0b0010, 48: 0b0011, 56: 0b0100,
|
|||
|
|
64: 0b0101, 80: 0b0110, 96: 0b0111, 112: 0b1000,
|
|||
|
|
128: 0b1001, 160: 0b1010, 192: 0b1011, 224: 0b1100,
|
|||
|
|
256: 0b1101, 320: 0b1110
|
|||
|
|
}[bitrate_kbps]
|
|||
|
|
sampling_rate = 0b00 # 44.1kHz(2位)
|
|||
|
|
padding_bit = padding # 填充位(1位)
|
|||
|
|
private = 0b0 # 私有位(1位)
|
|||
|
|
mode = 0b11 # 单声道(2位)
|
|||
|
|
mode_ext = 0b00 # 扩展模式(2位)
|
|||
|
|
copyright = 0b0 # 无版权(1位)
|
|||
|
|
original = 0b0 # 非原版(1位)
|
|||
|
|
emphasis = 0b00 # 无强调(2位)
|
|||
|
|
|
|||
|
|
# 组合为32位整数(大端序)
|
|||
|
|
header = (
|
|||
|
|
(sync << 21) |
|
|||
|
|
(version << 19) |
|
|||
|
|
(layer << 17) |
|
|||
|
|
(protection << 16) |
|
|||
|
|
(bitrate_index << 12) |
|
|||
|
|
(sampling_rate << 10) |
|
|||
|
|
(padding_bit << 9) |
|
|||
|
|
(private << 8) |
|
|||
|
|
(mode << 6) |
|
|||
|
|
(mode_ext << 4) |
|
|||
|
|
(copyright << 3) |
|
|||
|
|
(original << 2) |
|
|||
|
|
emphasis
|
|||
|
|
)
|
|||
|
|
# 转换为4字节二进制数据
|
|||
|
|
return header.to_bytes(4, byteorder='big')
|
|||
|
|
|
|||
|
|
|
|||
|
|
# ------------------------------------------------
|
|||
|
|
def audio_fade_in(audio_data, fade_length):
|
|||
|
|
# 假设音频数据是16位单声道PCM
|
|||
|
|
# 将二进制数据转换为整数数组
|
|||
|
|
samples = array.array('h', audio_data)
|
|||
|
|
|
|||
|
|
# 对前fade_length个样本进行淡入处理
|
|||
|
|
for i in range(fade_length):
|
|||
|
|
fade_factor = i / fade_length
|
|||
|
|
samples[i] = int(samples[i] * fade_factor)
|
|||
|
|
|
|||
|
|
# 将整数数组转换回二进制数据
|
|||
|
|
return samples.tobytes()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def parse_markdown_json(json_string):
|
|||
|
|
# 使用正则表达式匹配Markdown中的JSON代码块
|
|||
|
|
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
|
|||
|
|
if match:
|
|||
|
|
try:
|
|||
|
|
# 尝试解析JSON字符串
|
|||
|
|
data = json.loads(match[1])
|
|||
|
|
return {'success': True, 'data': data}
|
|||
|
|
except json.JSONDecodeError as e:
|
|||
|
|
# 如果解析失败,返回错误信息
|
|||
|
|
return {'success': False, 'data': str(e)}
|
|||
|
|
else:
|
|||
|
|
return {'success': False, 'data': 'not a valid markdown json string'}
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
audio_text_cache = {}
|
|||
|
|
cache_lock = Lock()
|
|||
|
|
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
|||
|
|
|
|||
|
|
|
|||
|
|
# 全角字符到半角字符的映射
|
|||
|
|
def fullwidth_to_halfwidth(s):
|
|||
|
|
full_to_half_map = {
|
|||
|
|
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
|||
|
|
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
|||
|
|
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
|||
|
|
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
|||
|
|
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
|||
|
|
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
|||
|
|
'、': '、', '・': '・', ':': ':'
|
|||
|
|
}
|
|||
|
|
return ''.join(full_to_half_map.get(char, char) for char in s)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def split_text_at_punctuation(text, chunk_size=100):
|
|||
|
|
# 使用正则表达式找到所有的标点符号和特殊字符
|
|||
|
|
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
|||
|
|
tokens = re.split(punctuation_pattern, text)
|
|||
|
|
|
|||
|
|
# 移除空字符串
|
|||
|
|
tokens = [token for token in tokens if token]
|
|||
|
|
|
|||
|
|
# 存储最终的文本块
|
|||
|
|
chunks = []
|
|||
|
|
current_chunk = ''
|
|||
|
|
|
|||
|
|
for token in tokens:
|
|||
|
|
if len(current_chunk) + len(token) <= chunk_size:
|
|||
|
|
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
|||
|
|
current_chunk += (token + ' ')
|
|||
|
|
else:
|
|||
|
|
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
|||
|
|
chunks.append(current_chunk.strip())
|
|||
|
|
current_chunk = token + ' '
|
|||
|
|
|
|||
|
|
# 添加最后一个块(如果有剩余)
|
|||
|
|
if current_chunk:
|
|||
|
|
chunks.append(current_chunk.strip())
|
|||
|
|
|
|||
|
|
return chunks
|
|||
|
|
|
|||
|
|
|
|||
|
|
def extract_text_from_markdown(markdown_text):
|
|||
|
|
# 移除Markdown标题
|
|||
|
|
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
|||
|
|
# 移除内联代码块
|
|||
|
|
text = re.sub(r'`[^`]+`', '', text)
|
|||
|
|
# 移除代码块
|
|||
|
|
text = re.sub(r'```[\s\S]*?```', '', text)
|
|||
|
|
# 移除加粗和斜体
|
|||
|
|
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
|||
|
|
# 移除链接
|
|||
|
|
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
|||
|
|
# 移除图片
|
|||
|
|
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
|||
|
|
# 移除HTML标签
|
|||
|
|
text = re.sub(r'<[^>]+>', '', text)
|
|||
|
|
# 转换标点符号
|
|||
|
|
# text = re.sub(r'[^\w\s]', '', text)
|
|||
|
|
text = fullwidth_to_halfwidth(text)
|
|||
|
|
# 移除多余的空格
|
|||
|
|
text = re.sub(r'\s+', ' ', text).strip()
|
|||
|
|
|
|||
|
|
return text
|
|||
|
|
|
|||
|
|
|
|||
|
|
def encode_gzip_base64(original_data: bytes) -> str:
|
|||
|
|
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
|
|||
|
|
# Step 1: Gzip 压缩
|
|||
|
|
with BytesIO() as buf:
|
|||
|
|
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
|
|||
|
|
gz_file.write(original_data)
|
|||
|
|
compressed_bytes = buf.getvalue()
|
|||
|
|
|
|||
|
|
# Step 2: Base64 编码(配置与Android端匹配)
|
|||
|
|
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def clean_audio_cache():
|
|||
|
|
"""定时清理过期缓存"""
|
|||
|
|
with cache_lock:
|
|||
|
|
now = time.time()
|
|||
|
|
expired_keys = [
|
|||
|
|
k for k, v in audio_text_cache.items()
|
|||
|
|
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
|
|||
|
|
]
|
|||
|
|
for k in expired_keys:
|
|||
|
|
entry = audio_text_cache.pop(k, None)
|
|||
|
|
if entry and entry.get('audio_stream'):
|
|||
|
|
entry['audio_stream'].close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def start_background_cleaner():
|
|||
|
|
"""启动后台清理线程"""
|
|||
|
|
|
|||
|
|
def cleaner_loop():
|
|||
|
|
while True:
|
|||
|
|
time.sleep(180) # 每3分钟清理一次
|
|||
|
|
clean_audio_cache()
|
|||
|
|
|
|||
|
|
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
|
|||
|
|
cleaner_thread.start()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def test_qwen_chat():
|
|||
|
|
messages = [
|
|||
|
|
{'role': 'system', 'content': 'You are a helpful assistant.'},
|
|||
|
|
{'role': 'user', 'content': '你是谁?'}
|
|||
|
|
]
|
|||
|
|
response = Generation.call(
|
|||
|
|
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key = "sk-xxx",
|
|||
|
|
api_key=ALI_KEY,
|
|||
|
|
model="qwen-plus", # 模型列表:https://help.aliyun.com/zh/model-studio/getting-started/models
|
|||
|
|
messages=messages,
|
|||
|
|
result_format="message"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
print(response.output.choices[0].message.content)
|
|||
|
|
else:
|
|||
|
|
print(f"HTTP返回码:{response.status_code}")
|
|||
|
|
print(f"错误码:{response.code}")
|
|||
|
|
print(f"错误信息:{response.message}")
|
|||
|
|
print("请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code")
|
|||
|
|
|
|||
|
|
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
|
|||
|
|
|
|||
|
|
class QwenTTS:
|
|||
|
|
def __init__(self, key, model_name="cosyvoice-v1/longxiaochun", base_url=""):
|
|||
|
|
import dashscope
|
|||
|
|
import ssl
|
|||
|
|
print("---begin--init QwenTTS--") # cyx
|
|||
|
|
self.model_name = model_name
|
|||
|
|
dashscope.api_key = key
|
|||
|
|
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
|
|||
|
|
self.synthesizer = None
|
|||
|
|
self.callback = None
|
|||
|
|
self.is_cosyvoice = False
|
|||
|
|
self.voice = ""
|
|||
|
|
if '/' in model_name:
|
|||
|
|
parts = model_name.split('/', 1)
|
|||
|
|
# 返回分离后的两个字符串parts[0], parts[1]
|
|||
|
|
if parts[0] == 'cosyvoice-v1':
|
|||
|
|
self.is_cosyvoice = True
|
|||
|
|
self.voice = parts[1]
|
|||
|
|
|
|||
|
|
def tts(self, text):
|
|||
|
|
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
|||
|
|
if self.is_cosyvoice is False:
|
|||
|
|
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
|||
|
|
from collections import deque
|
|||
|
|
else:
|
|||
|
|
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
|
|||
|
|
from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat # , SpeechSynthesisResult
|
|||
|
|
from dashscope.audio.tts import SpeechSynthesisResult
|
|||
|
|
from collections import deque
|
|||
|
|
|
|||
|
|
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
|||
|
|
|
|||
|
|
class Callback(ResultCallback):
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self.dque = deque()
|
|||
|
|
|
|||
|
|
def _run(self):
|
|||
|
|
while True:
|
|||
|
|
if not self.dque:
|
|||
|
|
time.sleep(0)
|
|||
|
|
continue
|
|||
|
|
val = self.dque.popleft()
|
|||
|
|
if val:
|
|||
|
|
yield val
|
|||
|
|
else:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
def on_open(self):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def on_complete(self):
|
|||
|
|
self.dque.append(None)
|
|||
|
|
|
|||
|
|
def on_error(self, response: SpeechSynthesisResponse):
|
|||
|
|
print("Qwen tts error", str(response))
|
|||
|
|
raise RuntimeError(str(response))
|
|||
|
|
|
|||
|
|
def on_close(self):
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def on_event(self, result: SpeechSynthesisResult):
|
|||
|
|
if result.get_audio_frame() is not None:
|
|||
|
|
self.dque.append(result.get_audio_frame())
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
|
|||
|
|
class Callback_v2(ResultCallback):
|
|||
|
|
def __init__(self) -> None:
|
|||
|
|
self.dque = deque()
|
|||
|
|
|
|||
|
|
def _run(self):
|
|||
|
|
while True:
|
|||
|
|
if not self.dque:
|
|||
|
|
time.sleep(0)
|
|||
|
|
continue
|
|||
|
|
val = self.dque.popleft()
|
|||
|
|
if val:
|
|||
|
|
yield val
|
|||
|
|
else:
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
def on_open(self):
|
|||
|
|
logging.info("Qwen tts open")
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
def on_complete(self):
|
|||
|
|
self.dque.append(None)
|
|||
|
|
|
|||
|
|
def on_error(self, response: SpeechSynthesisResponse):
|
|||
|
|
print("Qwen tts error", str(response))
|
|||
|
|
raise RuntimeError(str(response))
|
|||
|
|
|
|||
|
|
def on_close(self):
|
|||
|
|
# print("---Qwen call back close") # cyx
|
|||
|
|
logging.info("Qwen tts close")
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
""" canceled for test 语音大模型CosyVoice
|
|||
|
|
def on_event(self, result: SpeechSynthesisResult):
|
|||
|
|
if result.get_audio_frame() is not None:
|
|||
|
|
self.dque.append(result.get_audio_frame())
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def on_event(self, message):
|
|||
|
|
# print(f"recv speech synthsis message {message}")
|
|||
|
|
pass
|
|||
|
|
|
|||
|
|
# 以下适合语音大模型CosyVoice
|
|||
|
|
def on_data(self, data: bytes) -> None:
|
|||
|
|
if len(data) > 0:
|
|||
|
|
self.dque.append(data)
|
|||
|
|
|
|||
|
|
# --------------------------
|
|||
|
|
|
|||
|
|
# text = self.normalize_text(text)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# if self.model_name != 'cosyvoice-v1':
|
|||
|
|
if self.is_cosyvoice is False:
|
|||
|
|
self.callback = Callback()
|
|||
|
|
SpeechSynthesizer.call(model=self.model_name,
|
|||
|
|
text=text,
|
|||
|
|
callback=self.callback,
|
|||
|
|
format="wav") # format="mp3")
|
|||
|
|
else:
|
|||
|
|
self.callback = Callback_v2()
|
|||
|
|
format =self.get_audio_format(FORMAT,TTS_SAMPLERATE)
|
|||
|
|
self.synthesizer = SpeechSynthesizer(
|
|||
|
|
model='cosyvoice-v1',
|
|||
|
|
# voice="longyuan", #"longfei",
|
|||
|
|
voice=self.voice,
|
|||
|
|
callback=self.callback,
|
|||
|
|
format=format
|
|||
|
|
)
|
|||
|
|
self.synthesizer.call(text)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"---dale---20 error {e}") # cyx
|
|||
|
|
# -----------------------------------
|
|||
|
|
try:
|
|||
|
|
for data in self.callback._run():
|
|||
|
|
yield data
|
|||
|
|
# print(f"---Qwen return data {num_tokens_from_string(text)}")
|
|||
|
|
# yield num_tokens_from_string(text)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
raise RuntimeError(f"**ERROR**: {e}")
|
|||
|
|
|
|||
|
|
def get_audio_format(self, format: str, sample_rate: int):
|
|||
|
|
"""动态获取音频格式"""
|
|||
|
|
from dashscope.audio.tts_v2 import AudioFormat
|
|||
|
|
logging.info(f"--get_audio_format-- {format} {sample_rate}")
|
|||
|
|
format_map = {
|
|||
|
|
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
|
|||
|
|
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
|
|||
|
|
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
|
|||
|
|
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
|
|||
|
|
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
|||
|
|
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
|
|||
|
|
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
|
|||
|
|
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
|
|||
|
|
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
|
|||
|
|
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
|
|||
|
|
}
|
|||
|
|
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
|
|||
|
|
|
|||
|
|
def end_tts(self):
|
|||
|
|
if self.synthesizer:
|
|||
|
|
self.synthesizer.streaming_complete()
|
|||
|
|
|
|||
|
|
@tts_router.get("/audio/pcm_mp3")
|
|||
|
|
async def stream_mp3():
|
|||
|
|
def audio_generator():
|
|||
|
|
path = './test.mp3'
|
|||
|
|
try:
|
|||
|
|
with open(path, 'rb') as f:
|
|||
|
|
while True:
|
|||
|
|
chunk = f.read(1024)
|
|||
|
|
if not chunk:
|
|||
|
|
break
|
|||
|
|
yield chunk
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"MP3 streaming error: {str(e)}")
|
|||
|
|
|
|||
|
|
return StreamingResponse(
|
|||
|
|
audio_generator(),
|
|||
|
|
media_type="audio/mpeg",
|
|||
|
|
headers={"Cache-Control": "no-store"}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def add_wav_header(pcm_data: bytes, sample_rate: int) -> bytes:
|
|||
|
|
"""动态生成WAV头(严格保持原有逻辑结构)"""
|
|||
|
|
with BytesIO() as wav_buffer:
|
|||
|
|
with wave.open(wav_buffer, 'wb') as wav_file:
|
|||
|
|
wav_file.setnchannels(1) # 保持原单声道设置
|
|||
|
|
wav_file.setsampwidth(2) # 保持原16-bit设置
|
|||
|
|
wav_file.setframerate(sample_rate)
|
|||
|
|
wav_file.writeframes(pcm_data)
|
|||
|
|
wav_buffer.seek(0)
|
|||
|
|
return wav_buffer.read()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_silence_header(duration_ms: int = 500) -> bytes:
|
|||
|
|
"""生成静音数据(用于MP3流式传输预缓冲)"""
|
|||
|
|
num_samples = int(TTS_SAMPLERATE * duration_ms / 1000)
|
|||
|
|
return b'\x00' * num_samples * SAMPLE_WIDTH * CHANNELS
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tts_router.get("/chats/{chat_id}/tts/{audio_stream_id}")
|
|||
|
|
async def get_tts_audio(
|
|||
|
|
chat_id: str,
|
|||
|
|
audio_stream_id: str,
|
|||
|
|
request: Request,
|
|||
|
|
range: str = Header(None) # 新增Range头解析
|
|||
|
|
):
|
|||
|
|
with cache_lock:
|
|||
|
|
# tts_info = audio_text_cache.pop(audio_stream_id, None)
|
|||
|
|
tts_info = audio_text_cache.get(audio_stream_id, None)
|
|||
|
|
if not tts_info:
|
|||
|
|
raise HTTPException(404, detail="音频流已过期")
|
|||
|
|
audio_stream_len = tts_info.get('audio_stream_len', 0)
|
|||
|
|
audio_stream = tts_info['audio_stream']
|
|||
|
|
text = tts_info['text']
|
|||
|
|
|
|||
|
|
model_name = tts_info.get('model_name', "cosyvoice-v1/longxiaochun")
|
|||
|
|
format = tts_info['format'] # 新增字段
|
|||
|
|
sample_rate = tts_info['sample_rate'] # 新增字段
|
|||
|
|
|
|||
|
|
def stream_audio():
|
|||
|
|
total = 0
|
|||
|
|
try:
|
|||
|
|
|
|||
|
|
for chunk in tts_mdl.tts(text):
|
|||
|
|
# print(f"data_length={total} {chunk}")
|
|||
|
|
hex_data = binascii.hexlify(chunk).decode("utf-8") + "\r\n" # <--- 添加\n
|
|||
|
|
# yield f"{hex_data}" # <--- SSE格式封装
|
|||
|
|
print(f"yield {len(chunk)}")
|
|||
|
|
# yield chunk
|
|||
|
|
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
|
|||
|
|
total = total + len(gzip_base64_data)
|
|||
|
|
yield gzip_base64_data
|
|||
|
|
print(f"tts gen end {total}")
|
|||
|
|
except Exception as e:
|
|||
|
|
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
|||
|
|
"data": {"answer": "**ERROR**: " + str(e)}},
|
|||
|
|
ensure_ascii=False)).encode('utf-8')
|
|||
|
|
|
|||
|
|
def read_buffer():
|
|||
|
|
print(f"get audio_stream {audio_stream} {audio_stream.closed}")
|
|||
|
|
audio_stream.seek(0) # 关键重置操作
|
|||
|
|
data = audio_stream.read(1024)
|
|||
|
|
while data:
|
|||
|
|
yield data
|
|||
|
|
data = audio_stream.read(1024)
|
|||
|
|
|
|||
|
|
def generate_silence_header():
|
|||
|
|
"""生成500ms的静音MP3帧(约44100Hz)"""
|
|||
|
|
# 示例静音数据(实际需生成标准MP3静音帧)
|
|||
|
|
return b'\xff\xfb\xd4\x00' * 100 # 约400字节
|
|||
|
|
|
|||
|
|
# 保持原有生成器结构
|
|||
|
|
def generate_audio():
|
|||
|
|
# 保持原有QwenTTS调用方式
|
|||
|
|
tts = QwenTTS(ALI_KEY, model_name)
|
|||
|
|
|
|||
|
|
# 临时修改全局配置(保持原有逻辑)
|
|||
|
|
global FORMAT, TTS_SAMPLERATE
|
|||
|
|
original_format = FORMAT
|
|||
|
|
original_sample_rate = TTS_SAMPLERATE
|
|||
|
|
FORMAT = format
|
|||
|
|
TTS_SAMPLERATE = sample_rate
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
for chunk in tts.tts(text):
|
|||
|
|
if format == 'wav':
|
|||
|
|
yield add_wav_header(chunk, sample_rate)
|
|||
|
|
else:
|
|||
|
|
yield chunk
|
|||
|
|
finally:
|
|||
|
|
# 恢复全局变量
|
|||
|
|
FORMAT = original_format
|
|||
|
|
TTS_SAMPLERATE = original_sample_rate
|
|||
|
|
|
|||
|
|
# 保持原有流响应逻辑
|
|||
|
|
media_type_map = {
|
|||
|
|
'mp3': 'audio/mpeg',
|
|||
|
|
'wav': 'audio/wav',
|
|||
|
|
'pcm': f'audio/L16; rate={sample_rate}; channels=1'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 处理Range请求逻辑
|
|||
|
|
if range:
|
|||
|
|
start =0
|
|||
|
|
end =audio_stream_len
|
|||
|
|
total_length = audio_stream_len
|
|||
|
|
return StreamingResponse(
|
|||
|
|
read_buffer(),
|
|||
|
|
status_code=206,
|
|||
|
|
headers={
|
|||
|
|
"Accept-Ranges": "bytes",
|
|||
|
|
"Content-Range": f"bytes {start}-{end}/{total_length}",
|
|||
|
|
"Cache-Control": "no-store" # 确保不缓存动态内容
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
try:
|
|||
|
|
# 从缓存获取原始参数
|
|||
|
|
|
|||
|
|
if audio_stream is None:
|
|||
|
|
tts_mdl = QwenTTS(ALI_KEY, "cosyvoice-v1/longxiaochun")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if audio_stream:
|
|||
|
|
logging.info("return audio stream buffer")
|
|||
|
|
# 确保流的位置在开始处
|
|||
|
|
return StreamingResponse(
|
|||
|
|
read_buffer(),
|
|||
|
|
media_type=media_type_map[format],
|
|||
|
|
headers={
|
|||
|
|
"Transfer-Encoding": "chunked",
|
|||
|
|
"Cache-Control": "no-store",
|
|||
|
|
"Content-Disposition": "inline",
|
|||
|
|
"Access-Control-Allow-Origin": "*"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
return StreamingResponse(
|
|||
|
|
generate_audio(),
|
|||
|
|
media_type=media_type_map[format],
|
|||
|
|
headers={
|
|||
|
|
"Transfer-Encoding": "chunked",
|
|||
|
|
"Cache-Control": "no-store",
|
|||
|
|
"Content-Disposition": "inline"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"音频流错误: {str(e)}")
|
|||
|
|
raise HTTPException(500, detail="音频生成失败")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tts_router.post("/chats/{chat_id}/tts")
|
|||
|
|
async def create_tts_request(chat_id: str, request: Request):
|
|||
|
|
try:
|
|||
|
|
request_data = await request.json()
|
|||
|
|
# 获取并验证新参数
|
|||
|
|
format = request_data.get('format', FORMAT)
|
|||
|
|
if format not in ['mp3', 'wav', 'pcm']:
|
|||
|
|
raise HTTPException(400, detail="不支持的音频格式")
|
|||
|
|
|
|||
|
|
sample_rate = request_data.get('sample_rate', TTS_SAMPLERATE)
|
|||
|
|
if sample_rate not in [8000, 16000, 22050, 44100]:
|
|||
|
|
raise HTTPException(400, detail="不支持的采样率")
|
|||
|
|
|
|||
|
|
# 保持原有参数处理
|
|||
|
|
text = request_data.get("text")
|
|||
|
|
if not text or not text.strip():
|
|||
|
|
raise HTTPException(400, detail="文本内容不能为空")
|
|||
|
|
|
|||
|
|
# 存储新增参数到缓存
|
|||
|
|
audio_stream_id = str(uuid.uuid4())
|
|||
|
|
tts_info = {
|
|||
|
|
'text': text,
|
|||
|
|
'chat_id': chat_id,
|
|||
|
|
'created_at': time.time(),
|
|||
|
|
'audio_stream': None,
|
|||
|
|
'model_name': request_data.get('model_name'),
|
|||
|
|
'format': format,
|
|||
|
|
'sample_rate': sample_rate
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
audio_stream_len = 0
|
|||
|
|
# 保持原有延迟生成逻辑
|
|||
|
|
if request_data.get('delay_gen_audio', False) is False:
|
|||
|
|
try:
|
|||
|
|
# 临时设置全局变量(保持原有逻辑)
|
|||
|
|
original_format = FORMAT
|
|||
|
|
original_sample_rate = TTS_SAMPLERATE
|
|||
|
|
tts = QwenTTS(ALI_KEY, tts_info.get('model_name', "cosyvoice-v1/longxiaochun"))
|
|||
|
|
buffer = io.BytesIO()
|
|||
|
|
for chunk in tts.tts(text):
|
|||
|
|
audio_stream_len = audio_stream_len + len(chunk)
|
|||
|
|
buffer.write(chunk)
|
|||
|
|
buffer.seek(0)
|
|||
|
|
tts_info['audio_stream'] = buffer
|
|||
|
|
tts_info['audio_stream_len'] = audio_stream_len
|
|||
|
|
finally:
|
|||
|
|
# 恢复全局变量
|
|||
|
|
pass
|
|||
|
|
# 保持原有缓存逻辑
|
|||
|
|
with cache_lock:
|
|||
|
|
audio_text_cache[audio_stream_id] = tts_info
|
|||
|
|
logging.info(f"create tts stream return {audio_stream_id} {audio_text_cache[audio_stream_id]}")
|
|||
|
|
# 保持原响应结构
|
|||
|
|
return {
|
|||
|
|
"tts_url": f"/chats/{chat_id}/tts/{audio_stream_id}",
|
|||
|
|
"audio_stream_id": audio_stream_id,
|
|||
|
|
"sample_rate": sample_rate
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"请求处理失败: {str(e)}")
|
|||
|
|
raise HTTPException(500, detail="内部服务器错误")
|
|||
|
|
|
|||
|
|
# 辅助函数
|
|||
|
|
|
|||
|
|
# ------------------------ API路由 ------------------------
|
|||
|
|
@tts_router.post("/chats1/{chat_id}/tts")
|
|||
|
|
async def create_tts_request1(chat_id: str, request: Request):
|
|||
|
|
"""创建TTS音频流"""
|
|||
|
|
try:
|
|||
|
|
data = await request.json()
|
|||
|
|
logging.info(f"API--create_tts_request1-- {data}")
|
|||
|
|
# 参数校验
|
|||
|
|
text = data.get("text", "").strip()
|
|||
|
|
if not text:
|
|||
|
|
raise HTTPException(400, detail="文本内容不能为空")
|
|||
|
|
|
|||
|
|
format = data.get("format", "mp3")
|
|||
|
|
if format not in ["mp3", "wav", "pcm"]:
|
|||
|
|
raise HTTPException(400, detail="不支持的音频格式")
|
|||
|
|
|
|||
|
|
sample_rate = data.get("sample_rate", 16000)
|
|||
|
|
if sample_rate not in [8000, 16000, 22050, 44100]:
|
|||
|
|
raise HTTPException(400, detail="不支持的采样率")
|
|||
|
|
format ="mp3"
|
|||
|
|
sample_rate = 44100
|
|||
|
|
# 生成音频流
|
|||
|
|
audio_stream_id = str(uuid.uuid4())
|
|||
|
|
buffer = io.BytesIO()
|
|||
|
|
tts = QwenTTS(ALI_KEY)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
for chunk in tts.tts(text):
|
|||
|
|
buffer.write(chunk)
|
|||
|
|
buffer.seek(0)
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"TTS生成失败: {str(e)}")
|
|||
|
|
raise HTTPException(500, detail="音频生成失败")
|
|||
|
|
|
|||
|
|
# 存储到缓存
|
|||
|
|
with cache_lock:
|
|||
|
|
audio_text_cache[audio_stream_id] = {
|
|||
|
|
"buffer": buffer,
|
|||
|
|
"format": format,
|
|||
|
|
"sample_rate": sample_rate,
|
|||
|
|
"created_at": datetime.datetime.now(),
|
|||
|
|
"size": buffer.getbuffer().nbytes
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return JSONResponse(
|
|||
|
|
status_code=200,
|
|||
|
|
content={
|
|||
|
|
"tts_url":f"/chats1/{chat_id}/tts/{audio_stream_id}",
|
|||
|
|
"url": f"/chats1/{chat_id}/tts/{audio_stream_id}",
|
|||
|
|
"expires_at": (datetime.datetime.now() + timedelta(seconds=CACHE_EXPIRE_SECONDS)).isoformat()
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"请求处理失败: {str(e)}")
|
|||
|
|
raise HTTPException(500, detail="服务器内部错误")
|
|||
|
|
|
|||
|
|
|
|||
|
|
@tts_router.get("/chats1/{chat_id}/tts/{audio_stream_id}")
|
|||
|
|
async def get_tts_audio1(
|
|||
|
|
chat_id: str,
|
|||
|
|
audio_stream_id: str,
|
|||
|
|
range: str = Header(None)
|
|||
|
|
):
|
|||
|
|
"""获取音频流"""
|
|||
|
|
# 清理过期缓存
|
|||
|
|
cleanup_cache()
|
|||
|
|
|
|||
|
|
# 获取缓存
|
|||
|
|
with cache_lock:
|
|||
|
|
item = audio_text_cache.get(audio_stream_id)
|
|||
|
|
|
|||
|
|
if not item:
|
|||
|
|
raise HTTPException(404, detail="音频流不存在或已过期")
|
|||
|
|
|
|||
|
|
# 准备响应参数
|
|||
|
|
buffer = item["buffer"]
|
|||
|
|
format = item["format"]
|
|||
|
|
total_size = item["size"]
|
|||
|
|
media_type = {
|
|||
|
|
"mp3": "audio/mpeg",
|
|||
|
|
"wav": "audio/wav",
|
|||
|
|
"pcm": f"audio/L16; rate={item['sample_rate']}; channels=1"
|
|||
|
|
}[format]
|
|||
|
|
|
|||
|
|
# 处理范围请求
|
|||
|
|
if range:
|
|||
|
|
return handle_range_request(range, buffer, total_size, media_type)
|
|||
|
|
|
|||
|
|
# 完整文件响应
|
|||
|
|
return StreamingResponse(
|
|||
|
|
iter(lambda: buffer.read(4096), b""),
|
|||
|
|
media_type=media_type,
|
|||
|
|
headers={
|
|||
|
|
"Accept-Ranges": "bytes",
|
|||
|
|
"Content-Length": str(total_size),
|
|||
|
|
"Cache-Control": f"max-age={CACHE_EXPIRE_SECONDS}"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def handle_range_request(range: str, buffer: BytesIO, total_size: int, media_type: str):
|
|||
|
|
"""处理HTTP范围请求"""
|
|||
|
|
try:
|
|||
|
|
unit, ranges = range.split("=")
|
|||
|
|
if unit != "bytes":
|
|||
|
|
raise ValueError
|
|||
|
|
|
|||
|
|
start_str, end_str = ranges.split("-")
|
|||
|
|
start = int(start_str) if start_str else 0
|
|||
|
|
end = int(end_str) if end_str else total_size - 1
|
|||
|
|
|
|||
|
|
# 验证范围有效性
|
|||
|
|
if start >= total_size or end >= total_size:
|
|||
|
|
raise HTTPException(416, detail="请求范围无效", headers={
|
|||
|
|
"Content-Range": f"bytes */{total_size}"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
content_length = end - start + 1
|
|||
|
|
buffer.seek(start)
|
|||
|
|
|
|||
|
|
def chunk_generator():
|
|||
|
|
remaining = content_length
|
|||
|
|
while remaining > 0:
|
|||
|
|
chunk = buffer.read(min(4096, remaining))
|
|||
|
|
if not chunk:
|
|||
|
|
break
|
|||
|
|
yield chunk
|
|||
|
|
remaining -= len(chunk)
|
|||
|
|
|
|||
|
|
return StreamingResponse(
|
|||
|
|
chunk_generator(),
|
|||
|
|
status_code=206,
|
|||
|
|
headers={
|
|||
|
|
"Content-Range": f"bytes {start}-{end}/{total_size}",
|
|||
|
|
"Content-Length": str(content_length),
|
|||
|
|
"Content-Type": media_type,
|
|||
|
|
"Accept-Ranges": "bytes"
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except ValueError:
|
|||
|
|
raise HTTPException(400, detail="无效的Range头")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def cleanup_cache():
|
|||
|
|
"""清理过期缓存"""
|
|||
|
|
with cache_lock:
|
|||
|
|
now = datetime.datetime.now()
|
|||
|
|
expired = [k for k, v in audio_text_cache.items()
|
|||
|
|
if (now - v["created_at"]).total_seconds() > CACHE_EXPIRE_SECONDS]
|
|||
|
|
for key in expired:
|
|||
|
|
del audio_text_cache[key]
|
|||
|
|
|
|||
|
|
# 应用启动时启动清理线程
|
|||
|
|
# start_background_cleaner()
|