Files
ragflow_python/asr-monitor-test/backup.txt

264 lines
11 KiB
Plaintext
Raw Normal View History

"""
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
class QwenTTS:
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longxiaochun"):
import dashscope
print("---begin--init dialog_service QwenTTS--") # cyx
self.model_name = model_name
dashscope.api_key = key
self.synthesizer = None
self.callback = None
self.is_cosyvoice = False
self.voice = ""
self.format = format
self.sample_rate = sample_rate
if '/' in model_name:
parts = model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1':
self.is_cosyvoice = True
self.voice = parts[1]
def init_streaming_call(self, audio_call_back):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat # , SpeechSynthesisResult
from dashscope.audio.tts import SpeechSynthesisResult
from collections import deque
print(f"--QwenTTS--tts_stream begin-- {self.is_cosyvoice} {self.voice}") # cyx
class Callback_v2(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
logging.info("Qwen tts open")
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
# print("---Qwen call back close") # cyx
logging.info("Qwen tts close")
pass
# canceled for test 语音大模型CosyVoice
#def on_event(self, result: SpeechSynthesisResult):
# if result.get_audio_frame() is not None:
# self.dque.append(result.get_audio_frame())
#
def on_event(self, message):
# print(f"recv speech synthsis message {message}")
pass
# 以下适合语音大模型CosyVoice
def on_data(self, data: bytes) -> None:
if len(data) > 0:
#self.dque.append(data)
if audio_call_back:
audio_call_back(data)
# --------------------------
# text = self.normalize_text(text)
try:
self.callback = Callback_v2()
format =self.get_audio_format(self.format,self.sample_rate)
self.synthesizer = SpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
#self.synthesizer.call(text)
except Exception as e:
print(f"---dale---20 error {e}") # cyx
# -----------------------------------
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def get_audio_format(self, format: str, sample_rate: int):
# 动态获取音频格式
from dashscope.audio.tts_v2 import AudioFormat
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
(48800, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
(48800, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
(48800, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
def streaming_call(self,text):
if self.synthesizer:
self.synthesizer.streaming_call(text)
def end_streaming_call(self):
if self.synthesizer:
self.synthesizer.streaming_complete()
class StreamSessionManager1:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次 5 x 60 300秒
self.gc_tts = 10 # 10s 大模型开始输出文本有可能需要比较久2025年5 24 从3s->10s
self.inited = False
self.tts_model = None
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
session_id = str(uuid.uuid4())
with self.lock:
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count':0,
'finished': threading.Event(), # 添加事件对象
'sample_rate':sample_rate,
'stream_format':stream_format,
"tts_chunk_data_valid":False
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 将文本放入任务队列(非阻塞)
#logging.info(f"StreamSessionManager append_text {text}")
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
#任务处理线程(每个会话独立)
session = self.sessions.get(session_id)
def audio_call_back(chunk):
logging.info(f"audio_call_back {len(chunk)}")
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk, 1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
first_chunk = False
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True # 20250510 增加表示连接TTS后台已经返回可以通知前端了
while True:
if not session or not session['active']:
break
if not self.inited:
self.inited = True
self.tts_model = QwenTTS(ALI_KEY, session['stream_format'], session['sample_rate'])
self.tts_model.init_streaming_call(audio_call_back)
try:
#logging.info(f"StreamSessionManager _process_tasks {session['task_queue'].qsize()}")
# 合并多个文本块最多等待50ms
texts = []
while len(texts) < 5: # 最大合并5个文本块
try:
text = session['task_queue'].get(timeout=0.1)
#logging.info(f"StreamSessionManager _process_tasks --0 {len(texts)}")
texts.append(text)
except queue.Empty:
break
if texts:
session['last_active'] = time.time() # 如果有处理文本,重置活跃时间
# 提交到线程池处理
future=self.executor.submit(
self._generate_audio,
session_id,
' '.join(texts) # 合并文本减少请求次数
)
future.result() # 等待转换任务执行完毕
session['last_active'] = time.time()
# 会话超时检查
if time.time() - session['last_active'] > self.gc_interval:
self.close_session(session_id)
break
if time.time() - session['last_active'] > self.gc_tts:
session['finished'].set()
if self.tts_model:
self.tts_model.end_streaming_call()
break
except Exception as e:
logging.error(f"Task processing error: {str(e)}")
def _generate_audio(self, session_id, text):
#实际生成音频(线程池执行)
session = self.sessions.get(session_id)
if not session: return
# logging.info(f"_generate_audio:{text}")
first_chunk = True
try:
self.tts_model.streaming_call(text)
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
logging.info(f"--streaming_call--error {str(e)}")
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def get_session(self, session_id):
return self.sessions.get(session_id)
"""