在生成对话文字时,同时在后台生成tts音频,增加朗读音色选择,增加博物馆的概况接口

This commit is contained in:
qcloud
2025-02-23 09:52:30 +08:00
parent c88312a914
commit a5e83f4d3b
7 changed files with 653 additions and 224 deletions

View File

@@ -13,17 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from flask import request , Response, jsonify
from api import settings
from api.db import LLMType
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService
from api.db.services.dialog_service import DialogService,stream_manager
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from api.db.services.brief_service import MesumOverviewService
from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result
import logging
import base64
import queue,time,uuid
from threading import Lock,Thread
from zhipuai import ZhipuAI
# 用户已经添加的模型 cyx 2025-01-26
@manager.route('/get_llms', methods=['GET'])
@@ -48,4 +55,352 @@ def my_llms(tenant_id):
})
return get_result(data=res)
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
return get_error_data_result(message=f"Get LLMS error {e}")
main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链"
@manager.route('/photo/recongeText', methods=['POST'])
@token_required
def upload_file(tenant_id):
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
file_size = request.content_length
img_base = base64.b64encode(file.read()).decode('utf-8')
req_antique = request.form.get('antique',None)
if req_antique is None:
req_antique = main_antiquity
logging.info(f"recevie photo file {file.filename} {file_size} 识别中....")
client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey
response = client.chat.completions.create(
model="glm-4v-plus", # 填写需要调用的模型名称
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": img_base
}
},
{
"type": "text",
"text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称,"
f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍"
f"你的回答有2个结果第一个结果是是从文字进行分析出匹配文物候选文物只能如下:{req_antique},"
f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出,"
f",第二个结果是原始识别的所有文字"
"2个结果输出以{ }的json格式给出匹配文物的键值为antique如果有多个请加序号如:antique1,antique2,"
f"原始数据的键值为text输出是1个完整的JSON数据不要有多余的前置和后置内容确保前端能正确解析出JSON数据")
}
]
}
]
)
message = response.choices[0].message
logging.info(message.content)
return jsonify({'message': 'File uploaded successfully','text':message.content }), 200
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
#get_all
@manager.route('/mesum/list', methods=['GET'])
@token_required
def mesum_list(tenant_id):
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
# model_type = request.args.get("type")
try:
res = []
overviews=MesumOverviewService.get_all()
for o in overviews:
res.append(o.to_dict())
return get_result(data=res)
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
@manager.route('/mesum/set_antique', methods=['POST'])
@token_required
def mesum_set_antique(tenant_id):
global main_antiquity
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
req_data = request.json
req_data_antique=req_data.get('antique',None)
try:
if req_data_antique:
main_antiquity = req_data_antique
print(main_antiquity)
return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'})
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
now = time.time()
expired_keys = [
k for k, v in audio_text_cache.items()
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
]
for k in expired_keys:
entry = audio_text_cache.pop(k, None)
if entry and entry.get('audio_stream'):
entry['audio_stream'].close()
def start_background_cleaner():
"""启动后台清理线程"""
def cleaner_loop():
while True:
time.sleep(180) # 每3分钟清理一次
clean_audio_cache()
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
cleaner_thread.start()
# 应用启动时启动清理线程
start_background_cleaner()
@manager.route('/tts_stream/<session_id>')
def tts_stream(session_id):
def generate():
retry_count = 0
session = None
count = 0;
try:
while retry_count < 1:
session = stream_manager.sessions.get(session_id)
if not session or not session['active']:
break
try:
chunk = session['buffer'].get(timeout=5) # 30秒超时
count = count + 1
if isinstance(chunk, str) and chunk.startswith("ERROR"):
logging.info("---tts stream error!!!!")
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
break
yield chunk
retry_count = 0 # 成功收到数据重置重试计数器
except queue.Empty:
retry_count += 1
yield b'' # 保持连接
finally:
# 确保流结束后关闭会话
if session:
# 延迟关闭会话,确保所有数据已发送
time.sleep(5) # 等待5秒确保流结束
stream_manager.close_session(session_id)
logging.info(f"Session {session_id} closed.")
resp = Response(generate(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
def dialog_tts_get(chat_id, audio_stream_id):
with cache_lock:
tts_info = audio_text_cache.pop(audio_stream_id, None) # 取出即删除
try:
req = tts_info
if not req:
return get_error_data_result(message="Audio stream not found or expired.")
audio_stream = req.get('audio_stream')
tenant_id = req.get('tenant_id')
chat_id = req.get('chat_id')
text = req.get('text', "..")
model_name = req.get('model_name')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not dia:
return get_error_data_result(message="You do not own the chat")
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
def stream_audio():
try:
for chunk in tts_mdl.tts(text):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
def generate():
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
if audio_stream:
# 确保流的位置在开始处
audio_stream.seek(0)
resp = Response(generate(), mimetype="audio/mpeg")
else:
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
except Exception as e:
logging.error(f"音频流传输错误: {str(e)}", exc_info=True)
return get_error_data_result(message="音频流传输失败")
finally:
# 确保资源释放
if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
tts_info['audio_stream'].close()
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
@token_required
def dialog_tts_post(tenant_id, chat_id):
try:
req = request.json
if not req.get("text"):
return get_error_data_result(message="Please input your question.")
delay_gen_audio = req.get('delay_gen_audio', False)
# text = extract_text_from_markdown(req.get('text'))
text = req.get('text')
model_name = req.get('model_name')
audio_stream_id = req.get('audio_stream_id', None)
if audio_stream_id is None:
audio_stream_id = str(uuid.uuid4())
# 在这里生成音频流并存储到内存中
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
if delay_gen_audio:
audio_stream = None
else:
audio_stream = io.BytesIO()
# 结构化缓存数据
tts_info = {
'text': text,
'tenant_id': tenant_id,
'chat_id': chat_id,
'created_at': time.time(),
'audio_stream': audio_stream, # 维持原有逻辑
'model_name': req.get('model_name'),
'delay_gen_audio': delay_gen_audio, # 明确存储状态
audio_stream_id: audio_stream_id
}
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
if delay_gen_audio is False:
try:
"""
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
try:
if txt is None or txt.strip() == "":
continue
for chunk in tts_mdl.tts(txt):
audio_stream.write(chunk)
except Exception as e:
continue
"""
if text is None or text.strip() == "":
audio_stream.write(b'\x00' * 100)
else:
# 确保在流的末尾写入
audio_stream.seek(0, io.SEEK_END)
for chunk in tts_mdl.tts(text):
audio_stream.write(chunk)
except Exception as e:
logging.info(f"--error:{e}")
with cache_lock:
audio_text_cache.pop(audio_stream_id, None)
return get_error_data_result(message="get tts audio stream error.")
# 构建音频流URL
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
# 返回音频流URL
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
except Exception as e:
logging.error(f"请求处理失败: {str(e)}", exc_info=True)
return get_error_data_result(message="服务器内部错误")

View File

@@ -19,22 +19,22 @@ import logging
from copy import deepcopy
from uuid import uuid4
from api.db import LLMType
from flask import request, Response, jsonify
from flask import request, Response, jsonify, stream_with_context
from api.db.services.dialog_service import ask
from agent.canvas import Canvas
from api.db import StatusEnum
from api.db.db_models import API4Conversation
from api.db.services.api_service import API4ConversationService
from api.db.services.canvas_service import UserCanvasService
from api.db.services.dialog_service import DialogService, ConversationService, chat
from api.db.services.dialog_service import DialogService, ConversationService, chat,stream_manager
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result
from api.utils.api_utils import get_result, token_required
from api.db.services.llm_service import LLMBundle
import uuid
import queue
import queue,time
from threading import Lock,Thread
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
@token_required
@@ -239,186 +239,6 @@ def completion(tenant_id, chat_id): # chat_id 和 别的文件中的dialog_id
break
return get_result(data=answer)
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def is_dale(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':', '': '.'
}
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
audio_text_cache = {}
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
def dialog_tts_get(chat_id, audio_stream_id):
tts_info = audio_text_cache.pop(audio_stream_id, None)
req = tts_info
if not req:
return get_error_data_result(message="Audio stream not found or expired.")
audio_stream = req.get('audio_stream')
tenant_id = req.get('tenant_id')
chat_id = req.get('chat_id')
text = req.get('text', "..")
model_name = req.get('model_name')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not dia:
return get_error_data_result(message="You do not own the chat")
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
def stream_audio():
try:
for chunk in tts_mdl.tts(text):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
def generate():
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
if audio_stream:
# 确保流的位置在开始处
audio_stream.seek(0)
resp = Response(generate(), mimetype="audio/mpeg")
else:
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
@token_required
def dialog_tts_post(tenant_id, chat_id):
req = request.json
if not req.get("text"):
return get_error_data_result(message="Please input your question.")
delay_gen_audio = req.get('delay_gen_audio', False)
# text = extract_text_from_markdown(req.get('text'))
text = req.get('text')
audio_stream_id = req.get('audio_stream_id')
# logging.info(f"request tts audio url:{text} audio_stream_id:{audio_stream_id} ")
if audio_stream_id is None:
audio_stream_id = str(uuid.uuid4())
# 在这里生成音频流并存储到内存中
model_name = req.get('model_name')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
logging.info(f"---tts {tts_model_name}")
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
if delay_gen_audio:
audio_stream = None
else:
audio_stream = io.BytesIO()
audio_text_cache[audio_stream_id] = {'text': text, 'chat_id': chat_id, "tenant_id": tenant_id,
'audio_stream': audio_stream,'model_name':model_name} # 缓存文本以便后续生成音频流
if delay_gen_audio is False:
try:
"""
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
try:
if txt is None or txt.strip() == "":
continue
for chunk in tts_mdl.tts(txt):
audio_stream.write(chunk)
except Exception as e:
continue
"""
if text is None or text.strip() == "":
audio_stream.write(b'\x00' * 100)
else:
for chunk in tts_mdl.tts(text):
audio_stream.write(chunk)
except Exception as e:
return get_error_data_result(message="get tts audio stream error.")
# 构建音频流URL
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
# 返回音频流URL
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
@manager.route('/agents/<agent_id>/completions', methods=['POST'])
@token_required
def agent_completion(tenant_id, agent_id):

View File

@@ -988,7 +988,44 @@ class CanvasTemplate(DataBaseModel):
class Meta:
db_table = "canvas_template"
# ------------added by cyx for mesum overview
class MesumOverview(DataBaseModel):
name = CharField(
max_length=128,
null=False,
help_text="mesum name",
primary_key=False)
longitude = CharField(
max_length=40,
null=True,
help_text="Longitude",
index=False)
latitude = CharField(
max_length=40,
null=True,
help_text="latitude",
index=False)
antique=CharField(
max_length=1024,
null=True,
help_text="antique",
index=False)
brief = CharField(
max_length=1024,
null=True,
help_text="brief",
index=False)
def __str__(self):
return self.name
class Meta:
db_table = "mesum_overview"
#-------------------------------------------
def migrate_db():
with DB.transaction():
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)

View File

@@ -0,0 +1,31 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant, MesumOverview
from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.db import StatusEnum
class MesumOverviewService(CommonService):
model = MesumOverview

View File

@@ -33,37 +33,103 @@ from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
from peewee import fn
import threading, queue
import threading, queue,uuid,time
from concurrent.futures import ThreadPoolExecutor
# 创建一个 TTS 生成线程
class TTSWorker(threading.Thread):
def __init__(self, tenant_id, tts_id, tts_text_queue, tts_audio_queue):
super().__init__()
self.tts_mdl = LLMBundle(tenant_id, LLMType.TTS, tts_id)
self.tts_text_queue = tts_text_queue
self.tts_audio_queue = tts_audio_queue
self.daemon = True # 设置为守护线程,主线程退出时,子线程也会自动退出
def run(self):
class StreamSessionManager:
def __init__(self):
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次
def create_session(self, tts_model):
session_id = str(uuid.uuid4())
with self.lock:
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=100), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count':0
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
return session_id
def append_text(self, session_id, text):
with self.lock:
session = self.sessions.get(session_id)
if not session: return
# 将文本放入任务队列(非阻塞)
try:
session['task_queue'].put(text, block=False)
except queue.Full:
logging.warning(f"Session {session_id} task queue full")
def _process_tasks(self, session_id):
"""任务处理线程(每个会话独立)"""
while True:
# 从队列中获取数据
delta_ans = self.tts_text_queue.get()
if delta_ans is None: # 如果队列中没有数据,退出线程
session = self.sessions.get(session_id)
if not session or not session['active']:
break
try:
# 调用 TTS 生成音频数据
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
if tts_input_is_valid:
logging.info(f"--tts threading {delta_ans} {tts_input_is_valid} {sanitized_text}")
bin = b""
for chunk in self.tts_mdl.tts(sanitized_text):
bin += chunk
# 将生成的音频数据存储到队列中或直接处理
self.tts_audio_queue.put(bin)
except Exception as e:
logging.error(f"Error generating TTS for text '{delta_ans}': {e}")
# 合并多个文本块最多等待50ms
texts = []
while len(texts) < 5: # 最大合并5个文本块
try:
text = session['task_queue'].get(timeout=0.05)
texts.append(text)
except queue.Empty:
break
if texts:
# 提交到线程池处理
future=self.executor.submit(
self._generate_audio,
session_id,
' '.join(texts) # 合并文本减少请求次数
)
future.result() # 等待转换任务执行完毕
# 会话超时检查
if time.time() - session['last_active'] > self.gc_interval:
self.close_session(session_id)
break
except Exception as e:
logging.error(f"Task processing error: {str(e)}")
def _generate_audio(self, session_id, text):
"""实际生成音频(线程池执行)"""
session = self.sessions.get(session_id)
if not session: return
# logging.info(f"_generate_audio:{text}")
try:
for chunk in session['tts_model'].tts(text):
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟30秒后清理资源
threading.Timer(10, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
if session_id in self.sessions:
del self.sessions[session_id]
stream_manager = StreamSessionManager()
class DialogService(CommonService):
model = Dialog
@@ -235,6 +301,73 @@ def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
# 如果通过所有检查,返回有效标志和修正后的文本
return True, delta_ans
def _should_flush(text_chunk,chunk_buffer,last_flush_time):
"""智能判断是否需要立即生成音频"""
# 规则1遇到句子结束标点
if re.search(r'[。!?,]$', text_chunk):
return True
if re.search(r'(\d{4})(年|月|日|,)', text_chunk):
return False # 不刷新,继续合并
# 规则2达到最大缓冲长度200字符
if sum(len(c) for c in chunk_buffer) >= 200:
return True
# 规则3超过500ms未刷新
if time.time() - last_flush_time > 0.5:
return True
return False
MAX_BUFFER_LEN = 200 # 最大缓冲长度
FLUSH_TIMEOUT = 0.5 # 强制刷新时间(秒)
# 智能查找文本最佳分割点(标点/语义单位/短语边界)
def find_split_position(text):
"""智能查找最佳分割位置"""
# 优先查找句子结束符
sentence_end = list(re.finditer(r'[。!?]', text))
if sentence_end:
return sentence_end[-1].end()
# 其次查找自然停顿符
pause_mark = list(re.finditer(r'[,;、]', text))
if pause_mark:
return pause_mark[-1].end()
# 防止截断日期/数字短语
date_pattern = re.search(r'\d+(年|月|日)(?!\d)', text)
if date_pattern:
return date_pattern.end()
# 避免拆分常见短语
for phrase in ["青少年", "博物馆", "参观"]:
idx = text.rfind(phrase)
if idx != -1 and idx + len(phrase) <= len(text):
return idx + len(phrase)
return None
# 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分
def process_buffer(chunk_buffer, force_flush=False):
"""处理文本缓冲区,返回待发送文本和剩余缓冲区"""
current_text = "".join(chunk_buffer)
if not current_text:
return "", []
split_pos = find_split_position(current_text)
# 强制刷新逻辑
if force_flush or len(current_text) >= MAX_BUFFER_LEN:
# 即使强制刷新也要尽量找合适的分割点
if split_pos is None or split_pos < len(current_text) // 2:
split_pos = max(split_pos or 0, MAX_BUFFER_LEN)
split_pos = min(split_pos, len(current_text))
if split_pos is not None and split_pos > 0:
return current_text[:split_pos], [current_text[split_pos:]]
return "", chunk_buffer
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
@@ -283,7 +416,10 @@ def chat(dialog, messages, stream=True, **kwargs):
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,dialog.tts_id)
if kwargs.get('tts_model'):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,kwargs.get('tts_model'))
else:
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS, dialog.tts_id)
# try to use sql if field mapping is good to go
if field_map:
@@ -388,34 +524,83 @@ def chat(dialog, messages, stream=True, **kwargs):
if stream:
last_ans = ""
answer = ""
# 创建TTS会话提前初始化
tts_session_id = stream_manager.create_session(tts_mdl)
audio_url = f"/tts_stream/{tts_session_id}"
first_chunk = True
chunk_buffer = [] # 新增文本缓冲
last_flush_time = time.time() # 初始化时间戳
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
if num_tokens_from_string(delta_ans) < 24:
continue
last_ans = answer
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
# cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
if kwargs.get('tts_disable'):
tts_input_is_valid =False
if tts_input_is_valid:
# 缓冲文本直到遇到标点
chunk_buffer.append(sanitized_text)
# 处理缓冲区内容
while True:
# 判断是否需要强制刷新
force = time.time() - last_flush_time > FLUSH_TIMEOUT
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
if not to_send:
break
# 发送有效内容
stream_manager.append_text(tts_session_id, to_send)
chunk_buffer = remaining
last_flush_time = time.time()
"""
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
"""
# 首块返回音频URL
if first_chunk:
yield {
"answer": answer,
"delta_ans": sanitized_text,
"audio_stream_url": audio_url,
"session_id": tts_session_id,
"reference": {}
}
first_chunk = False
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
delta_ans = answer[len(last_ans):]
if delta_ans:
# stream_manager.append_text(tts_session_id, delta_ans)
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空调用tts 出错
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
if tts_input_is_valid:
# 20250221 修改,在后端生成音频数据
chunk_buffer.append(sanitized_text)
stream_manager.append_text(tts_session_id, ''.join(chunk_buffer))
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
"""
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
"""
yield decorate_answer(answer)
else: