在生成对话文字时,同时在后台生成tts音频,增加朗读音色选择,增加博物馆的概况接口
This commit is contained in:
@@ -988,7 +988,44 @@ class CanvasTemplate(DataBaseModel):
|
||||
class Meta:
|
||||
db_table = "canvas_template"
|
||||
|
||||
# ------------added by cyx for mesum overview
|
||||
class MesumOverview(DataBaseModel):
|
||||
name = CharField(
|
||||
max_length=128,
|
||||
null=False,
|
||||
help_text="mesum name",
|
||||
primary_key=False)
|
||||
|
||||
longitude = CharField(
|
||||
max_length=40,
|
||||
null=True,
|
||||
help_text="Longitude",
|
||||
index=False)
|
||||
|
||||
latitude = CharField(
|
||||
max_length=40,
|
||||
null=True,
|
||||
help_text="latitude",
|
||||
index=False)
|
||||
|
||||
antique=CharField(
|
||||
max_length=1024,
|
||||
null=True,
|
||||
help_text="antique",
|
||||
index=False)
|
||||
|
||||
brief = CharField(
|
||||
max_length=1024,
|
||||
null=True,
|
||||
help_text="brief",
|
||||
index=False)
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
class Meta:
|
||||
db_table = "mesum_overview"
|
||||
|
||||
#-------------------------------------------
|
||||
def migrate_db():
|
||||
with DB.transaction():
|
||||
migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB)
|
||||
|
||||
31
api/db/services/brief_service.py
Normal file
31
api/db/services/brief_service.py
Normal file
@@ -0,0 +1,31 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import User, Tenant, MesumOverview
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
||||
from api.db import StatusEnum
|
||||
|
||||
|
||||
class MesumOverviewService(CommonService):
|
||||
model = MesumOverview
|
||||
|
||||
@@ -33,37 +33,103 @@ from rag.nlp.search import index_name
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from peewee import fn
|
||||
import threading, queue
|
||||
import threading, queue,uuid,time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
# 创建一个 TTS 生成线程
|
||||
class TTSWorker(threading.Thread):
|
||||
def __init__(self, tenant_id, tts_id, tts_text_queue, tts_audio_queue):
|
||||
super().__init__()
|
||||
self.tts_mdl = LLMBundle(tenant_id, LLMType.TTS, tts_id)
|
||||
self.tts_text_queue = tts_text_queue
|
||||
self.tts_audio_queue = tts_audio_queue
|
||||
self.daemon = True # 设置为守护线程,主线程退出时,子线程也会自动退出
|
||||
|
||||
def run(self):
|
||||
class StreamSessionManager:
|
||||
def __init__(self):
|
||||
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
|
||||
self.lock = threading.Lock()
|
||||
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
||||
self.gc_interval = 300 # 5分钟清理一次
|
||||
|
||||
def create_session(self, tts_model):
|
||||
session_id = str(uuid.uuid4())
|
||||
with self.lock:
|
||||
self.sessions[session_id] = {
|
||||
'tts_model': tts_model,
|
||||
'buffer': queue.Queue(maxsize=100), # 线程安全队列
|
||||
'task_queue': queue.Queue(),
|
||||
'active': True,
|
||||
'last_active': time.time(),
|
||||
'audio_chunk_count':0
|
||||
}
|
||||
# 启动任务处理线程
|
||||
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
||||
return session_id
|
||||
|
||||
def append_text(self, session_id, text):
|
||||
with self.lock:
|
||||
session = self.sessions.get(session_id)
|
||||
if not session: return
|
||||
# 将文本放入任务队列(非阻塞)
|
||||
try:
|
||||
session['task_queue'].put(text, block=False)
|
||||
except queue.Full:
|
||||
logging.warning(f"Session {session_id} task queue full")
|
||||
|
||||
def _process_tasks(self, session_id):
|
||||
"""任务处理线程(每个会话独立)"""
|
||||
while True:
|
||||
# 从队列中获取数据
|
||||
delta_ans = self.tts_text_queue.get()
|
||||
if delta_ans is None: # 如果队列中没有数据,退出线程
|
||||
session = self.sessions.get(session_id)
|
||||
if not session or not session['active']:
|
||||
break
|
||||
try:
|
||||
# 调用 TTS 生成音频数据
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
if tts_input_is_valid:
|
||||
logging.info(f"--tts threading {delta_ans} {tts_input_is_valid} {sanitized_text}")
|
||||
bin = b""
|
||||
for chunk in self.tts_mdl.tts(sanitized_text):
|
||||
bin += chunk
|
||||
# 将生成的音频数据存储到队列中或直接处理
|
||||
self.tts_audio_queue.put(bin)
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating TTS for text '{delta_ans}': {e}")
|
||||
# 合并多个文本块(最多等待50ms)
|
||||
texts = []
|
||||
while len(texts) < 5: # 最大合并5个文本块
|
||||
try:
|
||||
text = session['task_queue'].get(timeout=0.05)
|
||||
texts.append(text)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
if texts:
|
||||
# 提交到线程池处理
|
||||
future=self.executor.submit(
|
||||
self._generate_audio,
|
||||
session_id,
|
||||
' '.join(texts) # 合并文本减少请求次数
|
||||
)
|
||||
future.result() # 等待转换任务执行完毕
|
||||
# 会话超时检查
|
||||
if time.time() - session['last_active'] > self.gc_interval:
|
||||
self.close_session(session_id)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Task processing error: {str(e)}")
|
||||
|
||||
def _generate_audio(self, session_id, text):
|
||||
"""实际生成音频(线程池执行)"""
|
||||
session = self.sessions.get(session_id)
|
||||
if not session: return
|
||||
# logging.info(f"_generate_audio:{text}")
|
||||
try:
|
||||
for chunk in session['tts_model'].tts(text):
|
||||
session['buffer'].put(chunk)
|
||||
session['last_active'] = time.time()
|
||||
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
||||
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
|
||||
except Exception as e:
|
||||
session['buffer'].put(f"ERROR:{str(e)}")
|
||||
|
||||
def close_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
# 标记会话为不活跃
|
||||
self.sessions[session_id]['active'] = False
|
||||
# 延迟30秒后清理资源
|
||||
threading.Timer(10, self._clean_session, args=[session_id]).start()
|
||||
|
||||
def _clean_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
del self.sessions[session_id]
|
||||
|
||||
stream_manager = StreamSessionManager()
|
||||
|
||||
class DialogService(CommonService):
|
||||
model = Dialog
|
||||
@@ -235,6 +301,73 @@ def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
|
||||
# 如果通过所有检查,返回有效标志和修正后的文本
|
||||
return True, delta_ans
|
||||
|
||||
def _should_flush(text_chunk,chunk_buffer,last_flush_time):
|
||||
"""智能判断是否需要立即生成音频"""
|
||||
# 规则1:遇到句子结束标点
|
||||
if re.search(r'[。!?,]$', text_chunk):
|
||||
return True
|
||||
|
||||
if re.search(r'(\d{4})(年|月|日|,)', text_chunk):
|
||||
return False # 不刷新,继续合并
|
||||
# 规则2:达到最大缓冲长度(200字符)
|
||||
if sum(len(c) for c in chunk_buffer) >= 200:
|
||||
return True
|
||||
# 规则3:超过500ms未刷新
|
||||
if time.time() - last_flush_time > 0.5:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
MAX_BUFFER_LEN = 200 # 最大缓冲长度
|
||||
FLUSH_TIMEOUT = 0.5 # 强制刷新时间(秒)
|
||||
|
||||
# 智能查找文本最佳分割点(标点/语义单位/短语边界)
|
||||
def find_split_position(text):
|
||||
"""智能查找最佳分割位置"""
|
||||
# 优先查找句子结束符
|
||||
sentence_end = list(re.finditer(r'[。!?]', text))
|
||||
if sentence_end:
|
||||
return sentence_end[-1].end()
|
||||
|
||||
# 其次查找自然停顿符
|
||||
pause_mark = list(re.finditer(r'[,;、]', text))
|
||||
if pause_mark:
|
||||
return pause_mark[-1].end()
|
||||
|
||||
# 防止截断日期/数字短语
|
||||
date_pattern = re.search(r'\d+(年|月|日)(?!\d)', text)
|
||||
if date_pattern:
|
||||
return date_pattern.end()
|
||||
|
||||
# 避免拆分常见短语
|
||||
for phrase in ["青少年", "博物馆", "参观"]:
|
||||
idx = text.rfind(phrase)
|
||||
if idx != -1 and idx + len(phrase) <= len(text):
|
||||
return idx + len(phrase)
|
||||
|
||||
return None
|
||||
|
||||
# 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分
|
||||
def process_buffer(chunk_buffer, force_flush=False):
|
||||
"""处理文本缓冲区,返回待发送文本和剩余缓冲区"""
|
||||
current_text = "".join(chunk_buffer)
|
||||
if not current_text:
|
||||
return "", []
|
||||
|
||||
split_pos = find_split_position(current_text)
|
||||
|
||||
# 强制刷新逻辑
|
||||
if force_flush or len(current_text) >= MAX_BUFFER_LEN:
|
||||
# 即使强制刷新也要尽量找合适的分割点
|
||||
if split_pos is None or split_pos < len(current_text) // 2:
|
||||
split_pos = max(split_pos or 0, MAX_BUFFER_LEN)
|
||||
split_pos = min(split_pos, len(current_text))
|
||||
|
||||
if split_pos is not None and split_pos > 0:
|
||||
return current_text[:split_pos], [current_text[split_pos:]]
|
||||
|
||||
return "", chunk_buffer
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
st = timer()
|
||||
@@ -283,7 +416,10 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
tts_mdl = None
|
||||
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,dialog.tts_id)
|
||||
if kwargs.get('tts_model'):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,kwargs.get('tts_model'))
|
||||
else:
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS, dialog.tts_id)
|
||||
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
@@ -388,34 +524,83 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if stream:
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
# 创建TTS会话(提前初始化)
|
||||
tts_session_id = stream_manager.create_session(tts_mdl)
|
||||
audio_url = f"/tts_stream/{tts_session_id}"
|
||||
first_chunk = True
|
||||
chunk_buffer = [] # 新增文本缓冲
|
||||
last_flush_time = time.time() # 初始化时间戳
|
||||
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
if num_tokens_from_string(delta_ans) < 24:
|
||||
continue
|
||||
|
||||
last_ans = answer
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
# cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
if kwargs.get('tts_disable'):
|
||||
tts_input_is_valid =False
|
||||
if tts_input_is_valid:
|
||||
# 缓冲文本直到遇到标点
|
||||
chunk_buffer.append(sanitized_text)
|
||||
# 处理缓冲区内容
|
||||
while True:
|
||||
# 判断是否需要强制刷新
|
||||
force = time.time() - last_flush_time > FLUSH_TIMEOUT
|
||||
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
|
||||
|
||||
if not to_send:
|
||||
break
|
||||
|
||||
# 发送有效内容
|
||||
stream_manager.append_text(tts_session_id, to_send)
|
||||
chunk_buffer = remaining
|
||||
last_flush_time = time.time()
|
||||
"""
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||||
"""
|
||||
|
||||
# 首块返回音频URL
|
||||
if first_chunk:
|
||||
yield {
|
||||
"answer": answer,
|
||||
"delta_ans": sanitized_text,
|
||||
"audio_stream_url": audio_url,
|
||||
"session_id": tts_session_id,
|
||||
"reference": {}
|
||||
}
|
||||
first_chunk = False
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||||
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
# stream_manager.append_text(tts_session_id, delta_ans)
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空调用tts 出错
|
||||
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
if tts_input_is_valid:
|
||||
# 20250221 修改,在后端生成音频数据
|
||||
chunk_buffer.append(sanitized_text)
|
||||
stream_manager.append_text(tts_session_id, ''.join(chunk_buffer))
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||||
"""
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||||
"""
|
||||
|
||||
yield decorate_answer(answer)
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user