从新提交到gitee 仓库

This commit is contained in:
qcloud
2025-02-06 23:34:26 +08:00
parent e678819f70
commit c88312a914
62 changed files with 211935 additions and 7500 deletions

View File

@@ -912,6 +912,12 @@ class Dialog(DataBaseModel):
help_text="is it validate(0: wasted, 1: validate)",
default="1",
index=True)
# tts_id added by cyx 为每一个对话助理设置相应的tts
tts_id = CharField(
max_length=256,
null=True,
help_text="default tts model ID",
index=True)
class Meta:
db_table = "dialog"

View File

@@ -32,6 +32,37 @@ from rag.app.resume import forbidden_select_fields4resume
from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
from peewee import fn
import threading, queue
# 创建一个 TTS 生成线程
class TTSWorker(threading.Thread):
def __init__(self, tenant_id, tts_id, tts_text_queue, tts_audio_queue):
super().__init__()
self.tts_mdl = LLMBundle(tenant_id, LLMType.TTS, tts_id)
self.tts_text_queue = tts_text_queue
self.tts_audio_queue = tts_audio_queue
self.daemon = True # 设置为守护线程,主线程退出时,子线程也会自动退出
def run(self):
while True:
# 从队列中获取数据
delta_ans = self.tts_text_queue.get()
if delta_ans is None: # 如果队列中没有数据,退出线程
break
try:
# 调用 TTS 生成音频数据
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
if tts_input_is_valid:
logging.info(f"--tts threading {delta_ans} {tts_input_is_valid} {sanitized_text}")
bin = b""
for chunk in self.tts_mdl.tts(sanitized_text):
bin += chunk
# 将生成的音频数据存储到队列中或直接处理
self.tts_audio_queue.put(bin)
except Exception as e:
logging.error(f"Error generating TTS for text '{delta_ans}': {e}")
class DialogService(CommonService):
@@ -65,22 +96,61 @@ class ConversationService(CommonService):
@classmethod
@DB.connection_context()
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name):
sessions = cls.model.select().where(cls.model.dialog_id ==dialog_id)
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id, name, cols=None):
# 构建基础查询
print("--ConversationService get_list enter", page_number, items_per_page) # cyx
query = cls.model.select().where(cls.model.dialog_id == dialog_id)
# 如果指定了ID则添加ID筛选
if id:
sessions = sessions.where(cls.model.id == id)
query = query.where(cls.model.id == id)
# 如果指定了名称,则添加名称筛选
if name:
sessions = sessions.where(cls.model.name == name)
query = query.where(cls.model.name == name)
# 如果指定了列筛选,则只选择指定的列
if cols:
query = query.select(*[getattr(cls.model, col) for col in cols])
# 获取记录总数
total = query.count()
# 添加排序
if desc:
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
query = query.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
query = query.order_by(cls.model.getter_by(orderby).asc())
sessions = sessions.paginate(page_number, items_per_page)
# 执行分页查询
paginated_query = query.paginate(page_number, items_per_page)
data = list(paginated_query.dicts())
# logging.info("--ConversationService get_list",total, data) #cyx
# 返回分页数据和记录总数
return total, data
return list(sessions.dicts())
@classmethod
@DB.connection_context()
def query_sessions_summary(cls):
# 按 id 分组,统计每个 id 的最旧记录
query = (
cls.model
.select(
cls.model.id,
cls.model.dialog_id,
cls.model.name,
fn.MIN(cls.model.create_time).alias("create_time"),
fn.MIN(cls.model.create_date).alias("create_date")
)
.group_by(cls.model.id, cls.model.dialog_id, cls.model.name)
.order_by(
fn.MIN(cls.model.create_time).desc(),
)
)
# 转换为字典列表返回
return list(query.dicts())
def message_fit_in(msg, max_length=4000):
def count():
nonlocal msg
@@ -128,6 +198,42 @@ def llm_id2llm_type(llm_id):
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
# cyx 2024 12 04
# 用于校验和修正语音合成的输入文本。该函数会去除非法字符、修正内容,并返回一个结果:包括是否有效和修正后的文本。
def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
"""
检验并修正语音合成的输入文本。
Args:
delta_ans (str): 输入的待校验文本。
max_length (int): 文本允许的最大长度。
Returns:
tuple: (is_valid, sanitized_text)
- is_valid (bool): 文本是否有效。
- sanitized_text (str): 修正后的文本(如果无效,为空字符串)。
"""
# 1. 确保输入为字符串
if not isinstance(delta_ans, str):
return False, ""
# 2. 去除前后空白并检查是否为空
delta_ans = delta_ans.strip()
if len(delta_ans) == 0:
return False, ""
# 3. 替换全角符号为半角
delta_ans = re.sub(r'[]', '?', delta_ans)
# 4. 移除非法字符(仅保留中文、英文、数字及常见标点符号)
delta_ans = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s,.!?\'";;。,!?:”“()\-()]', '', delta_ans)
# 5. 检查长度
if len(delta_ans) == 0 or len(delta_ans) > max_length:
return False, ""
# 如果通过所有检查,返回有效标志和修正后的文本
return True, delta_ans
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
@@ -175,8 +281,10 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt_config = dialog.prompt_config
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
tts_mdl = None
if prompt_config.get("tts"):
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,dialog.tts_id)
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
@@ -184,7 +292,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if ans:
yield ans
return
# logging.info(f"dialog_service--1 chat prompt_config{prompt_config['parameters']} {prompt_config}") # cyx
for p in prompt_config["parameters"]:
if p["key"] == "knowledge":
continue
@@ -223,6 +331,7 @@ def chat(dialog, messages, stream=True, **kwargs):
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_tm = timer()
if not knowledges and prompt_config.get("empty_response"):
@@ -245,7 +354,6 @@ def chat(dialog, messages, stream=True, **kwargs):
gen_conf["max_tokens"] = min(
gen_conf["max_tokens"],
max_tokens - used_token_count)
def decorate_answer(answer):
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
refs = []
@@ -281,22 +389,44 @@ def chat(dialog, messages, stream=True, **kwargs):
last_ans = ""
answer = ""
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 16:
continue
last_ans = answer
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
delta_ans = answer[len(last_ans):]
if delta_ans:
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
#if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
yield decorate_answer(answer)
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
logging.debug("User: {}|Assistant: {}".format(
msg[-1]["content"], answer))
res = decorate_answer(answer)
res["audio_binary"] = tts(tts_mdl, answer)
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
else:
res["audio_binary"] = tts(tts_mdl, answer)
yield res

