1258 lines
51 KiB
Python
1258 lines
51 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
import logging
|
||
import binascii
|
||
import os
|
||
import json
|
||
import re
|
||
from copy import deepcopy
|
||
from timeit import default_timer as timer
|
||
import datetime
|
||
from datetime import timedelta
|
||
from api.db import LLMType, ParserType,StatusEnum
|
||
from api.db.db_models import Dialog, Conversation,DB
|
||
from api.db.services.common_service import CommonService
|
||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||
from api.db.services.llm_service import LLMService, TenantLLMService, LLMBundle
|
||
from api import settings
|
||
from rag.app.resume import forbidden_select_fields4resume
|
||
from rag.nlp.search import index_name
|
||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||
from api.utils.file_utils import get_project_base_directory
|
||
from peewee import fn
|
||
import threading, queue,uuid,time,array
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
from api.db.services.ali_tts_service import (stream_manager_w_stream as stream_manager)
|
||
|
||
def audio_fade_in(audio_data, fade_length):
|
||
# 假设音频数据是16位单声道PCM
|
||
# 将二进制数据转换为整数数组
|
||
samples = array.array('h', audio_data)
|
||
|
||
# 对前fade_length个样本进行淡入处理
|
||
for i in range(fade_length):
|
||
fade_factor = i / fade_length
|
||
samples[i] = int(samples[i] * fade_factor)
|
||
|
||
# 将整数数组转换回二进制数据
|
||
return samples.tobytes()
|
||
|
||
|
||
class StreamSessionManager:
|
||
def __init__(self):
|
||
self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
|
||
self.lock = threading.Lock()
|
||
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
||
self.gc_interval = 300 # 5分钟清理一次 5 x 60 300秒
|
||
self.gc_tts = 10 # 10s 大模型开始输出文本有可能需要比较久,2025年5 24 从3s->10s
|
||
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
|
||
session_id = str(uuid.uuid4())
|
||
with self.lock:
|
||
self.sessions[session_id] = {
|
||
'tts_model': tts_model,
|
||
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
||
'task_queue': queue.Queue(),
|
||
'active': True,
|
||
'last_active': time.time(),
|
||
'audio_chunk_count':0,
|
||
'finished': threading.Event(), # 添加事件对象
|
||
'sample_rate':sample_rate,
|
||
'stream_format':stream_format,
|
||
"tts_chunk_data_valid":False,
|
||
"sentence_complete_event": threading.Event(),
|
||
"current_processing": False # 标记是否正在处理句子
|
||
}
|
||
# 启动任务处理线程
|
||
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
||
return session_id
|
||
|
||
def append_text(self, session_id, text):
|
||
with self.lock:
|
||
session = self.sessions.get(session_id)
|
||
if not session: return
|
||
# 将文本放入任务队列(非阻塞)
|
||
#logging.info(f"StreamSessionManager append_text {text}")
|
||
try:
|
||
session['task_queue'].put(text, block=False)
|
||
except queue.Full:
|
||
logging.warning(f"Session {session_id} task queue full")
|
||
|
||
def _process_tasks(self, session_id):
|
||
"""任务处理线程(每个会话独立)"""
|
||
while True:
|
||
session = self.sessions.get(session_id)
|
||
if not session or not session['active']:
|
||
break
|
||
try:
|
||
#logging.info(f"StreamSessionManager _process_tasks {session['task_queue'].qsize()}")
|
||
# 合并多个文本块(最多等待50ms)
|
||
texts = []
|
||
while len(texts) < 5: # 最大合并5个文本块
|
||
try:
|
||
text = session['task_queue'].get(timeout=0.1)
|
||
#logging.info(f"StreamSessionManager _process_tasks --0 {len(texts)}")
|
||
texts.append(text)
|
||
except queue.Empty:
|
||
break
|
||
|
||
if texts:
|
||
session['last_active'] = time.time() # 如果有处理文本,重置活跃时间
|
||
# 提交到线程池处理
|
||
future=self.executor.submit(
|
||
self._generate_audio,
|
||
session_id,
|
||
' '.join(texts) # 合并文本减少请求次数
|
||
)
|
||
future.result() # 等待转换任务执行完毕
|
||
session['last_active'] = time.time()
|
||
# 会话超时检查
|
||
if time.time() - session['last_active'] > self.gc_interval:
|
||
self.close_session(session_id)
|
||
break
|
||
if time.time() - session['last_active'] > self.gc_tts:
|
||
session['finished'].set()
|
||
break
|
||
|
||
except Exception as e:
|
||
logging.error(f"Task processing error: {str(e)}")
|
||
|
||
def _generate_audio1(self, session_id, text):
|
||
"""实际生成音频(线程池执行)"""
|
||
session = self.sessions.get(session_id)
|
||
if not session: return
|
||
# logging.info(f"_generate_audio:{text}")
|
||
first_chunk = True
|
||
logging.info(f"转换开始!!! {text}")
|
||
try:
|
||
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
|
||
if session['stream_format'] == 'wav':
|
||
if first_chunk:
|
||
chunk_len = len(chunk)
|
||
if chunk_len > 2048:
|
||
session['buffer'].put(audio_fade_in(chunk,1024))
|
||
else:
|
||
session['buffer'].put(audio_fade_in(chunk, chunk_len))
|
||
first_chunk = False
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
else:
|
||
session['buffer'].put(chunk)
|
||
session['last_active'] = time.time()
|
||
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
||
if session['tts_chunk_data_valid'] is False:
|
||
session['tts_chunk_data_valid'] = True #20250510 增加,表示连接TTS后台已经返回,可以通知前端了
|
||
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
|
||
except Exception as e:
|
||
session['buffer'].put(f"ERROR:{str(e)}")
|
||
logging.info(f"--_generate_audio--error {str(e)}")
|
||
|
||
def _generate_audio(self, session_id, text):
|
||
"""实际生成音频(顺序执行)- 用于非流式引擎"""
|
||
session = self.sessions.get(session_id)
|
||
if not session:
|
||
return
|
||
|
||
try:
|
||
# 调用 TTS
|
||
session['tts_model'].text_tts_call(text)
|
||
|
||
# 标记完成
|
||
session['sentence_complete_event'].set()
|
||
session['last_active'] = time.time()
|
||
session['audio_chunk_count'] += 1
|
||
|
||
if not session['tts_chunk_data_valid']:
|
||
session['tts_chunk_data_valid'] = True
|
||
|
||
except Exception as e:
|
||
session['buffer'].put(f"ERROR:{str(e)}".encode())
|
||
session['sentence_complete_event'].set() # 确保事件被设置
|
||
|
||
def close_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
# 标记会话为不活跃
|
||
self.sessions[session_id]['active'] = False
|
||
# 延迟2秒后清理资源
|
||
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
||
|
||
def _clean_session(self, session_id):
|
||
with self.lock:
|
||
if session_id in self.sessions:
|
||
del self.sessions[session_id]
|
||
|
||
def get_session(self, session_id):
|
||
return self.sessions.get(session_id)
|
||
|
||
stream_manager_bk = StreamSessionManager()
|
||
|
||
class DialogService(CommonService):
|
||
model = Dialog
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def get_list(cls, tenant_id,
|
||
page_number, items_per_page, orderby, desc, id , name):
|
||
chats = cls.model.select()
|
||
if id:
|
||
chats = chats.where(cls.model.id == id)
|
||
if name:
|
||
chats = chats.where(cls.model.name == name)
|
||
chats = chats.where(
|
||
(cls.model.tenant_id == tenant_id)
|
||
& (cls.model.status == StatusEnum.VALID.value)
|
||
)
|
||
if desc:
|
||
chats = chats.order_by(cls.model.getter_by(orderby).desc())
|
||
else:
|
||
chats = chats.order_by(cls.model.getter_by(orderby).asc())
|
||
|
||
chats = chats.paginate(page_number, items_per_page)
|
||
|
||
return list(chats.dicts())
|
||
|
||
|
||
class ConversationService(CommonService):
|
||
model = Conversation
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id, name, cols=None):
|
||
# 构建基础查询
|
||
print("--ConversationService get_list enter", page_number, items_per_page) # cyx
|
||
query = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
||
|
||
# 如果指定了ID,则添加ID筛选
|
||
if id:
|
||
query = query.where(cls.model.id == id)
|
||
|
||
# 如果指定了名称,则添加名称筛选
|
||
if name:
|
||
query = query.where(cls.model.name == name)
|
||
|
||
# 如果指定了列筛选,则只选择指定的列
|
||
if cols:
|
||
query = query.select(*[getattr(cls.model, col) for col in cols])
|
||
|
||
# 获取记录总数
|
||
total = query.count()
|
||
# 添加排序
|
||
if desc:
|
||
query = query.order_by(cls.model.getter_by(orderby).desc())
|
||
else:
|
||
query = query.order_by(cls.model.getter_by(orderby).asc())
|
||
|
||
# 执行分页查询
|
||
paginated_query = query.paginate(page_number, items_per_page)
|
||
data = list(paginated_query.dicts())
|
||
# logging.info("--ConversationService get_list",total, data) #cyx
|
||
# 返回分页数据和记录总数
|
||
return total, data
|
||
|
||
|
||
|
||
@classmethod
|
||
@DB.connection_context()
|
||
def query_sessions_summary(cls):
|
||
# 按 id 分组,统计每个 id 的最旧记录
|
||
query = (
|
||
cls.model
|
||
.select(
|
||
cls.model.id,
|
||
cls.model.dialog_id,
|
||
cls.model.name,
|
||
fn.MIN(cls.model.create_time).alias("create_time"),
|
||
fn.MIN(cls.model.create_date).alias("create_date")
|
||
)
|
||
.group_by(cls.model.id, cls.model.dialog_id, cls.model.name)
|
||
.order_by(
|
||
fn.MIN(cls.model.create_time).desc(),
|
||
)
|
||
)
|
||
# 转换为字典列表返回
|
||
return list(query.dicts())
|
||
|
||
def message_fit_in(msg, max_length=4000):
|
||
def count():
|
||
nonlocal msg
|
||
tks_cnts = []
|
||
for m in msg:
|
||
tks_cnts.append(
|
||
{"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||
total = 0
|
||
for m in tks_cnts:
|
||
total += m["count"]
|
||
return total
|
||
|
||
c = count()
|
||
if c < max_length:
|
||
return c, msg
|
||
|
||
msg_ = [m for m in msg[:-1] if m["role"] == "system"]
|
||
if len(msg) > 1:
|
||
msg_.append(msg[-1])
|
||
msg = msg_
|
||
c = count()
|
||
if c < max_length:
|
||
return c, msg
|
||
|
||
ll = num_tokens_from_string(msg_[0]["content"])
|
||
l = num_tokens_from_string(msg_[-1]["content"])
|
||
if ll / (ll + l) > 0.8:
|
||
m = msg_[0]["content"]
|
||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
||
msg[0]["content"] = m
|
||
return max_length, msg
|
||
|
||
m = msg_[1]["content"]
|
||
m = encoder.decode(encoder.encode(m)[:max_length - l])
|
||
msg[1]["content"] = m
|
||
return max_length, msg
|
||
|
||
|
||
def llm_id2llm_type(llm_id):
|
||
llm_id = llm_id.split("@")[0]
|
||
fnm = os.path.join(get_project_base_directory(), "conf")
|
||
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
|
||
for llm_factory in llm_factories["factory_llm_infos"]:
|
||
for llm in llm_factory["llm"]:
|
||
if llm_id == llm["llm_name"]:
|
||
return llm["model_type"].strip(",")[-1]
|
||
|
||
followup_seperator = "继续追问:"
|
||
# cyx 2024 12 04
|
||
# 用于校验和修正语音合成的输入文本。该函数会去除非法字符、修正内容,并返回一个结果:包括是否有效和修正后的文本。
|
||
def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
|
||
"""
|
||
检验并修正语音合成的输入文本。
|
||
|
||
Args:
|
||
delta_ans (str): 输入的待校验文本。
|
||
max_length (int): 文本允许的最大长度。
|
||
|
||
Returns:
|
||
tuple: (is_valid, sanitized_text)
|
||
- is_valid (bool): 文本是否有效。
|
||
- sanitized_text (str): 修正后的文本(如果无效,为空字符串)。
|
||
"""
|
||
# 1. 确保输入为字符串
|
||
if not isinstance(delta_ans, str):
|
||
return False, ""
|
||
|
||
# 2. 去除前后空白并检查是否为空
|
||
delta_ans = delta_ans.strip()
|
||
if len(delta_ans) == 0:
|
||
return False, ""
|
||
|
||
# 3. 替换全角符号为半角
|
||
delta_ans = re.sub(r'[?]', '?', delta_ans)
|
||
|
||
# 4. 移除非法字符(仅保留中文、英文、数字及常见标点符号)
|
||
delta_ans = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9\s,.!?\'";;。,!?:”“()\-()]', '', delta_ans)
|
||
|
||
# 5. 检查长度
|
||
if len(delta_ans) == 0 or len(delta_ans) > max_length:
|
||
return False, ""
|
||
# """清理流式输出中可能存在的 jsonjsonjson 及之后的内容"""
|
||
"""
|
||
检查子串存在性:使用 in 关键字判断字符串 ans 是否包含子串 "jsonjsonjson"。
|
||
|
||
分割字符串:若存在,使用 split('jsonjsonjson', 1) 分割一次,1 确保只分割首个匹配项,避免后续重复子串的影响。
|
||
|
||
提取前部分:取分割后的第一个元素 [0],即目标子串前的内容。
|
||
|
||
处理不存在情况:根据需求返回空字符串或原字符串,示例中返回空字符串。
|
||
if followup_seperator in delta_ans:
|
||
# 分割字符串,取第一个部分
|
||
delta_ans = delta_ans.split(followup_seperator, 1)[0]
|
||
json_markdown_separator_found = True
|
||
"""
|
||
|
||
|
||
"""
|
||
# 方法:split 分割 假设```json 在文本的最后,去除```json 的内容
|
||
strings_split = delta_ans.split('```json', 1)
|
||
ans_remove_json = strings_split[0].rstrip()
|
||
if len(strings_split)>1 :
|
||
found_last_json_markdown = True
|
||
# 中间结果可能是
|
||
# json
|
||
|
||
# "questions"
|
||
# "通州起义的具体 思想基础。
|
||
pattern = r'^(.*?)json(?=.{0,10}"questions")' # 关键修改点:添加双引号
|
||
if match := re.search(pattern, ans_remove_json, flags=re.DOTALL):
|
||
ans_remove_json1 =match.group(1).rstrip()
|
||
found_last_json_markdown = True
|
||
else:
|
||
ans_remove_json1 = ans_remove_json
|
||
#logging.info(f"--dale---3:1-{delta_ans} 2-{ans_remove_json1}")
|
||
"""
|
||
# logging.info(f"--dale---3--:{delta_ans}")
|
||
# 如果通过所有检查,返回有效标志和修正后的文本
|
||
|
||
return True, delta_ans
|
||
|
||
|
||
def _should_flush(text_chunk,chunk_buffer,last_flush_time):
|
||
"""智能判断是否需要立即生成音频"""
|
||
# 规则1:遇到句子结束标点
|
||
if re.search(r'[。!?,]$', text_chunk):
|
||
return True
|
||
|
||
if re.search(r'(\d{4})(年|月|日|,)', text_chunk):
|
||
return False # 不刷新,继续合并
|
||
# 规则2:达到最大缓冲长度(200字符)
|
||
if sum(len(c) for c in chunk_buffer) >= 200:
|
||
return True
|
||
# 规则3:超过500ms未刷新
|
||
if time.time() - last_flush_time > 0.5:
|
||
return True
|
||
return False
|
||
|
||
def extract_and_parse_json(llm_ans):
|
||
# 匹配带 JSON 标记的代码块
|
||
json_pattern = r'```json\n(.*?)\n```'
|
||
match = re.search(json_pattern, llm_ans, re.DOTALL)
|
||
if not match:
|
||
# 尝试匹配不带标记的 JSON 对象
|
||
json_pattern_fallback = r'\{.*\}'
|
||
match = re.search(json_pattern_fallback, response_text, re.DOTALL)
|
||
if not match:
|
||
return None, "未检测到 JSON 数据"
|
||
|
||
json_str = match.group(1) if match.group(1) else match.group(0)
|
||
|
||
try:
|
||
# 处理常见格式问题
|
||
json_str = json_str.strip()
|
||
json_str = json_str.replace(",", ",") # 替换中文逗号
|
||
json_str = re.sub(r'//.*?\n', '', json_str) # 去除注释
|
||
|
||
# 解析 JSON
|
||
parsed_data = json.loads(json_str)
|
||
return parsed_data, None
|
||
|
||
except json.JSONDecodeError as e:
|
||
error_msg = f"JSON 解析失败:{str(e)}\n错误位置:第 {e.lineno} 列 {e.colno}"
|
||
return None, error_msg
|
||
except Exception as e:
|
||
return None, f"解析异常:{str(e)}"
|
||
|
||
import re
|
||
import json
|
||
|
||
def extract_clear_parse_json(llm_ans, clean_text=True):
|
||
parsed_data = None
|
||
error = None
|
||
matches = []
|
||
cleaned_text = llm_ans
|
||
# 先将json特殊分隔符去除
|
||
|
||
# 匹配带 JSON 标记的完整代码块(含内容)
|
||
json_marker_pattern = r'```json.*?```'
|
||
for match in re.finditer(json_marker_pattern, cleaned_text, re.DOTALL):
|
||
start, end = match.start(), match.end()
|
||
json_block = match.group(0) # 包含整个 ```json...``` 内容
|
||
matches.append((start, end, json_block))
|
||
|
||
# 若未找到带标记块,尝试匹配无标记的 JSON 对象
|
||
if not matches:
|
||
json_object_pattern = r'\{[^{}]*\}'
|
||
for match in re.finditer(json_object_pattern, cleaned_text, re.DOTALL):
|
||
start, end = match.start(), match.end()
|
||
json_block = match.group(0)
|
||
matches.append((start, end, json_block))
|
||
|
||
# 清理文本(删除所有 JSON 块)
|
||
if clean_text:
|
||
# 反向删除避免位置偏移
|
||
for start, end, _ in sorted(matches, key=lambda x: x[0], reverse=True):
|
||
cleaned_text = cleaned_text[:start] + cleaned_text[end:]
|
||
|
||
# 提取并解析 JSON 内容
|
||
json_blocks = []
|
||
for _, _, block in matches:
|
||
# 如果是带标记的块,剥离 ```json 和 ```
|
||
if block.startswith('```json'):
|
||
pure_json = re.sub(r'^```json\s*|\s*```$', '', block, flags=re.DOTALL)
|
||
json_blocks.append(pure_json.strip())
|
||
else:
|
||
json_blocks.append(block.strip())
|
||
|
||
# 尝试解析所有 JSON 块
|
||
for json_str in json_blocks:
|
||
try:
|
||
json_str = re.sub(r'//.*', '', json_str) # 移除行内注释
|
||
json_str = json_str.replace(",", ",") # 处理中文逗号
|
||
parsed_data = json.loads(json_str)
|
||
error = None
|
||
break # 解析成功即停止
|
||
except json.JSONDecodeError as e:
|
||
error = f"JSON 解析失败:{e},位置:第 {e.lineno} 行"
|
||
except Exception as e:
|
||
error = f"解析异常:{str(e)}"
|
||
|
||
if not json_blocks:
|
||
error = "未检测到 JSON 数据"
|
||
|
||
return parsed_data, cleaned_text.strip(), error
|
||
|
||
# cyx 20250510 增加 生成后续可以追问的内容,输出为json格式
|
||
def generate_structured_followups(chat_mdl, answer, max_questions=5):
|
||
"""
|
||
生成结构化追问建议(带JSON格式解释)
|
||
:return: (JSON数据, 格式解释, 消耗tokens)
|
||
"""
|
||
system_prompt = """你是一位精通数据结构的博物馆教育专家,请完成以下任务:
|
||
1. 根据讲解内容生成{max_questions}个追问问题,格式为严格遵循的JSON
|
||
2. 为生成的JSON添加格式解释
|
||
3. JSON需包含问题分类和置信度
|
||
|
||
JSON格式要求:
|
||
{{
|
||
"questions": [
|
||
{{
|
||
"text": "问题文本",
|
||
"type": "问题类型",
|
||
"confidence": 置信度(0-1)
|
||
}}
|
||
],
|
||
"source_analysis": {{
|
||
"main_topics": ["主要话题"],
|
||
"missing_areas": ["未涉及领域"]
|
||
}}
|
||
}}"""
|
||
|
||
user_prompt = f"""原始讲解内容:
|
||
{answer}
|
||
|
||
请按以下步骤处理:
|
||
1. 分析内容的关键知识点
|
||
2. 生成{max_questions}个延伸问题
|
||
3. 生成JSON后添加格式解释
|
||
4. 用---分隔数据和解释"""
|
||
|
||
gen_config = {
|
||
"temperature": 0.5,
|
||
"max_tokens": 800
|
||
}
|
||
|
||
try:
|
||
ans = chat_mdl.chat(
|
||
system=system_prompt.format(max_questions=max_questions),
|
||
history=[{"role": "user", "content": user_prompt}],
|
||
gen_conf=gen_config
|
||
)
|
||
# 分离JSON和解释
|
||
json_data,error = extract_and_parse_json(ans)
|
||
if json_data is not None:
|
||
# 强化JSON提取
|
||
return json_data,"解析正确"
|
||
else:
|
||
return {},"解析错误",tokens
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON解析失败: {str(e)}")
|
||
return {}, "格式解析错误"
|
||
except Exception as e:
|
||
print(f"生成失败: {str(e)}")
|
||
return {}, "生成过程异常"
|
||
|
||
|
||
|
||
MAX_BUFFER_LEN = 200 # 最大缓冲长度
|
||
FLUSH_TIMEOUT = 0.5 # 强制刷新时间(秒)
|
||
|
||
# 智能查找文本最佳分割点(标点/语义单位/短语边界)
|
||
def find_split_position(text):
|
||
"""智能查找最佳分割位置"""
|
||
# 优先查找句子结束符
|
||
sentence_end = list(re.finditer(r'[。!?]', text))
|
||
if sentence_end:
|
||
return sentence_end[-1].end()
|
||
|
||
# 其次查找自然停顿符
|
||
pause_mark = list(re.finditer(r'[,;、]', text))
|
||
if pause_mark:
|
||
return pause_mark[-1].end()
|
||
|
||
# 防止截断日期/数字短语
|
||
date_pattern = re.search(r'\d+(年|月|日)(?!\d)', text)
|
||
if date_pattern:
|
||
return date_pattern.end()
|
||
|
||
return None
|
||
|
||
# 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分
|
||
def process_buffer(chunk_buffer, force_flush=False):
|
||
"""处理文本缓冲区,返回待发送文本和剩余缓冲区"""
|
||
current_text = "".join(chunk_buffer)
|
||
if not current_text:
|
||
return "", []
|
||
split_pos = find_split_position(current_text)
|
||
# 强制刷新逻辑
|
||
if force_flush or len(current_text) >= MAX_BUFFER_LEN:
|
||
# 即使强制刷新也要尽量找合适的分割点
|
||
if split_pos is None or split_pos < len(current_text) // 2:
|
||
split_pos = max(split_pos or 0, MAX_BUFFER_LEN)
|
||
split_pos = min(split_pos, len(current_text))
|
||
|
||
if split_pos is not None and split_pos > 0:
|
||
to_tts_text = current_text[:split_pos]
|
||
remaining_text = [current_text[split_pos:]]
|
||
|
||
return to_tts_text,remaining_text
|
||
|
||
return None, chunk_buffer
|
||
|
||
def chat(dialog, messages, stream=True, **kwargs):
|
||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||
st = timer()
|
||
tmp = dialog.llm_id.split("@")
|
||
fid = None
|
||
llm_id = tmp[0]
|
||
if len(tmp)>1: fid = tmp[1]
|
||
#logging.info(f"dialog_service--0 message={messages}") # cyx
|
||
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
|
||
if not llm:
|
||
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
|
||
TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id, llm_factory=fid)
|
||
if not llm:
|
||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||
max_tokens = 8192
|
||
else:
|
||
max_tokens = llm[0].max_tokens
|
||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||
if len(embd_nms) != 1:
|
||
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||
|
||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||
|
||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
||
if "doc_ids" in messages[-1]:
|
||
attachments = messages[-1]["doc_ids"]
|
||
for m in messages[:-1]:
|
||
if "doc_ids" in m:
|
||
attachments.extend(m["doc_ids"])
|
||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||
if not embd_mdl:
|
||
raise LookupError("Embedding model(%s) not found" % embd_nms[0])
|
||
|
||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||
else:
|
||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||
|
||
prompt_config = dialog.prompt_config
|
||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||
tts_mdl = None
|
||
|
||
if prompt_config.get("tts"):
|
||
if kwargs.get('tts_model'):
|
||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,kwargs.get('tts_model'))
|
||
else:
|
||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS, dialog.tts_id)
|
||
|
||
if not kwargs.get("voice"): # 20251007 cyx 修改,没有传入voice 参数,则不需要生成tts
|
||
kwargs['tts_disable'] = True
|
||
|
||
tts_sample_rate = kwargs.get("tts_sample_rate",8000) # 默认为8K
|
||
tts_stream_format = kwargs.get("tts_stream_format","mp3") # 默认为mp3格式
|
||
# try to use sql if field mapping is good to go
|
||
if field_map:
|
||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
||
if ans:
|
||
yield ans
|
||
return
|
||
# logging.info(f"dialog_service--1 chat prompt_config{prompt_config['parameters']} {prompt_config}") # cyx
|
||
for p in prompt_config["parameters"]:
|
||
if p["key"] == "knowledge":
|
||
continue
|
||
if p["key"] not in kwargs and not p["optional"]:
|
||
raise KeyError("Miss parameter: " + p["key"])
|
||
if p["key"] not in kwargs:
|
||
prompt_config["system"] = prompt_config["system"].replace(
|
||
"{%s}" % p["key"], " ")
|
||
|
||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||
else:
|
||
questions = questions[-1:]
|
||
refineQ_tm = timer()
|
||
keyword_tm = timer()
|
||
rerank_mdl = None
|
||
if dialog.rerank_id:
|
||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||
|
||
for _ in range(len(questions) // 2):
|
||
questions.append(questions[-1])
|
||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||
else:
|
||
if prompt_config.get("keyword", False):
|
||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||
keyword_tm = timer()
|
||
|
||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||
kbinfos = retr.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||
dialog.similarity_threshold,
|
||
dialog.vector_similarity_weight,
|
||
doc_ids=attachments,
|
||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||
logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||
# 打印查询到的知识库信息
|
||
#logging.info( "知识库中知识--!!!:{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||
retrieval_tm = timer()
|
||
|
||
if not knowledges and prompt_config.get("empty_response"):
|
||
empty_res = prompt_config["empty_response"]
|
||
yield {"answer": empty_res, "reference": kbinfos, "audio_binary":
|
||
tts(tts_mdl, empty_res,sample_rate=tts_sample_rate,stream_format=tts_stream_format)}
|
||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||
|
||
kwargs["knowledge"] = "\n\n------\n\n".join(knowledges)
|
||
gen_conf = dialog.llm_setting
|
||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||
|
||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
||
for m in messages if m["role"] != "system"])
|
||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||
prompt = msg[0]["content"]
|
||
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
||
#logging.info(f"dialog_service--3 chat msg={msg}") # cyx
|
||
if "max_tokens" in gen_conf:
|
||
gen_conf["max_tokens"] = min(
|
||
gen_conf["max_tokens"],
|
||
max_tokens - used_token_count)
|
||
def decorate_answer(answer):
|
||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_tm
|
||
refs = []
|
||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||
answer, idx = retr.insert_citations(answer,
|
||
[ck["content_ltks"]
|
||
for ck in kbinfos["chunks"]],
|
||
[ck["vector"]
|
||
for ck in kbinfos["chunks"]],
|
||
embd_mdl,
|
||
tkweight=1 - dialog.vector_similarity_weight,
|
||
vtweight=dialog.vector_similarity_weight)
|
||
# 上述转换过程中,发现有时候会在answer中插入类似##0$$ ##1$$ 这样的字符串,需要去除
|
||
# cyx 20250407
|
||
answer = re.sub(r'##\d+\$\$', '', answer).strip() #去除##0$$类似内容 同时去除多余空格
|
||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||
recall_docs = [
|
||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||
kbinfos["doc_aggs"] = recall_docs
|
||
|
||
refs = deepcopy(kbinfos)
|
||
for c in refs["chunks"]:
|
||
if c.get("vector"):
|
||
del c["vector"]
|
||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||
done_tm = timer()
|
||
prompt += "\n\n### Elapsed\n - Refine Question: %.1f ms\n - Keywords: %.1f ms\n - Retrieval: %.1f ms\n - LLM: %.1f ms" % (
|
||
(refineQ_tm - st) * 1000, (keyword_tm - refineQ_tm) * 1000, (retrieval_tm - keyword_tm) * 1000,
|
||
(done_tm - retrieval_tm) * 1000)
|
||
#return {"answer": answer, "prompt": prompt,"reference": refs }
|
||
# cyx 增加 20250510 生成后续追问的内容
|
||
# cyx 修改 20250422 不向前端发送prompt 和 refs ,增加发送 finished 标志
|
||
return {"answer": answer, "finished":True,"reference":""}
|
||
if stream:
|
||
last_ans = ""
|
||
answer = ""
|
||
audio_url = None
|
||
tts_session_id = None
|
||
if not kwargs.get('tts_disable'):
|
||
# 创建TTS会话(提前初始化)
|
||
tts_session_id = stream_manager.create_session(tts_mdl,sample_rate=tts_sample_rate,stream_format=tts_stream_format,
|
||
voice = kwargs.get('tts_model'))
|
||
tts_session = stream_manager.get_session(tts_session_id)
|
||
audio_url = f"/tts_stream/{tts_session_id}"
|
||
send_tts_url = False
|
||
chunk_buffer = [] # 新增文本缓冲
|
||
last_flush_time = time.time() # 初始化时间戳
|
||
# 下面优先处理知识库中没有找到相关内容 cyx 20250323 修改
|
||
if not kwargs["knowledge"] or kwargs["knowledge"] =="" or len(kwargs["knowledge"]) < 4:
|
||
if not kwargs.get('tts_disable'):
|
||
stream_manager.append_text(tts_session_id, "未找到相关内容")
|
||
yield {
|
||
"answer": "未找到相关内容",
|
||
"delta_ans": "未找到相关内容",
|
||
"session_id": tts_session_id,
|
||
"reference": {},
|
||
"audio_stream_url": audio_url,
|
||
"sample_rate": tts_sample_rate,
|
||
"stream_format": tts_stream_format,
|
||
}
|
||
else:
|
||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||
answer = ans
|
||
delta_ans = ans[len(last_ans):]
|
||
if num_tokens_from_string(delta_ans) < 24:
|
||
continue
|
||
last_ans = answer
|
||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
|
||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||
# cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||
if kwargs.get('tts_disable'):
|
||
tts_input_is_valid =False
|
||
if tts_input_is_valid :
|
||
# 缓冲文本直到遇到标点
|
||
chunk_buffer.append(sanitized_text)
|
||
# 处理缓冲区内容
|
||
while True:
|
||
# 判断是否需要强制刷新
|
||
force = time.time() - last_flush_time > FLUSH_TIMEOUT
|
||
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
|
||
if not to_send:
|
||
break
|
||
|
||
# 发送有效内容
|
||
stream_manager.append_text(tts_session_id, to_send)
|
||
chunk_buffer = remaining
|
||
last_flush_time = time.time()
|
||
"""
|
||
if tts_input_is_valid:
|
||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||
else:
|
||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||
"""
|
||
# 首块返回音频URL
|
||
if send_tts_url is False and not kwargs.get('tts_disable'):
|
||
if tts_session['tts_chunk_data_valid'] is True:
|
||
yield {
|
||
"answer": answer,
|
||
"delta_ans": sanitized_text,
|
||
"session_id": tts_session_id,
|
||
"reference": {},
|
||
"audio_stream_url": audio_url,
|
||
"sample_rate":tts_sample_rate,
|
||
"stream_format":tts_stream_format,
|
||
}
|
||
send_tts_url = True # 发送一次tts url 给前端即可,不能重复发送
|
||
logging.info(f"--chat retur tts url {audio_url}")
|
||
else:
|
||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||
|
||
|
||
delta_ans = answer[len(last_ans):]
|
||
if delta_ans:
|
||
# stream_manager.append_text(tts_session_id, delta_ans)
|
||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
|
||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||
tts_input_is_valid = False
|
||
if tts_input_is_valid :
|
||
# 20250221 修改,在后端生成音频数据
|
||
chunk_buffer.append(sanitized_text)
|
||
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
|
||
if to_send:
|
||
stream_manager.append_text(tts_session_id, to_send)
|
||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||
"""
|
||
if tts_input_is_valid:
|
||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||
else:
|
||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||
"""
|
||
yield decorate_answer(answer)
|
||
|
||
else:
|
||
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||
logging.debug("User: {}|Assistant: {}".format(
|
||
msg[-1]["content"], answer))
|
||
res = decorate_answer(answer)
|
||
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||
tts_input_is_valid = False
|
||
else:
|
||
res["audio_binary"] = tts(tts_mdl, answer,tts_sample_rate,tts_stream_format)
|
||
yield res
|
||
|
||
|
||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
||
sys_prompt = "你是一个DBA。你需要这对以下表的字段结构,根据用户的问题列表,写出最后一个问题对应的SQL。"
|
||
user_promt = """
|
||
表名:{};
|
||
数据库表字段说明如下:
|
||
{}
|
||
|
||
问题如下:
|
||
{}
|
||
请写出SQL, 且只要SQL,不要有其他说明及文字。
|
||
""".format(
|
||
index_name(tenant_id),
|
||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||
question
|
||
)
|
||
tried_times = 0
|
||
|
||
def get_table():
|
||
nonlocal sys_prompt, user_promt, question, tried_times
|
||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_promt}], {
|
||
"temperature": 0.06})
|
||
logging.debug(f"{question} ==> {user_promt} get SQL: {sql}")
|
||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||
sql = re.sub(r" +", " ", sql)
|
||
sql = re.sub(r"([;;]|```).*", "", sql)
|
||
if sql[:len("select ")] != "select ":
|
||
return None, None
|
||
if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()):
|
||
if sql[:len("select *")] != "select *":
|
||
sql = "select doc_id,docnm_kwd," + sql[6:]
|
||
else:
|
||
flds = []
|
||
for k in field_map.keys():
|
||
if k in forbidden_select_fields4resume:
|
||
continue
|
||
if len(flds) > 11:
|
||
break
|
||
flds.append(k)
|
||
sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:]
|
||
|
||
logging.debug(f"{question} get SQL(refined): {sql}")
|
||
tried_times += 1
|
||
return settings.retrievaler.sql_retrieval(sql, format="json"), sql
|
||
|
||
tbl, sql = get_table()
|
||
if tbl is None:
|
||
return None
|
||
if tbl.get("error") and tried_times <= 2:
|
||
user_promt = """
|
||
表名:{};
|
||
数据库表字段说明如下:
|
||
{}
|
||
|
||
问题如下:
|
||
{}
|
||
|
||
你上一次给出的错误SQL如下:
|
||
{}
|
||
|
||
后台报错如下:
|
||
{}
|
||
|
||
请纠正SQL中的错误再写一遍,且只要SQL,不要有其他说明及文字。
|
||
""".format(
|
||
index_name(tenant_id),
|
||
"\n".join([f"{k}: {v}" for k, v in field_map.items()]),
|
||
question, sql, tbl["error"]
|
||
)
|
||
tbl, sql = get_table()
|
||
logging.debug("TRY it again: {}".format(sql))
|
||
|
||
logging.debug("GET table: {}".format(tbl))
|
||
if tbl.get("error") or len(tbl["rows"]) == 0:
|
||
return None
|
||
|
||
docid_idx = set([ii for ii, c in enumerate(
|
||
tbl["columns"]) if c["name"] == "doc_id"])
|
||
docnm_idx = set([ii for ii, c in enumerate(
|
||
tbl["columns"]) if c["name"] == "docnm_kwd"])
|
||
clmn_idx = [ii for ii in range(
|
||
len(tbl["columns"])) if ii not in (docid_idx | docnm_idx)]
|
||
|
||
# compose markdown table
|
||
clmns = "|" + "|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"],
|
||
tbl["columns"][i]["name"])) for i in
|
||
clmn_idx]) + ("|Source|" if docid_idx and docid_idx else "|")
|
||
|
||
line = "|" + "|".join(["------" for _ in range(len(clmn_idx))]) + \
|
||
("|------|" if docid_idx and docid_idx else "")
|
||
|
||
rows = ["|" +
|
||
"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") +
|
||
"|" for r in tbl["rows"]]
|
||
rows = [r for r in rows if re.sub(r"[ |]+", "", r)]
|
||
if quota:
|
||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||
else:
|
||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||
|
||
if not docid_idx or not docnm_idx:
|
||
logging.warning("SQL missing field: " + sql)
|
||
return {
|
||
"answer": "\n".join([clmns, line, rows]),
|
||
"reference": {"chunks": [], "doc_aggs": []},
|
||
"prompt": sys_prompt
|
||
}
|
||
|
||
docid_idx = list(docid_idx)[0]
|
||
docnm_idx = list(docnm_idx)[0]
|
||
doc_aggs = {}
|
||
for r in tbl["rows"]:
|
||
if r[docid_idx] not in doc_aggs:
|
||
doc_aggs[r[docid_idx]] = {"doc_name": r[docnm_idx], "count": 0}
|
||
doc_aggs[r[docid_idx]]["count"] += 1
|
||
return {
|
||
"answer": "\n".join([clmns, line, rows]),
|
||
"reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
|
||
"doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
|
||
doc_aggs.items()]},
|
||
"prompt": sys_prompt
|
||
}
|
||
|
||
|
||
def relevant(tenant_id, llm_id, question, contents: list):
|
||
if llm_id2llm_type(llm_id) == "image2text":
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||
else:
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||
prompt = """
|
||
You are a grader assessing relevance of a retrieved document to a user question.
|
||
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
|
||
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
|
||
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
|
||
No other words needed except 'yes' or 'no'.
|
||
"""
|
||
if not contents:return False
|
||
contents = "Documents: \n" + " - ".join(contents)
|
||
contents = f"Question: {question}\n" + contents
|
||
if num_tokens_from_string(contents) >= chat_mdl.max_length - 4:
|
||
contents = encoder.decode(encoder.encode(contents)[:chat_mdl.max_length - 4])
|
||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": contents}], {"temperature": 0.01})
|
||
if ans.lower().find("yes") >= 0: return True
|
||
return False
|
||
|
||
|
||
def rewrite(tenant_id, llm_id, question):
|
||
if llm_id2llm_type(llm_id) == "image2text":
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||
else:
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||
prompt = """
|
||
You are an expert at query expansion to generate a paraphrasing of a question.
|
||
I can't retrieval relevant information from the knowledge base by using user's question directly.
|
||
You need to expand or paraphrase user's question by multiple ways such as using synonyms words/phrase,
|
||
writing the abbreviation in its entirety, adding some extra descriptions or explanations,
|
||
changing the way of expression, translating the original question into another language (English/Chinese), etc.
|
||
And return 5 versions of question and one is from translation.
|
||
Just list the question. No other words are needed.
|
||
"""
|
||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
||
return ans
|
||
|
||
|
||
def keyword_extraction(chat_mdl, content, topn=3):
|
||
prompt = f"""
|
||
Role: You're a text analyzer.
|
||
Task: extract the most important keywords/phrases of a given piece of text content.
|
||
Requirements:
|
||
- Summarize the text content, and give top {topn} important keywords/phrases.
|
||
- The keywords MUST be in language of the given piece of text content.
|
||
- The keywords are delimited by ENGLISH COMMA.
|
||
- Keywords ONLY in output.
|
||
|
||
### Text Content
|
||
{content}
|
||
|
||
"""
|
||
msg = [
|
||
{"role": "system", "content": prompt},
|
||
{"role": "user", "content": "Output: "}
|
||
]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||
if isinstance(kwd, tuple): kwd = kwd[0]
|
||
if kwd.find("**ERROR**") >=0: return ""
|
||
return kwd
|
||
|
||
|
||
def question_proposal(chat_mdl, content, topn=3):
|
||
prompt = f"""
|
||
Role: You're a text analyzer.
|
||
Task: propose {topn} questions about a given piece of text content.
|
||
Requirements:
|
||
- Understand and summarize the text content, and propose top {topn} important questions.
|
||
- The questions SHOULD NOT have overlapping meanings.
|
||
- The questions SHOULD cover the main content of the text as much as possible.
|
||
- The questions MUST be in language of the given piece of text content.
|
||
- One question per line.
|
||
- Question ONLY in output.
|
||
|
||
### Text Content
|
||
{content}
|
||
|
||
"""
|
||
msg = [
|
||
{"role": "system", "content": prompt},
|
||
{"role": "user", "content": "Output: "}
|
||
]
|
||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||
kwd = chat_mdl.chat(prompt, msg[1:], {"temperature": 0.2})
|
||
if isinstance(kwd, tuple): kwd = kwd[0]
|
||
if kwd.find("**ERROR**") >= 0: return ""
|
||
return kwd
|
||
|
||
|
||
def full_question(tenant_id, llm_id, messages):
|
||
if llm_id2llm_type(llm_id) == "image2text":
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||
else:
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||
conv = []
|
||
for m in messages:
|
||
if m["role"] not in ["user", "assistant"]: continue
|
||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||
conv = "\n".join(conv)
|
||
today = datetime.date.today().isoformat()
|
||
yesterday = (datetime.date.today() - timedelta(days=1)).isoformat()
|
||
tomorrow = (datetime.date.today() + timedelta(days=1)).isoformat()
|
||
prompt = f"""
|
||
Role: A helpful assistant
|
||
|
||
Task and steps:
|
||
1. Generate a full user question that would follow the conversation.
|
||
2. If the user's question involves relative date, you need to convert it into absolute date based on the current date, which is {today}. For example: 'yesterday' would be converted to {yesterday}.
|
||
|
||
Requirements & Restrictions:
|
||
- Text generated MUST be in the same language of the original user's question.
|
||
- If the user's latest question is completely, don't do anything, just return the original question.
|
||
- DON'T generate anything except a refined question.
|
||
|
||
######################
|
||
-Examples-
|
||
######################
|
||
|
||
# Example 1
|
||
## Conversation
|
||
USER: What is the name of Donald Trump's father?
|
||
ASSISTANT: Fred Trump.
|
||
USER: And his mother?
|
||
###############
|
||
Output: What's the name of Donald Trump's mother?
|
||
|
||
------------
|
||
# Example 2
|
||
## Conversation
|
||
USER: What is the name of Donald Trump's father?
|
||
ASSISTANT: Fred Trump.
|
||
USER: And his mother?
|
||
ASSISTANT: Mary Trump.
|
||
User: What's her full name?
|
||
###############
|
||
Output: What's the full name of Donald Trump's mother Mary Trump?
|
||
|
||
------------
|
||
# Example 3
|
||
## Conversation
|
||
USER: What's the weather today in London?
|
||
ASSISTANT: Cloudy.
|
||
USER: What's about tomorrow in Rochester?
|
||
###############
|
||
Output: What's the weather in Rochester on {tomorrow}?
|
||
######################
|
||
|
||
# Real Data
|
||
## Conversation
|
||
{conv}
|
||
###############
|
||
"""
|
||
ans = chat_mdl.chat(prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
|
||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||
|
||
|
||
def tts(tts_mdl, text,sample_rate=8000,stream_format = "mp3"):
|
||
if not tts_mdl or not text: return
|
||
bin = b""
|
||
for chunk in tts_mdl.tts(text,sample_rate,stream_format):
|
||
bin += chunk
|
||
return binascii.hexlify(bin).decode("utf-8")
|
||
|
||
|
||
def ask(question, kb_ids, tenant_id):
|
||
kbs = KnowledgebaseService.get_by_ids(kb_ids)
|
||
tenant_ids = [kb.tenant_id for kb in kbs]
|
||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||
|
||
is_kg = all([kb.parser_id == ParserType.KG for kb in kbs])
|
||
retr = settings.retrievaler if not is_kg else settings.kg_retrievaler
|
||
|
||
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||
max_tokens = chat_mdl.max_length
|
||
|
||
kbinfos = retr.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False)
|
||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||
|
||
used_token_count = 0
|
||
for i, c in enumerate(knowledges):
|
||
used_token_count += num_tokens_from_string(c)
|
||
if max_tokens * 0.97 < used_token_count:
|
||
knowledges = knowledges[:i]
|
||
break
|
||
|
||
prompt = """
|
||
Role: You're a smart assistant. Your name is Miss R.
|
||
Task: Summarize the information from knowledge bases and answer user's question.
|
||
Requirements and restriction:
|
||
- DO NOT make things up, especially for numbers.
|
||
- If the information from knowledge is irrelevant with user's question, JUST SAY: Sorry, no relevant information provided.
|
||
- Answer with markdown format text.
|
||
- Answer in language of user's question.
|
||
- DO NOT make things up, especially for numbers.
|
||
|
||
### Information from knowledge bases
|
||
%s
|
||
|
||
The above is information from knowledge bases.
|
||
|
||
"""%"\n".join(knowledges)
|
||
msg = [{"role": "user", "content": question}]
|
||
|
||
def decorate_answer(answer):
|
||
nonlocal knowledges, kbinfos, prompt
|
||
answer, idx = retr.insert_citations(answer,
|
||
[ck["content_ltks"]
|
||
for ck in kbinfos["chunks"]],
|
||
[ck["vector"]
|
||
for ck in kbinfos["chunks"]],
|
||
embd_mdl,
|
||
tkweight=0.7,
|
||
vtweight=0.3)
|
||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||
recall_docs = [
|
||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||
if not recall_docs: recall_docs = kbinfos["doc_aggs"]
|
||
kbinfos["doc_aggs"] = recall_docs
|
||
refs = deepcopy(kbinfos)
|
||
for c in refs["chunks"]:
|
||
if c.get("vector"):
|
||
del c["vector"]
|
||
|
||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||
return {"answer": answer, "reference": refs}
|
||
|
||
answer = ""
|
||
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
||
answer = ans
|
||
yield {"answer": answer, "reference": {}}
|
||
yield decorate_answer(answer)
|
||
|