This commit is contained in:
qcloud
2025-05-15 15:26:06 +08:00
parent 330976812d
commit e29f79b9ac
2804 changed files with 1044973 additions and 83 deletions

View File

@@ -653,6 +653,7 @@ class TenantLLM(DataBaseModel):
api_base = CharField(max_length=255, null=True, help_text="API Base")
max_tokens = IntegerField(default=8192, index=True)
used_tokens = IntegerField(default=0, index=True)
description = CharField(max_length=255, null=True, help_text="描述")
def __str__(self):
return self.llm_name
@@ -1008,8 +1009,8 @@ class MesumOverview(DataBaseModel):
help_text="latitude",
index=False)
antique=CharField(
max_length=1024,
category=CharField(
max_length=2048,
null=True,
help_text="antique",
index=False)
@@ -1019,6 +1020,22 @@ class MesumOverview(DataBaseModel):
null=True,
help_text="brief",
index=False)
photo_url = CharField(
max_length=255,
null=True,
help_text="图片地址",
index=False)
address = CharField(
max_length=1024,
null=True,
help_text="地址",
index=False)
chat_id = CharField(
max_length=255,
null=False,
help_text="AI对话ID",
index=False)
def __str__(self):
return self.name
@@ -1032,16 +1049,26 @@ class MesumAntique(DataBaseModel):
description = TextField(null=True)
category = CharField(max_length=100, null=True)
group = CharField(max_length=100, null=True)
background = TextField(null=True)
value = TextField(null=True)
discovery = TextField(null=True)
id = AutoField(primary_key=True)
mesum_id = CharField(max_length=100, null=True)
combined = TextField(null=True)
ttsUrl_adult = CharField(max_length=256, null=True)
ttsUrl_child = CharField(max_length=256, null=True)
class Meta:
db_table = 'mesum_antique'
class AppInfo(DataBaseModel):
id = AutoField(primary_key=True) # 新增自增主键(原表无主键时需要添加)
app_name = CharField(max_length=255, null=False)
app_version = CharField(max_length=30, null=False)
package_url = CharField(max_length=512, null=False)
app_type = CharField(max_length=30, null=True)
upload_time = DateTimeField(null=True)
description = TextField(null=True) # 使用 TextField 代替 varchar(1024)
class Meta:
table_name = 'app_info' # 指定对应的数据库表名
#-------------------------------------------
def migrate_db():
with DB.transaction():

View File

@@ -22,9 +22,10 @@ from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant, MesumAntique
from api.db.services.common_service import CommonService
from api.db.services.brief_service import MesumOverviewService
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.db import StatusEnum
import logging
class MesumAntiqueService(CommonService):
model = MesumAntique
@@ -37,10 +38,39 @@ class MesumAntiqueService(CommonService):
@classmethod
@DB.connection_context()
def get_all_categories(cls):
def get_all_categories(cls,mesum_id):
# 查询所有唯一的category
categories = [category.category for category in cls.model.select(cls.model.category).distinct().execute() if category.category]
return categories
mesum_antique_categories = []
try:
mesum_brief = MesumOverviewService.query(id=mesum_id)
if mesum_brief:
categories_text= mesum_brief[0].category
# 统一替换中文分号为英文分号,并去除末尾分号
categories_text = categories_text.replace("", ";").rstrip(";")
# 分割并清理空格/空值
mesum_antique_categories = [dynasty.strip() for dynasty in categories_text.split(";") if dynasty.strip()]
finally:
pass
categories = [category.category
for category in (
cls.model.select(cls.model.category)
.where(cls.model.mesum_id==mesum_id)
.distinct()
.execute()
)
if category.category
]
# 下面代码是按照博物馆brief定义的目录顺序调整输出的目录顺序以便前端能够按照正确顺序显示
# cyx 20250415
# 创建字典映射元素到索引,提升查询效率
mesum_antique_categories_dict = {value: idx for idx, value in enumerate(mesum_antique_categories)}
# 排序逻辑:通过查字典获取顺序号,不在字典的给极大值(排最后)
categories_sorted = sorted(
categories,
key=lambda item: mesum_antique_categories_dict.get(item, len(mesum_antique_categories)) # 查字典获取顺序号
)
return categories_sorted
@classmethod
@DB.connection_context()
@@ -71,6 +101,24 @@ class MesumAntiqueService(CommonService):
return grouped_data
@classmethod
@DB.connection_context()
def get_labels_with_id(cls, mesum_id):
# 根据mesum_id过滤并排除空的category
query = cls.model.select().where(
(cls.model.mesum_id == mesum_id)
).order_by(cls.model.id)
# 将label和关联的id 一并返回
labels_data = []
for obj in query.dicts():
labels_data.append({
'id': obj['id'],
'label': obj['label']
})
return labels_data
@classmethod
@DB.connection_context()
def get_antique_by_id(cls, mesum_id,antique_id):

