Files
ragflow_python/asr-monitor-test/backup.txt

264 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
class QwenTTS:
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longxiaochun"):
import dashscope
print("---begin--init dialog_service QwenTTS--") # cyx
self.model_name = model_name
dashscope.api_key = key
self.synthesizer = None
self.callback = None
self.is_cosyvoice = False
self.voice = ""
self.format = format
self.sample_rate = sample_rate
if '/' in model_name:
parts = model_name.split('/', 1)
# 返回分离后的两个字符串parts[0], parts[1]
if parts[0] == 'cosyvoice-v1':
self.is_cosyvoice = True
self.voice = parts[1]
def init_streaming_call(self, audio_call_back):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat # , SpeechSynthesisResult
from dashscope.audio.tts import SpeechSynthesisResult
from collections import deque
print(f"--QwenTTS--tts_stream begin-- {self.is_cosyvoice} {self.voice}") # cyx
class Callback_v2(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
def _run(self):
while True:
if not self.dque:
time.sleep(0)
continue
val = self.dque.popleft()
if val:
yield val
else:
break
def on_open(self):
logging.info("Qwen tts open")
pass
def on_complete(self):
self.dque.append(None)
def on_error(self, response: SpeechSynthesisResponse):
print("Qwen tts error", str(response))
raise RuntimeError(str(response))
def on_close(self):
# print("---Qwen call back close") # cyx
logging.info("Qwen tts close")
pass
# canceled for test 语音大模型CosyVoice
#def on_event(self, result: SpeechSynthesisResult):
# if result.get_audio_frame() is not None:
# self.dque.append(result.get_audio_frame())
#
def on_event(self, message):
# print(f"recv speech synthsis message {message}")
pass
# 以下适合语音大模型CosyVoice
def on_data(self, data: bytes) -> None:
if len(data) > 0:
#self.dque.append(data)
if audio_call_back:
audio_call_back(data)
# --------------------------
# text = self.normalize_text(text)
try:
self.callback = Callback_v2()
format =self.get_audio_format(self.format,self.sample_rate)
self.synthesizer = SpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice=self.voice,
callback=self.callback,
format=format
)
#self.synthesizer.call(text)
except Exception as e:
print(f"---dale---20 error {e}") # cyx
# -----------------------------------
# print(f"---Qwen return data {num_tokens_from_string(text)}")
# yield num_tokens_from_string(text)
except Exception as e:
raise RuntimeError(f"**ERROR**: {e}")
def get_audio_format(self, format: str, sample_rate: int):
# 动态获取音频格式
from dashscope.audio.tts_v2 import AudioFormat
logging.info(f"QwenTTS--get_audio_format-- {format} {sample_rate}")
format_map = {
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
(48800, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
(48800, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
(48800, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
}
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
def streaming_call(self,text):
if self.synthesizer:
self.synthesizer.streaming_call(text)
def end_streaming_call(self):
if self.synthesizer:
self.synthesizer.streaming_complete()
class StreamSessionManager1:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次 5 x 60 300秒
self.gc_tts = 10 # 10s 大模型开始输出文本有可能需要比较久2025年5 24 从3s->10s
self.inited = False
self.tts_model = None
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
session_id = str(uuid.uuid4())
with self.lock:
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count':0,
'finished': threading.Event(), # 添加事件对象
'sample_rate':sample_rate,
'stream_format':stream_format,
"tts_chunk_data_valid":False
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 将文本放入任务队列(非阻塞)
#logging.info(f"StreamSessionManager append_text {text}")
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
#任务处理线程(每个会话独立)
session = self.sessions.get(session_id)
def audio_call_back(chunk):
logging.info(f"audio_call_back {len(chunk)}")
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk, 1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
first_chunk = False
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True # 20250510 增加表示连接TTS后台已经返回可以通知前端了
while True:
if not session or not session['active']:
break
if not self.inited:
self.inited = True
self.tts_model = QwenTTS(ALI_KEY, session['stream_format'], session['sample_rate'])
self.tts_model.init_streaming_call(audio_call_back)
try:
#logging.info(f"StreamSessionManager _process_tasks {session['task_queue'].qsize()}")
# 合并多个文本块最多等待50ms
texts = []
while len(texts) < 5: # 最大合并5个文本块
try:
text = session['task_queue'].get(timeout=0.1)
#logging.info(f"StreamSessionManager _process_tasks --0 {len(texts)}")
texts.append(text)
except queue.Empty:
break
if texts:
session['last_active'] = time.time() # 如果有处理文本,重置活跃时间
# 提交到线程池处理
future=self.executor.submit(
self._generate_audio,
session_id,
' '.join(texts) # 合并文本减少请求次数
)
future.result() # 等待转换任务执行完毕
session['last_active'] = time.time()
# 会话超时检查
if time.time() - session['last_active'] > self.gc_interval:
self.close_session(session_id)
break
if time.time() - session['last_active'] > self.gc_tts:
session['finished'].set()
if self.tts_model:
self.tts_model.end_streaming_call()
break
except Exception as e:
logging.error(f"Task processing error: {str(e)}")
def _generate_audio(self, session_id, text):
#实际生成音频(线程池执行)
session = self.sessions.get(session_id)
if not session: return
# logging.info(f"_generate_audio:{text}")
first_chunk = True
try:
self.tts_model.streaming_call(text)
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
logging.info(f"--streaming_call--error {str(e)}")
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
def get_session(self, session_id):
return self.sessions.get(session_id)
"""