367 lines
15 KiB
Python
367 lines
15 KiB
Python
import asyncio,logging
|
||
from collections import deque
|
||
import threading, time,queue,uuid,time,array
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
ALI_KEY = "sk-a47a3fb5f4a94f66bbaf713779101c75"
|
||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||
from dashscope.audio.tts import (
|
||
ResultCallback as TTSResultCallback,
|
||
SpeechSynthesizer as TTSSpeechSynthesizer,
|
||
SpeechSynthesisResult as TTSSpeechSynthesisResult,
|
||
)
|
||
# cyx 2025 01 19 测试cosyvoice 使用tts_v2 版本
|
||
from dashscope.audio.tts_v2 import (
|
||
ResultCallback as CosyResultCallback,
|
||
SpeechSynthesizer as CosySpeechSynthesizer,
|
||
AudioFormat,
|
||
)
|
||
|
||
class QwenTTS:
|
||
def __init__(self, key,format="mp3",sample_rate=44100, model_name="cosyvoice-v1/longyuan"):
|
||
import dashscope
|
||
import ssl
|
||
logging.info(f"---QwenTTS Construtor-- {format} {sample_rate} {model_name}") # cyx
|
||
self.model_name = model_name
|
||
dashscope.api_key = key
|
||
ssl._create_default_https_context = ssl._create_unverified_context # 禁用验证
|
||
self.synthesizer = None
|
||
self.callback = None
|
||
self.is_cosyvoice = False
|
||
self.voice = ""
|
||
self.format = format
|
||
self.sample_rate = sample_rate
|
||
if '/' in model_name:
|
||
parts = model_name.split('/', 1)
|
||
# 返回分离后的两个字符串parts[0], parts[1]
|
||
if parts[0] == 'cosyvoice-v1':
|
||
self.is_cosyvoice = True
|
||
self.voice = parts[1]
|
||
|
||
class Callback(TTSResultCallback):
|
||
def __init__(self) -> None:
|
||
self.dque = deque()
|
||
|
||
def _run(self):
|
||
while True:
|
||
if not self.dque:
|
||
time.sleep(0)
|
||
continue
|
||
val = self.dque.popleft()
|
||
if val:
|
||
yield val
|
||
else:
|
||
break
|
||
|
||
def on_open(self):
|
||
pass
|
||
|
||
def on_complete(self):
|
||
self.dque.append(None)
|
||
|
||
def on_error(self, response: SpeechSynthesisResponse):
|
||
print("Qwen tts error", str(response))
|
||
raise RuntimeError(str(response))
|
||
|
||
def on_close(self):
|
||
pass
|
||
|
||
def on_event(self, result: TTSSpeechSynthesisResult):
|
||
if result.get_audio_frame() is not None:
|
||
self.dque.append(result.get_audio_frame())
|
||
|
||
# --------------------------
|
||
|
||
class Callback_Cosy(CosyResultCallback):
|
||
def __init__(self,on_audio_data) -> None:
|
||
self.dque = deque()
|
||
self.on_audio_data = on_audio_data
|
||
|
||
def _run(self):
|
||
while True:
|
||
if not self.dque:
|
||
time.sleep(0)
|
||
continue
|
||
val = self.dque.popleft()
|
||
if val:
|
||
yield val
|
||
else:
|
||
break
|
||
|
||
def on_open(self):
|
||
logging.info("---Qwen tts on_open---")
|
||
pass
|
||
|
||
def on_complete(self):
|
||
self.dque.append(None)
|
||
|
||
def on_error(self, response: SpeechSynthesisResponse):
|
||
print("Qwen tts error", str(response))
|
||
raise RuntimeError(str(response))
|
||
|
||
def on_close(self):
|
||
# print("---Qwen call back close") # cyx
|
||
logging.info("---Qwen tts on_close---")
|
||
pass
|
||
|
||
""" canceled for test 语音大模型CosyVoice
|
||
def on_event(self, result: SpeechSynthesisResult):
|
||
if result.get_audio_frame() is not None:
|
||
self.dque.append(result.get_audio_frame())
|
||
"""
|
||
|
||
def on_event(self, message):
|
||
# print(f"recv speech synthsis message {message}")
|
||
pass
|
||
|
||
# 以下适合语音大模型CosyVoice
|
||
def on_data(self, data: bytes) -> None:
|
||
if len(data) > 0:
|
||
if self.on_audio_data:
|
||
self.on_audio_data(data)
|
||
else:
|
||
self.dque.append(data)
|
||
|
||
# --------------------------
|
||
|
||
def tts(self, text):
|
||
print(f"--QwenTTS--tts_stream begin-- {text} {self.is_cosyvoice} {self.voice}") # cyx
|
||
# text = self.normalize_text(text)
|
||
try:
|
||
# if self.model_name != 'cosyvoice-v1':
|
||
if self.is_cosyvoice is False:
|
||
self.callback = self.Callback()
|
||
TTSSpeechSynthesizer.call(model=self.model_name,
|
||
text=text,
|
||
callback=self.callback,
|
||
format="wav") # format="mp3")
|
||
else:
|
||
self.callback = self.Callback_Cosy(None)
|
||
format =self.get_audio_format(self.format,self.sample_rate)
|
||
self.synthesizer = CosySpeechSynthesizer(
|
||
model='cosyvoice-v1',
|
||
# voice="longyuan", #"longfei",
|
||
voice=self.voice,
|
||
callback=self.callback,
|
||
format=format
|
||
)
|
||
self.synthesizer.call(text)
|
||
except Exception as e:
|
||
print(f"---dale---20 error {e}") # cyx
|
||
# -----------------------------------
|
||
try:
|
||
for data in self.callback._run():
|
||
#logging.info(f"dashcope return data {len(data)}")
|
||
yield data
|
||
# print(f"---Qwen return data {num_tokens_from_string(text)}")
|
||
# yield num_tokens_from_string(text)
|
||
|
||
except Exception as e:
|
||
raise RuntimeError(f"**ERROR**: {e}")
|
||
|
||
def init_streaming_call(self, on_data):
|
||
try:
|
||
self.callback = self.Callback_Cosy(on_data)
|
||
format =self.get_audio_format(self.format,self.sample_rate)
|
||
self.synthesizer = CosySpeechSynthesizer(
|
||
model='cosyvoice-v1',
|
||
# voice="longyuan", #"longfei",
|
||
voice=self.voice,
|
||
callback=self.callback,
|
||
format=format
|
||
)
|
||
except Exception as e:
|
||
print(f"---dale---30 error {e}") # cyx
|
||
# -----------------------------------
|
||
|
||
def streaming_call(self,text):
|
||
if self.synthesizer:
|
||
self.synthesizer.streaming_call(text)
|
||
def end_streaming_call(self):
|
||
if self.synthesizer:
|
||
self.synthesizer.streaming_complete()
|
||
|
||
def get_audio_format(self, format: str, sample_rate: int):
|
||
"""动态获取音频格式"""
|
||
from dashscope.audio.tts_v2 import AudioFormat
|
||
format_map = {
|
||
(8000, 'mp3'): AudioFormat.MP3_8000HZ_MONO_128KBPS,
|
||
(8000, 'pcm'): AudioFormat.PCM_8000HZ_MONO_16BIT,
|
||
(8000, 'wav'): AudioFormat.WAV_8000HZ_MONO_16BIT,
|
||
(16000, 'pcm'): AudioFormat.PCM_16000HZ_MONO_16BIT,
|
||
(22050, 'mp3'): AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||
(22050, 'pcm'): AudioFormat.PCM_22050HZ_MONO_16BIT,
|
||
(22050, 'wav'): AudioFormat.WAV_22050HZ_MONO_16BIT,
|
||
(44100, 'mp3'): AudioFormat.MP3_44100HZ_MONO_256KBPS,
|
||
(44100, 'pcm'): AudioFormat.PCM_44100HZ_MONO_16BIT,
|
||
(44100, 'wav'): AudioFormat.WAV_44100HZ_MONO_16BIT,
|
||
(48000, 'mp3'): AudioFormat.MP3_48000HZ_MONO_256KBPS,
|
||
(48000, 'pcm'): AudioFormat.PCM_48000HZ_MONO_16BIT,
|
||
(48000, 'wav'):AudioFormat.WAV_48000HZ_MONO_16BIT
|
||
|
||
}
|
||
return format_map.get((sample_rate, format), AudioFormat.MP3_16000HZ_MONO_128KBPS)
|
||
|
||
class StreamSessionManager:
|
||
def __init__(self):
|
||
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
|
||
self.lock = threading.Lock()
|
||
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
||
self.gc_interval = 300 # 5分钟清理一次 5 x 60 300秒
|
||
self.gc_tts = 10 # 10s 大模型开始输出文本有可能需要比较久,2025年5 24 从3s->10s
|
||
|
||
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3',voice='cosyvoice-v1/longxiaochun'):
|
||
session_id = str(uuid.uuid4())
|
||
def on_audio_data(chunk):
|
||
session = self.sessions.get(session_id)
|
||
first_chunk = not session['tts_chunk_data_valid']
|
||
if session['stream_format'] == 'wav':
|
||
if first_chunk:
|
||
chunk_len = len(chunk)
|
||
if chunk_len > 2048:
|
||
session['buffer'].put(audio_fade_in(chunk, 1024))
|
||
else:
|
||
session['buffer'].put(audio_fade_in(chunk, chunk_len))
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
session['last_active'] = time.time()
|
||
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
||
if session['tts_chunk_data_valid'] is False:
|
||
session['tts_chunk_data_valid'] = True # 20250510 增加,表示连接TTS后台已经返回,可以通知前端了
|
||
|
||
with self.lock:
|
||
ali_tts_model = QwenTTS(ALI_KEY,stream_format, sample_rate,voice.split('@')[0])
|
||
self.sessions[session_id] = {
|
||
'tts_model': ali_tts_model, #tts_model,
|
||
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
||
'task_queue': queue.Queue(),
|
||
'active': True,
|
||
'last_active': time.time(),
|
||
'audio_chunk_count':0,
|
||
'finished': threading.Event(), # 添加事件对象
|
||
'sample_rate':sample_rate,
|
||
'stream_format':stream_format,
|
||
"tts_chunk_data_valid":False,
|
||
'voice':voice,
|
||
}
|
||
self.sessions[session_id]['tts_model'].init_streaming_call(on_audio_data)
|
||
# 启动任务处理线程
|
||
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
||
return session_id
|
||
|
||
def append_text(self, session_id, text):
|
||
with self.lock:
|
||
session = self.sessions.get(session_id)
|
||
if not session: return
|
||
# 将文本放入任务队列(非阻塞)
|
||
#logging.info(f"StreamSessionManager append_text {text}")
|
||
try:
|
||
session['task_queue'].put(text, block=False)
|
||
except queue.Full:
|
||
logging.warning(f"Session {session_id} task queue full")
|
||
|
||
def _process_tasks(self, session_id):
|
||
"""任务处理线程(每个会话独立)"""
|
||
while True:
|
||
session = self.sessions.get(session_id)
|
||
if not session or not session['active']:
|
||
break
|
||
try:
|
||
#logging.info(f"StreamSessionManager _process_tasks {session['task_queue'].qsize()}")
|
||
# 合并多个文本块(最多等待50ms)
|
||
texts = []
|
||
while len(texts) < 5: # 最大合并5个文本块
|
||
try:
|
||
text = session['task_queue'].get(timeout=0.1)
|
||
#logging.info(f"StreamSessionManager _process_tasks --0 {len(texts)}")
|
||
texts.append(text)
|
||
except queue.Empty:
|
||
break
|
||
|
||
if texts:
|
||
session['last_active'] = time.time() # 如果有处理文本,重置活跃时间
|
||
# 提交到线程池处理
|
||
#future=self.executor.submit(
|
||
# self._generate_audio,
|
||
# session_id,
|
||
# ' '.join(texts) # 合并文本减少请求次数
|
||
#)
|
||
#future.result() # 等待转换任务执行完毕
|
||
session['tts_model'].streaming_call(''.join(texts))
|
||
session['last_active'] = time.time()
|
||
# 会话超时检查
|
||
if time.time() - session['last_active'] > self.gc_interval:
|
||
self.close_session(session_id)
|
||
break
|
||
if time.time() - session['last_active'] > self.gc_tts:
|
||
session['tts_model'].end_streaming_call()
|
||
session['finished'].set()
|
||
break
|
||
|
||
except Exception as e:
|
||
logging.error(f"Task processing error: {str(e)}")
|
||
|
||
def _generate_audio(self, session_id, text):
|
||
"""实际生成音频(线程池执行)"""
|
||
session = self.sessions.get(session_id)
|
||
if not session: return
|
||
# logging.info(f"_generate_audio:{text}")
|
||
first_chunk = True
|
||
logging.info(f"转换开始!!! {text}")
|
||
try:
|
||
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
|
||
if session['stream_format'] == 'wav':
|
||
if first_chunk:
|
||
chunk_len = len(chunk)
|
||
if chunk_len > 2048:
|
||
session['buffer'].put(audio_fade_in(chunk,1024))
|
||
else:
|
||
session['buffer'].put(audio_fade_in(chunk, chunk_len))
|
||
first_chunk = False
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
session['last_active'] = time.time()
|
||
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
||
if session['tts_chunk_data_valid'] is False:
|
||
session['tts_chunk_data_valid'] = True #20250510 增加,表示连接TTS后台已经返回,可以通知前端了
|
||
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
|
||
except Exception as e:
|
||
session['buffer'].put(f"ERROR:{str(e)}")
|
||
logging.info(f"--_generate_audio--error {str(e)}")
|
||
|
||
|
||
def close_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
logging.info(f"--Session {session_id} close_session")
|
||
# 标记会话为不活跃
|
||
self.sessions[session_id]['active'] = False
|
||
# 延迟2秒后清理资源
|
||
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
||
|
||
def _clean_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
del self.sessions[session_id]
|
||
|
||
def get_session(self, session_id):
|
||
return self.sessions.get(session_id)
|
||
|
||
|
||
stream_manager_w_stream = StreamSessionManager()
|
||
def audio_fade_in(audio_data, fade_length):
|
||
# 假设音频数据是16位单声道PCM
|
||
# 将二进制数据转换为整数数组
|
||
samples = array.array('h', audio_data)
|
||
|
||
# 对前fade_length个样本进行淡入处理
|
||
for i in range(fade_length):
|
||
fade_factor = i / fade_length
|
||
samples[i] = int(samples[i] * fade_factor)
|
||
|
||
# 将整数数组转换回二进制数据
|
||
return samples.tobytes() |