View File

@@ -0,0 +1,15 @@
from datetime import datetime
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant, AppInfo
from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.db import StatusEnum
class AppInfoService(CommonService):
model = AppInfo

View File

@@ -29,3 +29,5 @@ from api.db import StatusEnum
class MesumOverviewService(CommonService):
model = MesumOverview

View File

@@ -69,7 +69,8 @@ class StreamSessionManager:
'audio_chunk_count':0,
'finished': threading.Event(), # 添加事件对象
'sample_rate':sample_rate,
'stream_format':stream_format
'stream_format':stream_format,
"tts_chunk_data_valid":False
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
@@ -126,6 +127,7 @@ class StreamSessionManager:
if not session: return
# logging.info(f"_generate_audio:{text}")
first_chunk = True
# logging.info(f"转换开始!!! {text}")
try:
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
if session['stream_format'] == 'wav':
@@ -142,6 +144,8 @@ class StreamSessionManager:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
if session['tts_chunk_data_valid'] is False:
session['tts_chunk_data_valid'] = True #20250510 增加表示连接TTS后台已经返回可以通知前端了
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
@@ -161,6 +165,9 @@ class StreamSessionManager:
if session_id in self.sessions:
del self.sessions[session_id]
def get_session(self, session_id):
return self.sessions.get(session_id)
stream_manager = StreamSessionManager()
class DialogService(CommonService):
@@ -296,6 +303,7 @@ def llm_id2llm_type(llm_id):
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
followup_seperator = "继续追问:"
# cyx 2024 12 04
# 用于校验和修正语音合成的输入文本。该函数会去除非法字符、修正内容,并返回一个结果:包括是否有效和修正后的文本。
def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
@@ -329,9 +337,47 @@ def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
# 5. 检查长度
if len(delta_ans) == 0 or len(delta_ans) > max_length:
return False, ""
# """清理流式输出中可能存在的 jsonjsonjson 及之后的内容"""
"""
检查子串存在性:使用 in 关键字判断字符串 ans 是否包含子串 "jsonjsonjson"
分割字符串:若存在,使用 split('jsonjsonjson', 1) 分割一次1 确保只分割首个匹配项,避免后续重复子串的影响。
提取前部分:取分割后的第一个元素 [0],即目标子串前的内容。
处理不存在情况:根据需求返回空字符串或原字符串,示例中返回空字符串。
if followup_seperator in delta_ans:
# 分割字符串,取第一个部分
delta_ans = delta_ans.split(followup_seperator, 1)[0]
json_markdown_separator_found = True
"""
"""
# 方法:split 分割 假设```json 在文本的最后,去除```json 的内容
strings_split = delta_ans.split('```json', 1)
ans_remove_json = strings_split[0].rstrip()
if len(strings_split)>1 :
found_last_json_markdown = True
# 中间结果可能是
# json
# "questions"
# "通州起义的具体 思想基础。
pattern = r'^(.*?)json(?=.{0,10}"questions")' # 关键修改点:添加双引号
if match := re.search(pattern, ans_remove_json, flags=re.DOTALL):
ans_remove_json1 =match.group(1).rstrip()
found_last_json_markdown = True
else:
ans_remove_json1 = ans_remove_json
#logging.info(f"--dale---3:1-{delta_ans} 2-{ans_remove_json1}")
"""
# logging.info(f"--dale---3--:{delta_ans}")
# 如果通过所有检查,返回有效标志和修正后的文本
return True, delta_ans
def _should_flush(text_chunk,chunk_buffer,last_flush_time):
"""智能判断是否需要立即生成音频"""
# 规则1遇到句子结束标点
@@ -348,9 +394,158 @@ def _should_flush(text_chunk,chunk_buffer,last_flush_time):
return True
return False
def extract_and_parse_json(llm_ans):
# 匹配带 JSON 标记的代码块
json_pattern = r'```json\n(.*?)\n```'
match = re.search(json_pattern, llm_ans, re.DOTALL)
if not match:
# 尝试匹配不带标记的 JSON 对象
json_pattern_fallback = r'\{.*\}'
match = re.search(json_pattern_fallback, response_text, re.DOTALL)
if not match:
return None, "未检测到 JSON 数据"
json_str = match.group(1) if match.group(1) else match.group(0)
try:
# 处理常见格式问题
json_str = json_str.strip()
json_str = json_str.replace("", ",") # 替换中文逗号
json_str = re.sub(r'//.*?\n', '', json_str) # 去除注释
# 解析 JSON
parsed_data = json.loads(json_str)
return parsed_data, None
except json.JSONDecodeError as e:
error_msg = f"JSON 解析失败:{str(e)}\n错误位置:第 {e.lineno}{e.colno}"
return None, error_msg
except Exception as e:
return None, f"解析异常:{str(e)}"
import re
import json
def extract_clear_parse_json(llm_ans, clean_text=True):
parsed_data = None
error = None
matches = []
cleaned_text = llm_ans
# 先将json特殊分隔符去除
# 匹配带 JSON 标记的完整代码块(含内容)
json_marker_pattern = r'```json.*?```'
for match in re.finditer(json_marker_pattern, cleaned_text, re.DOTALL):
start, end = match.start(), match.end()
json_block = match.group(0) # 包含整个 ```json...``` 内容
matches.append((start, end, json_block))
# 若未找到带标记块,尝试匹配无标记的 JSON 对象
if not matches:
json_object_pattern = r'\{[^{}]*\}'
for match in re.finditer(json_object_pattern, cleaned_text, re.DOTALL):
start, end = match.start(), match.end()
json_block = match.group(0)
matches.append((start, end, json_block))
# 清理文本(删除所有 JSON 块)
if clean_text:
# 反向删除避免位置偏移
for start, end, _ in sorted(matches, key=lambda x: x[0], reverse=True):
cleaned_text = cleaned_text[:start] + cleaned_text[end:]
# 提取并解析 JSON 内容
json_blocks = []
for _, _, block in matches:
# 如果是带标记的块,剥离 ```json 和 ```
if block.startswith('```json'):
pure_json = re.sub(r'^```json\s*|\s*```$', '', block, flags=re.DOTALL)
json_blocks.append(pure_json.strip())
else:
json_blocks.append(block.strip())
# 尝试解析所有 JSON 块
for json_str in json_blocks:
try:
json_str = re.sub(r'//.*', '', json_str) # 移除行内注释
json_str = json_str.replace("", ",") # 处理中文逗号
parsed_data = json.loads(json_str)
error = None
break # 解析成功即停止
except json.JSONDecodeError as e:
error = f"JSON 解析失败:{e},位置:第 {e.lineno}"
except Exception as e:
error = f"解析异常:{str(e)}"
if not json_blocks:
error = "未检测到 JSON 数据"
return parsed_data, cleaned_text.strip(), error
# cyx 20250510 增加 生成后续可以追问的内容输出为json格式
def generate_structured_followups(chat_mdl, answer, max_questions=5):
"""
生成结构化追问建议带JSON格式解释
:return: (JSON数据, 格式解释, 消耗tokens)
"""
system_prompt = """你是一位精通数据结构的博物馆教育专家,请完成以下任务:
1. 根据讲解内容生成{max_questions}个追问问题格式为严格遵循的JSON
2. 为生成的JSON添加格式解释
3. JSON需包含问题分类和置信度
JSON格式要求
{{
"questions": [
{{
"text": "问题文本",
"type": "问题类型",
"confidence": 置信度(0-1)
}}
],
"source_analysis": {{
"main_topics": ["主要话题"],
"missing_areas": ["未涉及领域"]
}}
}}"""
user_prompt = f"""原始讲解内容:
{answer}
请按以下步骤处理:
1. 分析内容的关键知识点
2. 生成{max_questions}个延伸问题
3. 生成JSON后添加格式解释
4. 用---分隔数据和解释"""
gen_config = {
"temperature": 0.5,
"max_tokens": 800
}
try:
ans = chat_mdl.chat(
system=system_prompt.format(max_questions=max_questions),
history=[{"role": "user", "content": user_prompt}],
gen_conf=gen_config
)
# 分离JSON和解释
json_data,error = extract_and_parse_json(ans)
if json_data is not None:
# 强化JSON提取
return json_data,"解析正确"
else:
return {},"解析错误",tokens
except json.JSONDecodeError as e:
print(f"JSON解析失败: {str(e)}")
return {}, "格式解析错误"
except Exception as e:
print(f"生成失败: {str(e)}")
return {}, "生成过程异常"
MAX_BUFFER_LEN = 200 # 最大缓冲长度
FLUSH_TIMEOUT = 0.5 # 强制刷新时间(秒)
FLUSH_TIMEOUT = 1 # 强制刷新时间(秒)
# 智能查找文本最佳分割点(标点/语义单位/短语边界)
def find_split_position(text):
@@ -387,7 +582,10 @@ def process_buffer(chunk_buffer, force_flush=False):
split_pos = min(split_pos, len(current_text))
if split_pos is not None and split_pos > 0:
return current_text[:split_pos], [current_text[split_pos:]]
to_tts_text = current_text[:split_pos]
remaining_text = [current_text[split_pos:]]
return to_tts_text,remaining_text
return None, chunk_buffer
@@ -424,7 +622,6 @@ def chat(dialog, messages, stream=True, **kwargs):
for m in messages[:-1]:
if "doc_ids" in m:
attachments.extend(m["doc_ids"])
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
if not embd_mdl:
raise LookupError("Embedding model(%s) not found" % embd_nms[0])
@@ -469,7 +666,6 @@ def chat(dialog, messages, stream=True, **kwargs):
questions = questions[-1:]
refineQ_tm = timer()
keyword_tm = timer()
rerank_mdl = None
if dialog.rerank_id:
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
@@ -492,7 +688,7 @@ def chat(dialog, messages, stream=True, **kwargs):
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
# 打印历史记录
#logging.info( "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
logging.info( "dale-----!!!:{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_tm = timer()
if not knowledges and prompt_config.get("empty_response"):
@@ -547,15 +743,18 @@ def chat(dialog, messages, stream=True, **kwargs):
prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % (
(refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000,
(done_tm - retrieval_tm) * 1000)
return {"answer": answer, "reference": refs, "prompt": prompt}
#return {"answer": answer, "prompt": prompt,"reference": refs }
# cyx 增加 20250510 生成后续追问的内容
# cyx 修改 20250422 不向前端发送prompt 和 refs ,增加发送 finished 标志
return {"answer": answer, "finished":True,"reference":""}
if stream:
last_ans = ""
answer = ""
# 创建TTS会话提前初始化
tts_session_id = stream_manager.create_session(tts_mdl,sample_rate=tts_sample_rate,stream_format=tts_stream_format)
tts_session = stream_manager.get_session(tts_session_id)
audio_url = f"/tts_stream/{tts_session_id}"
first_chunk = True
send_tts_url = False
chunk_buffer = [] # 新增文本缓冲
last_flush_time = time.time() # 初始化时间戳
# 下面优先处理知识库中没有找到相关内容 cyx 20250323 修改
@@ -583,7 +782,7 @@ def chat(dialog, messages, stream=True, **kwargs):
# cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
if kwargs.get('tts_disable'):
tts_input_is_valid =False
if tts_input_is_valid:
if tts_input_is_valid :
# 缓冲文本直到遇到标点
chunk_buffer.append(sanitized_text)
# 处理缓冲区内容
@@ -605,7 +804,7 @@ def chat(dialog, messages, stream=True, **kwargs):
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
"""
# 首块返回音频URL
if first_chunk:
if send_tts_url is False and tts_session['tts_chunk_data_valid'] is True:
yield {
"answer": answer,
"delta_ans": sanitized_text,
@@ -615,10 +814,11 @@ def chat(dialog, messages, stream=True, **kwargs):
"sample_rate":tts_sample_rate,
"stream_format":tts_stream_format,
}
first_chunk = False
send_tts_url = True # 发送一次tts url 给前端即可,不能重复发送
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
delta_ans = answer[len(last_ans):]
if delta_ans:
# stream_manager.append_text(tts_session_id, delta_ans)
@@ -627,10 +827,12 @@ def chat(dialog, messages, stream=True, **kwargs):
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
if tts_input_is_valid:
if tts_input_is_valid :
# 20250221 修改,在后端生成音频数据
chunk_buffer.append(sanitized_text)
stream_manager.append_text(tts_session_id, ''.join(chunk_buffer))
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
if to_send:
stream_manager.append_text(tts_session_id, to_send)
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
"""
if tts_input_is_valid:

View File

@@ -54,7 +54,8 @@ class TenantLLMService(CommonService):
LLMFactories.tags,
cls.model.model_type,
cls.model.llm_name,
cls.model.used_tokens
cls.model.used_tokens,
cls.model.description # added by cyx 20250415
]
objs = cls.model.select(*fields).join(LLMFactories, on=(cls.model.llm_factory == LLMFactories.name)).where(
cls.model.tenant_id == tenant_id, ~cls.model.api_key.is_null()).dicts()