Files
ragflow_python/api/db/services/ali_tts_service.py

367 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio,logging
from collections import deque
import threading, time,queue,uuid,time,array
from concurrent.futures import ThreadPoolExecutor
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
from dashscope.audio.tts import (
ResultCallback as TTSResultCallback,
SpeechSynthesizer as TTSSpeechSynthesizer,
SpeechSynthesisResult as TTSSpeechSynthesisResult,
)
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
from dashscope.audio.tts_v2 import (
ResultCallback as CosyResultCallback,
SpeechSynthesizer as CosySpeechSynthesizer,
AudioFormat,
)
class QwenTTS:
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longyuan"):
import dashscope
import ssl
logging.info(f"---QwenTTS Construtor-- {format} {sample_rate} {model_name}") # cyx
self.model_name = model_name
dashscope.api_key = key
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
self.synthesizer = None
self.callback = None
self.is_cosyvoice = False
self.voice = ""
self.format = format
self.sample_rate = sample_rate
if '/' in model_name:
parts = model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1':
self.is_cosyvoice = True
self.voice = parts[1]
class Callback(TTSResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
pass
def on_event(self, result: TTSSpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
# --------------------------
class Callback_Cosy(CosyResultCallback):
def __init__(self,on_audio_data) -> None:
self.dque = deque()
self.on_audio_data = on_audio_data
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
logging.info("---Qwen tts on_open---")
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
# print("---Qwen call back close") # cyx
logging.info("---Qwen tts on_close---")
pass
""" canceled for test 语音大模型CosyVoice
def on_event(self, result: SpeechSynthesisResult):
if result.get_audio_frame() is not None:
self.dque.append(result.get_audio_frame())
"""
def on_event(self, message):
# print(f"recv speech synthsis message {message}")
pass
# 以下适合语音大模型CosyVoice
def on_data(self, data: bytes) -> None:
if len(data) > 0:
if self.on_audio_data:
self.on_audio_data(data)
else:
self.dque.append(data)
# --------------------------
def tts(self, text):
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
# text = self.normalize_text(text)
try:
# if self.model_name != 'cosyvoice-v1':
if self.is_cosyvoice is False:
self.callback = self.Callback()
TTSSpeechSynthesizer.call(model=self.model_name,
text=text,
callback=self.callback,
format="wav") # format="mp3")
else:
self.callback = self.Callback_Cosy(None)
format =self.get_audio_format(self.format,self.sample_rate)
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
self.synthesizer.call(text)
except Exception as e:
print(f"---dale---20 error {e}") # cyx
# -----------------------------------
try:
for data in self.callback._run():
#logging.info(f"dashcope return data {len(data)}")
yield data
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def init_streaming_call(self, on_data):
try:
self.callback = self.Callback_Cosy(on_data)
format =self.get_audio_format(self.format,self.sample_rate)
self.synthesizer = CosySpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
except Exception as e:
print(f"---dale---30 error {e}") # cyx
# -----------------------------------
def streaming_call(self,text):
if self.synthesizer:
self.synthesizer.streaming_call(text)
def end_streaming_call(self):
if self.synthesizer:
self.synthesizer.streaming_complete()
def get_audio_format(self, format: str, sample_rate: int):
"""动态获取音频格式"""
from dashscope.audio.tts_v2 import AudioFormat
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
(48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
(48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
(48000, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
class StreamSessionManager:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次 5 x 60 300秒
self.gc_tts = 10 # 10s 大模型开始输出文本有可能需要比较久2025年5 24 从3s->10s
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3',voice='cosyvoice-v1/longxiaochun'):
session_id = str(uuid.uuid4())
def on_audio_data(chunk):
session = self.sessions.get(session_id)
first_chunk = not session['tts_chunk_data_valid']
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk, 1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True # 20250510 增加表示连接TTS后台已经返回可以通知前端了
with self.lock:
ali_tts_model = QwenTTS(ALI_KEY,stream_format, sample_rate,voice.split('@')[0])
self.sessions[session_id] = {
'tts_model': ali_tts_model, #tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count':0,
'finished': threading.Event(), # 添加事件对象
'sample_rate':sample_rate,
'stream_format':stream_format,
"tts_chunk_data_valid":False,
'voice':voice,
}
self.sessions[session_id]['tts_model'].init_streaming_call(on_audio_data)
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 将文本放入任务队列(非阻塞)
#logging.info(f"StreamSessionManager append_text {text}")
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
"""任务处理线程(每个会话独立)"""
while True:
session = self.sessions.get(session_id)
if not session or not session['active']:
break
try:
#logging.info(f"StreamSessionManager _process_tasks {session['task_queue'].qsize()}")
# 合并多个文本块最多等待50ms
texts = []
while len(texts) < 5: # 最大合并5个文本块
try:
text = session['task_queue'].get(timeout=0.1)
#logging.info(f"StreamSessionManager _process_tasks --0 {len(texts)}")
texts.append(text)
except queue.Empty:
break
if texts:
session['last_active'] = time.time() # 如果有处理文本,重置活跃时间
# 提交到线程池处理
#future=self.executor.submit(
# self._generate_audio,
# session_id,
# ' '.join(texts) # 合并文本减少请求次数
#)
#future.result() # 等待转换任务执行完毕
session['tts_model'].streaming_call(''.join(texts))
session['last_active'] = time.time()
# 会话超时检查
if time.time() - session['last_active'] > self.gc_interval:
self.close_session(session_id)
break
if time.time() - session['last_active'] > self.gc_tts:
session['tts_model'].end_streaming_call()
session['finished'].set()
break
except Exception as e:
logging.error(f"Task processing error: {str(e)}")
def _generate_audio(self, session_id, text):
"""实际生成音频(线程池执行)"""
session = self.sessions.get(session_id)
if not session: return
# logging.info(f"_generate_audio:{text}")
first_chunk = True
logging.info(f"转换开始!!! {text}")
try:
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk,1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
first_chunk = False
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True #20250510 增加表示连接TTS后台已经返回可以通知前端了
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
logging.info(f"--_generate_audio--error {str(e)}")
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
logging.info(f"--Session {session_id} close_session")
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def get_session(self, session_id):
return self.sessions.get(session_id)
stream_manager_w_stream = StreamSessionManager()
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()