从新提交到gitee 仓库
This commit is contained in:
@@ -912,6 +912,12 @@ class Dialog(DataBaseModel):
|
||||
help_text="is it validate(0: wasted, 1: validate)",
|
||||
default="1",
|
||||
index=True)
|
||||
# tts_id added by cyx 为每一个对话助理设置相应的tts
|
||||
tts_id = CharField(
|
||||
max_length=256,
|
||||
null=True,
|
||||
help_text="default tts model ID",
|
||||
index=True)
|
||||
|
||||
class Meta:
|
||||
db_table = "dialog"
|
||||
|
||||
@@ -32,6 +32,37 @@ from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.nlp.search import index_name
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from peewee import fn
|
||||
import threading, queue
|
||||
|
||||
|
||||
# 创建一个 TTS 生成线程
|
||||
class TTSWorker(threading.Thread):
|
||||
def __init__(self, tenant_id, tts_id, tts_text_queue, tts_audio_queue):
|
||||
super().__init__()
|
||||
self.tts_mdl = LLMBundle(tenant_id, LLMType.TTS, tts_id)
|
||||
self.tts_text_queue = tts_text_queue
|
||||
self.tts_audio_queue = tts_audio_queue
|
||||
self.daemon = True # 设置为守护线程,主线程退出时,子线程也会自动退出
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
# 从队列中获取数据
|
||||
delta_ans = self.tts_text_queue.get()
|
||||
if delta_ans is None: # 如果队列中没有数据,退出线程
|
||||
break
|
||||
try:
|
||||
# 调用 TTS 生成音频数据
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
if tts_input_is_valid:
|
||||
logging.info(f"--tts threading {delta_ans} {tts_input_is_valid} {sanitized_text}")
|
||||
bin = b""
|
||||
for chunk in self.tts_mdl.tts(sanitized_text):
|
||||
bin += chunk
|
||||
# 将生成的音频数据存储到队列中或直接处理
|
||||
self.tts_audio_queue.put(bin)
|
||||
except Exception as e:
|
||||
logging.error(f"Error generating TTS for text '{delta_ans}': {e}")
|
||||
|
||||
|
||||
class DialogService(CommonService):
|
||||
@@ -65,22 +96,61 @@ class ConversationService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id ==dialog_id)
|
||||
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id, name, cols=None):
|
||||
# 构建基础查询
|
||||
print("--ConversationService get_list enter", page_number, items_per_page) # cyx
|
||||
query = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
||||
|
||||
# 如果指定了ID,则添加ID筛选
|
||||
if id:
|
||||
sessions = sessions.where(cls.model.id == id)
|
||||
query = query.where(cls.model.id == id)
|
||||
|
||||
# 如果指定了名称,则添加名称筛选
|
||||
if name:
|
||||
sessions = sessions.where(cls.model.name == name)
|
||||
query = query.where(cls.model.name == name)
|
||||
|
||||
# 如果指定了列筛选,则只选择指定的列
|
||||
if cols:
|
||||
query = query.select(*[getattr(cls.model, col) for col in cols])
|
||||
|
||||
# 获取记录总数
|
||||
total = query.count()
|
||||
# 添加排序
|
||||
if desc:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
||||
query = query.order_by(cls.model.getter_by(orderby).desc())
|
||||
else:
|
||||
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
|
||||
query = query.order_by(cls.model.getter_by(orderby).asc())
|
||||
|
||||
sessions = sessions.paginate(page_number, items_per_page)
|
||||
# 执行分页查询
|
||||
paginated_query = query.paginate(page_number, items_per_page)
|
||||
data = list(paginated_query.dicts())
|
||||
# logging.info("--ConversationService get_list",total, data) #cyx
|
||||
# 返回分页数据和记录总数
|
||||
return total, data
|
||||
|
||||
return list(sessions.dicts())
|
||||
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def query_sessions_summary(cls):
|
||||
# 按 id 分组,统计每个 id 的最旧记录
|
||||
query = (
|
||||
cls.model
|
||||
.select(
|
||||
cls.model.id,
|
||||
cls.model.dialog_id,
|
||||
cls.model.name,
|
||||
fn.MIN(cls.model.create_time).alias("create_time"),
|
||||
fn.MIN(cls.model.create_date).alias("create_date")
|
||||
)
|
||||
.group_by(cls.model.id, cls.model.dialog_id, cls.model.name)
|
||||
.order_by(
|
||||
fn.MIN(cls.model.create_time).desc(),
|
||||
)
|
||||
)
|
||||
# 转换为字典列表返回
|
||||
return list(query.dicts())
|
||||
|
||||
def message_fit_in(msg, max_length=4000):
|
||||
def count():
|
||||
nonlocal msg
|
||||
@@ -128,6 +198,42 @@ def llm_id2llm_type(llm_id):
|
||||
if llm_id == llm["llm_name"]:
|
||||
return llm["model_type"].strip(",")[-1]
|
||||
|
||||
# cyx 2024 12 04
|
||||
# 用于校验和修正语音合成的输入文本。该函数会去除非法字符、修正内容,并返回一个结果:包括是否有效和修正后的文本。
|
||||
def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
|
||||
"""
|
||||
检验并修正语音合成的输入文本。
|
||||
|
||||
Args:
|
||||
delta_ans (str): 输入的待校验文本。
|
||||
max_length (int): 文本允许的最大长度。
|
||||
|
||||
Returns:
|
||||
tuple: (is_valid, sanitized_text)
|
||||
- is_valid (bool): 文本是否有效。
|
||||
- sanitized_text (str): 修正后的文本(如果无效,为空字符串)。
|
||||
"""
|
||||
# 1. 确保输入为字符串
|
||||
if not isinstance(delta_ans, str):
|
||||
return False, ""
|
||||
|
||||
# 2. 去除前后空白并检查是否为空
|
||||
delta_ans = delta_ans.strip()
|
||||
if len(delta_ans) == 0:
|
||||
return False, ""
|
||||
|
||||
# 3. 替换全角符号为半角
|
||||
delta_ans = re.sub(r'[?]', '?', delta_ans)
|
||||
|
||||
# 4. 移除非法字符(仅保留中文、英文、数字及常见标点符号)
|
||||
delta_ans = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s,.!?\'";;。,!?:”“()\-()]', '', delta_ans)
|
||||
|
||||
# 5. 检查长度
|
||||
if len(delta_ans) == 0 or len(delta_ans) > max_length:
|
||||
return False, ""
|
||||
|
||||
# 如果通过所有检查,返回有效标志和修正后的文本
|
||||
return True, delta_ans
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
@@ -175,8 +281,10 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
prompt_config = dialog.prompt_config
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
tts_mdl = None
|
||||
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,dialog.tts_id)
|
||||
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
@@ -184,7 +292,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
|
||||
# logging.info(f"dialog_service--1 chat prompt_config{prompt_config['parameters']} {prompt_config}") # cyx
|
||||
for p in prompt_config["parameters"]:
|
||||
if p["key"] == "knowledge":
|
||||
continue
|
||||
@@ -223,6 +331,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
logging.debug(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
retrieval_tm = timer()
|
||||
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
@@ -245,7 +354,6 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
gen_conf["max_tokens"] = min(
|
||||
gen_conf["max_tokens"],
|
||||
max_tokens - used_token_count)
|
||||
|
||||
def decorate_answer(answer):
|
||||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
|
||||
refs = []
|
||||
@@ -281,22 +389,44 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||||
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||||
logging.debug("User: {}|Assistant: {}".format(
|
||||
msg[-1]["content"], answer))
|
||||
res = decorate_answer(answer)
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
else:
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
yield res
|
||||
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ class TenantLLMService(CommonService):
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
# 初始化 tts model cyx
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
@@ -201,6 +202,8 @@ class LLMBundle(object):
|
||||
assert self.mdl, "Can't find model for {}/{}/{}".format(
|
||||
tenant_id, llm_type, llm_name)
|
||||
self.max_length = 8192
|
||||
if llm_type == LLMType.TTS:
|
||||
logging.info(f"dale--TTS model {tenant_id} {llm_type} {llm_name}")
|
||||
for lm in LLMService.query(llm_name=llm_name):
|
||||
self.max_length = lm.max_tokens
|
||||
break
|
||||
@@ -245,7 +248,7 @@ class LLMBundle(object):
|
||||
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return txt
|
||||
|
||||
def tts(self, text):
|
||||
def tts(self, text): # tts 调用 cyx
|
||||
for chunk in self.mdl.tts(text):
|
||||
if isinstance(chunk,int):
|
||||
if not TenantLLMService.increase_usage(
|
||||
@@ -253,7 +256,10 @@ class LLMBundle(object):
|
||||
logging.error(
|
||||
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
|
||||
return
|
||||
yield chunk
|
||||
yield chunk
|
||||
|
||||
def end_tts(self): # 结束 tts流式 调用 cyx
|
||||
self.mdl.end_tts()
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
txt, used_tokens = self.mdl.chat(system, history, gen_conf)
|
||||
|
||||
Reference in New Issue
Block a user