View File

@@ -140,6 +140,7 @@ class TenantLLMService(CommonService):
if llm_type == LLMType.TTS:
if model_config["llm_factory"] not in TTSModel:
return
# 初始化 tts model cyx
return TTSModel[model_config["llm_factory"]](
model_config["api_key"],
model_config["llm_name"],
@@ -201,6 +202,8 @@ class LLMBundle(object):
assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name)
self.max_length = 8192
if llm_type == LLMType.TTS:
logging.info(f"dale--TTS model {tenant_id} {llm_type} {llm_name}")
for lm in LLMService.query(llm_name=llm_name):
self.max_length = lm.max_tokens
break
@@ -245,7 +248,7 @@ class LLMBundle(object):
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt
def tts(self, text):
def tts(self, text): # tts 调用 cyx
for chunk in self.mdl.tts(text):
if isinstance(chunk,int):
if not TenantLLMService.increase_usage(
@@ -253,7 +256,10 @@ class LLMBundle(object):
logging.error(
"LLMBundle.tts can't update token usage for {}/TTS".format(self.tenant_id))
return
yield chunk
yield chunk
def end_tts(self): # 结束 tts流式 调用 cyx
self.mdl.end_tts()
def chat(self, system, history, gen_conf):
txt, used_tokens = self.mdl.chat(system, history, gen_conf)