2024-08-15 09:17:36 +08:00
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
2024-11-14 17:13:48 +08:00
import logging
2024-09-03 19:49:14 +08:00
import binascii
2024-08-15 09:17:36 +08:00
import os
import json
import re
from copy import deepcopy
2024-09-09 12:08:50 +08:00
from timeit import default_timer as timer
2024-11-13 13:49:18 +08:00
import datetime
from datetime import timedelta
2024-10-12 13:48:43 +08:00
from api . db import LLMType , ParserType , StatusEnum
from api . db . db_models import Dialog , Conversation , DB
2024-08-15 09:17:36 +08:00
from api . db . services . common_service import CommonService
from api . db . services . knowledgebase_service import KnowledgebaseService
from api . db . services . llm_service import LLMService , TenantLLMService , LLMBundle
2024-11-15 17:30:56 +08:00
from api import settings
2024-08-15 09:17:36 +08:00
from rag . app . resume import forbidden_select_fields4resume
from rag . nlp . search import index_name
from rag . utils import rmSpace , num_tokens_from_string , encoder
from api . utils . file_utils import get_project_base_directory
2025-02-06 23:34:26 +08:00
from peewee import fn
2025-04-08 08:41:07 +08:00
import threading , queue , uuid , time , array
2025-02-23 09:52:30 +08:00
from concurrent . futures import ThreadPoolExecutor
2025-07-10 22:04:44 +08:00
from api . db . services . ali_tts_service import ( stream_manager_w_stream as stream_manager )
2025-02-23 09:52:30 +08:00
2025-04-08 08:41:07 +08:00
def audio_fade_in ( audio_data , fade_length ) :
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array . array ( ' h ' , audio_data )
# 对前fade_length个样本进行淡入处理
for i in range ( fade_length ) :
fade_factor = i / fade_length
samples [ i ] = int ( samples [ i ] * fade_factor )
# 将整数数组转换回二进制数据
return samples . tobytes ( )
2025-02-23 09:52:30 +08:00
class StreamSessionManager :
def __init__ ( self ) :
self . sessions = { } # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}}
self . lock = threading . Lock ( )
self . executor = ThreadPoolExecutor ( max_workers = 30 ) # 固定大小线程池
2025-05-26 21:38:46 +08:00
self . gc_interval = 300 # 5分钟清理一次 5 x 60 300秒
self . gc_tts = 10 # 10s 大模型开始输出文本有可能需要比较久, 2025年5 24 从3s->10s
2025-04-08 08:41:07 +08:00
def create_session ( self , tts_model , sample_rate = 8000 , stream_format = ' mp3 ' ) :
2025-02-23 09:52:30 +08:00
session_id = str ( uuid . uuid4 ( ) )
with self . lock :
self . sessions [ session_id ] = {
' tts_model ' : tts_model ,
2025-04-08 08:41:07 +08:00
' buffer ' : queue . Queue ( maxsize = 300 ) , # 线程安全队列
2025-02-23 09:52:30 +08:00
' task_queue ' : queue . Queue ( ) ,
' active ' : True ,
' last_active ' : time . time ( ) ,
2025-04-08 08:41:07 +08:00
' audio_chunk_count ' : 0 ,
' finished ' : threading . Event ( ) , # 添加事件对象
' sample_rate ' : sample_rate ,
2025-05-15 15:26:06 +08:00
' stream_format ' : stream_format ,
2025-07-19 22:44:28 +08:00
" tts_chunk_data_valid " : False ,
" sentence_complete_event " : threading . Event ( ) ,
" current_processing " : False # 标记是否正在处理句子
2025-02-23 09:52:30 +08:00
}
# 启动任务处理线程
threading . Thread ( target = self . _process_tasks , args = ( session_id , ) , daemon = True ) . start ( )
return session_id
def append_text ( self , session_id , text ) :
with self . lock :
session = self . sessions . get ( session_id )
if not session : return
# 将文本放入任务队列(非阻塞)
2025-05-26 21:38:46 +08:00
#logging.info(f"StreamSessionManager append_text {text}")
2025-02-23 09:52:30 +08:00
try :
session [ ' task_queue ' ] . put ( text , block = False )
except queue . Full :
logging . warning ( f " Session { session_id } task queue full " )
2025-02-06 23:34:26 +08:00
2025-02-23 09:52:30 +08:00
def _process_tasks ( self , session_id ) :
""" 任务处理线程(每个会话独立) """
2025-02-06 23:34:26 +08:00
while True :
2025-02-23 09:52:30 +08:00
session = self . sessions . get ( session_id )
if not session or not session [ ' active ' ] :
2025-02-06 23:34:26 +08:00
break
try :
2025-05-26 21:38:46 +08:00
#logging.info(f"StreamSessionManager _process_tasks {session['task_queue'].qsize()}")
2025-02-23 09:52:30 +08:00
# 合并多个文本块( 最多等待50ms)
texts = [ ]
while len ( texts ) < 5 : # 最大合并5个文本块
try :
2025-05-26 21:38:46 +08:00
text = session [ ' task_queue ' ] . get ( timeout = 0.1 )
#logging.info(f"StreamSessionManager _process_tasks --0 {len(texts)}")
2025-02-23 09:52:30 +08:00
texts . append ( text )
except queue . Empty :
break
if texts :
2025-05-26 21:38:46 +08:00
session [ ' last_active ' ] = time . time ( ) # 如果有处理文本,重置活跃时间
2025-02-23 09:52:30 +08:00
# 提交到线程池处理
future = self . executor . submit (
self . _generate_audio ,
session_id ,
' ' . join ( texts ) # 合并文本减少请求次数
)
future . result ( ) # 等待转换任务执行完毕
2025-05-26 21:38:46 +08:00
session [ ' last_active ' ] = time . time ( )
2025-02-23 09:52:30 +08:00
# 会话超时检查
if time . time ( ) - session [ ' last_active ' ] > self . gc_interval :
self . close_session ( session_id )
break
2025-04-08 08:41:07 +08:00
if time . time ( ) - session [ ' last_active ' ] > self . gc_tts :
session [ ' finished ' ] . set ( )
break
2024-08-15 09:17:36 +08:00
2025-02-23 09:52:30 +08:00
except Exception as e :
logging . error ( f " Task processing error: { str ( e ) } " )
2025-07-19 22:44:28 +08:00
def _generate_audio1 ( self , session_id , text ) :
2025-02-23 09:52:30 +08:00
""" 实际生成音频(线程池执行) """
session = self . sessions . get ( session_id )
if not session : return
# logging.info(f"_generate_audio:{text}")
2025-04-08 08:41:07 +08:00
first_chunk = True
2025-05-26 21:38:46 +08:00
logging . info ( f " 转换开始!!! { text } " )
2025-02-23 09:52:30 +08:00
try :
2025-04-08 08:41:07 +08:00
for chunk in session [ ' tts_model ' ] . tts ( text , session [ ' sample_rate ' ] , session [ ' stream_format ' ] ) :
if session [ ' stream_format ' ] == ' wav ' :
if first_chunk :
chunk_len = len ( chunk )
if chunk_len > 2048 :
session [ ' buffer ' ] . put ( audio_fade_in ( chunk , 1024 ) )
else :
session [ ' buffer ' ] . put ( audio_fade_in ( chunk , chunk_len ) )
first_chunk = False
else :
session [ ' buffer ' ] . put ( chunk )
else :
session [ ' buffer ' ] . put ( chunk )
2025-02-23 09:52:30 +08:00
session [ ' last_active ' ] = time . time ( )
session [ ' audio_chunk_count ' ] = session [ ' audio_chunk_count ' ] + 1
2025-05-15 15:26:06 +08:00
if session [ ' tts_chunk_data_valid ' ] is False :
session [ ' tts_chunk_data_valid ' ] = True #20250510 增加, 表示连接TTS后台已经返回, 可以通知前端了
2025-02-23 09:52:30 +08:00
logging . info ( f " 转换结束!!! { session [ ' audio_chunk_count ' ] } " )
except Exception as e :
session [ ' buffer ' ] . put ( f " ERROR: { str ( e ) } " )
2025-04-08 08:41:07 +08:00
logging . info ( f " --_generate_audio--error { str ( e ) } " )
2025-07-19 22:44:28 +08:00
def _generate_audio ( self , session_id , text ) :
""" 实际生成音频(顺序执行)- 用于非流式引擎 """
session = self . sessions . get ( session_id )
if not session :
return
try :
# 调用 TTS
session [ ' tts_model ' ] . text_tts_call ( text )
# 标记完成
session [ ' sentence_complete_event ' ] . set ( )
session [ ' last_active ' ] = time . time ( )
session [ ' audio_chunk_count ' ] + = 1
if not session [ ' tts_chunk_data_valid ' ] :
session [ ' tts_chunk_data_valid ' ] = True
except Exception as e :
session [ ' buffer ' ] . put ( f " ERROR: { str ( e ) } " . encode ( ) )
session [ ' sentence_complete_event ' ] . set ( ) # 确保事件被设置
2025-02-23 09:52:30 +08:00
def close_session ( self , session_id ) :
with self . lock :
if session_id in self . sessions :
# 标记会话为不活跃
self . sessions [ session_id ] [ ' active ' ] = False
2025-04-08 08:41:07 +08:00
# 延迟2秒后清理资源
threading . Timer ( 1 , self . _clean_session , args = [ session_id ] ) . start ( )
2025-02-23 09:52:30 +08:00
def _clean_session ( self , session_id ) :
with self . lock :
if session_id in self . sessions :
del self . sessions [ session_id ]
2025-05-15 15:26:06 +08:00
def get_session ( self , session_id ) :
return self . sessions . get ( session_id )
2025-07-10 22:04:44 +08:00
stream_manager_bk = StreamSessionManager ( )
2024-08-15 09:17:36 +08:00
class DialogService ( CommonService ) :
model = Dialog
2024-10-12 13:48:43 +08:00
@classmethod
@DB.connection_context ( )
def get_list ( cls , tenant_id ,
page_number , items_per_page , orderby , desc , id , name ) :
chats = cls . model . select ( )
if id :
chats = chats . where ( cls . model . id == id )
if name :
chats = chats . where ( cls . model . name == name )
chats = chats . where (
( cls . model . tenant_id == tenant_id )
& ( cls . model . status == StatusEnum . VALID . value )
)
if desc :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
else :
chats = chats . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
chats = chats . paginate ( page_number , items_per_page )
return list ( chats . dicts ( ) )
2024-08-15 09:17:36 +08:00
class ConversationService ( CommonService ) :
model = Conversation
2024-10-12 19:35:19 +08:00
@classmethod
@DB.connection_context ( )
2025-02-06 23:34:26 +08:00
def get_list ( cls , dialog_id , page_number , items_per_page , orderby , desc , id , name , cols = None ) :
# 构建基础查询
print ( " --ConversationService get_list enter " , page_number , items_per_page ) # cyx
query = cls . model . select ( ) . where ( cls . model . dialog_id == dialog_id )
# 如果指定了ID, 则添加ID筛选
2024-10-12 19:35:19 +08:00
if id :
2025-02-06 23:34:26 +08:00
query = query . where ( cls . model . id == id )
# 如果指定了名称,则添加名称筛选
2024-10-12 19:35:19 +08:00
if name :
2025-02-06 23:34:26 +08:00
query = query . where ( cls . model . name == name )
# 如果指定了列筛选,则只选择指定的列
if cols :
query = query . select ( * [ getattr ( cls . model , col ) for col in cols ] )
# 获取记录总数
total = query . count ( )
# 添加排序
2024-10-12 19:35:19 +08:00
if desc :
2025-02-06 23:34:26 +08:00
query = query . order_by ( cls . model . getter_by ( orderby ) . desc ( ) )
2024-10-12 19:35:19 +08:00
else :
2025-02-06 23:34:26 +08:00
query = query . order_by ( cls . model . getter_by ( orderby ) . asc ( ) )
# 执行分页查询
paginated_query = query . paginate ( page_number , items_per_page )
data = list ( paginated_query . dicts ( ) )
# logging.info("--ConversationService get_list",total, data) #cyx
# 返回分页数据和记录总数
return total , data
2024-10-12 19:35:19 +08:00
2024-08-15 09:17:36 +08:00
2025-02-06 23:34:26 +08:00
@classmethod
@DB.connection_context ( )
def query_sessions_summary ( cls ) :
# 按 id 分组,统计每个 id 的最旧记录
query = (
cls . model
. select (
cls . model . id ,
cls . model . dialog_id ,
cls . model . name ,
fn . MIN ( cls . model . create_time ) . alias ( " create_time " ) ,
fn . MIN ( cls . model . create_date ) . alias ( " create_date " )
)
. group_by ( cls . model . id , cls . model . dialog_id , cls . model . name )
. order_by (
fn . MIN ( cls . model . create_time ) . desc ( ) ,
)
)
# 转换为字典列表返回
return list ( query . dicts ( ) )
2024-10-22 13:12:49 +08:00
2024-08-15 09:17:36 +08:00
def message_fit_in ( msg , max_length = 4000 ) :
def count ( ) :
nonlocal msg
tks_cnts = [ ]
for m in msg :
tks_cnts . append (
{ " role " : m [ " role " ] , " count " : num_tokens_from_string ( m [ " content " ] ) } )
total = 0
for m in tks_cnts :
total + = m [ " count " ]
return total
c = count ( )
if c < max_length :
return c , msg
msg_ = [ m for m in msg [ : - 1 ] if m [ " role " ] == " system " ]
2024-11-19 18:41:48 +08:00
if len ( msg ) > 1 :
msg_ . append ( msg [ - 1 ] )
2024-08-15 09:17:36 +08:00
msg = msg_
c = count ( )
if c < max_length :
return c , msg
ll = num_tokens_from_string ( msg_ [ 0 ] [ " content " ] )
l = num_tokens_from_string ( msg_ [ - 1 ] [ " content " ] )
if ll / ( ll + l ) > 0.8 :
m = msg_ [ 0 ] [ " content " ]
m = encoder . decode ( encoder . encode ( m ) [ : max_length - l ] )
msg [ 0 ] [ " content " ] = m
return max_length , msg
m = msg_ [ 1 ] [ " content " ]
m = encoder . decode ( encoder . encode ( m ) [ : max_length - l ] )
msg [ 1 ] [ " content " ] = m
return max_length , msg
def llm_id2llm_type ( llm_id ) :
2024-09-18 16:09:22 +08:00
llm_id = llm_id . split ( " @ " ) [ 0 ]
2024-08-15 09:17:36 +08:00
fnm = os . path . join ( get_project_base_directory ( ) , " conf " )
llm_factories = json . load ( open ( os . path . join ( fnm , " llm_factories.json " ) , " r " ) )
for llm_factory in llm_factories [ " factory_llm_infos " ] :
for llm in llm_factory [ " llm " ] :
if llm_id == llm [ " llm_name " ] :
return llm [ " model_type " ] . strip ( " , " ) [ - 1 ]
2024-10-12 13:48:43 +08:00
2025-05-15 15:26:06 +08:00
followup_seperator = " 继续追问: "
2025-02-06 23:34:26 +08:00
# cyx 2024 12 04
# 用于校验和修正语音合成的输入文本。该函数会去除非法字符、修正内容,并返回一个结果:包括是否有效和修正后的文本。
def validate_and_sanitize_tts_input ( delta_ans , max_length = 3000 ) :
"""
检验并修正语音合成的输入文本 。
Args :
delta_ans ( str ) : 输入的待校验文本 。
max_length ( int ) : 文本允许的最大长度 。
Returns :
tuple : ( is_valid , sanitized_text )
- is_valid ( bool ) : 文本是否有效 。
- sanitized_text ( str ) : 修正后的文本 ( 如果无效 , 为空字符串 ) 。
"""
# 1. 确保输入为字符串
if not isinstance ( delta_ans , str ) :
return False , " "
# 2. 去除前后空白并检查是否为空
delta_ans = delta_ans . strip ( )
if len ( delta_ans ) == 0 :
return False , " "
# 3. 替换全角符号为半角
delta_ans = re . sub ( r ' [? ] ' , ' ? ' , delta_ans )
# 4. 移除非法字符(仅保留中文、英文、数字及常见标点符号)
delta_ans = re . sub ( r ' [^ \ u4e00- \ u9fa5a-zA-Z0-9 \ s,.!? \' " ;;。,!?:”“() \ -()] ' , ' ' , delta_ans )
# 5. 检查长度
if len ( delta_ans ) == 0 or len ( delta_ans ) > max_length :
return False , " "
2025-05-15 15:26:06 +08:00
# """清理流式输出中可能存在的 jsonjsonjson 及之后的内容"""
"""
检查子串存在性 : 使用 in 关键字判断字符串 ans 是否包含子串 " jsonjsonjson " 。
分割字符串 : 若存在 , 使用 split ( ' jsonjsonjson ' , 1 ) 分割一次 , 1 确保只分割首个匹配项 , 避免后续重复子串的影响 。
提取前部分 : 取分割后的第一个元素 [ 0 ] , 即目标子串前的内容 。
处理不存在情况 : 根据需求返回空字符串或原字符串 , 示例中返回空字符串 。
if followup_seperator in delta_ans :
# 分割字符串,取第一个部分
delta_ans = delta_ans . split ( followup_seperator , 1 ) [ 0 ]
json_markdown_separator_found = True
"""
"""
# 方法:split 分割 假设```json 在文本的最后,去除```json 的内容
strings_split = delta_ans . split ( ' ```json ' , 1 )
ans_remove_json = strings_split [ 0 ] . rstrip ( )
if len ( strings_split ) > 1 :
found_last_json_markdown = True
# 中间结果可能是
# json
# "questions"
# "通州起义的具体 思想基础。
pattern = r ' ^(.*?)json(?=. { 0,10} " questions " ) ' # 关键修改点:添加双引号
if match := re . search ( pattern , ans_remove_json , flags = re . DOTALL ) :
ans_remove_json1 = match . group ( 1 ) . rstrip ( )
found_last_json_markdown = True
else :
ans_remove_json1 = ans_remove_json
#logging.info(f"--dale---3:1-{delta_ans} 2-{ans_remove_json1}")
"""
# logging.info(f"--dale---3--:{delta_ans}")
2025-02-06 23:34:26 +08:00
# 如果通过所有检查,返回有效标志和修正后的文本
2025-05-15 15:26:06 +08:00
2025-02-06 23:34:26 +08:00
return True , delta_ans
2024-08-15 09:17:36 +08:00
2025-05-15 15:26:06 +08:00
2025-02-23 09:52:30 +08:00
def _should_flush ( text_chunk , chunk_buffer , last_flush_time ) :
""" 智能判断是否需要立即生成音频 """
# 规则1: 遇到句子结束标点
if re . search ( r ' [。!?,]$ ' , text_chunk ) :
return True
if re . search ( r ' ( \ d {4} )(年|月|日|,) ' , text_chunk ) :
return False # 不刷新,继续合并
# 规则2: 达到最大缓冲长度( 200字符)
if sum ( len ( c ) for c in chunk_buffer ) > = 200 :
return True
# 规则3: 超过500ms未刷新
if time . time ( ) - last_flush_time > 0.5 :
return True
return False
2025-05-15 15:26:06 +08:00
def extract_and_parse_json ( llm_ans ) :
# 匹配带 JSON 标记的代码块
json_pattern = r ' ```json \ n(.*?) \ n``` '
match = re . search ( json_pattern , llm_ans , re . DOTALL )
if not match :
# 尝试匹配不带标记的 JSON 对象
json_pattern_fallback = r ' \ { .* \ } '
match = re . search ( json_pattern_fallback , response_text , re . DOTALL )
if not match :
return None , " 未检测到 JSON 数据 "
json_str = match . group ( 1 ) if match . group ( 1 ) else match . group ( 0 )
try :
# 处理常见格式问题
json_str = json_str . strip ( )
json_str = json_str . replace ( " , " , " , " ) # 替换中文逗号
json_str = re . sub ( r ' //.*? \ n ' , ' ' , json_str ) # 去除注释
# 解析 JSON
parsed_data = json . loads ( json_str )
return parsed_data , None
except json . JSONDecodeError as e :
error_msg = f " JSON 解析失败: { str ( e ) } \n 错误位置:第 { e . lineno } 列 { e . colno } "
return None , error_msg
except Exception as e :
return None , f " 解析异常: { str ( e ) } "
import re
import json
def extract_clear_parse_json ( llm_ans , clean_text = True ) :
parsed_data = None
error = None
matches = [ ]
cleaned_text = llm_ans
# 先将json特殊分隔符去除
# 匹配带 JSON 标记的完整代码块(含内容)
json_marker_pattern = r ' ```json.*?``` '
for match in re . finditer ( json_marker_pattern , cleaned_text , re . DOTALL ) :
start , end = match . start ( ) , match . end ( )
json_block = match . group ( 0 ) # 包含整个 ```json...``` 内容
matches . append ( ( start , end , json_block ) )
# 若未找到带标记块,尝试匹配无标记的 JSON 对象
if not matches :
json_object_pattern = r ' \ { [^ {} ]* \ } '
for match in re . finditer ( json_object_pattern , cleaned_text , re . DOTALL ) :
start , end = match . start ( ) , match . end ( )
json_block = match . group ( 0 )
matches . append ( ( start , end , json_block ) )
# 清理文本(删除所有 JSON 块)
if clean_text :
# 反向删除避免位置偏移
for start , end , _ in sorted ( matches , key = lambda x : x [ 0 ] , reverse = True ) :
cleaned_text = cleaned_text [ : start ] + cleaned_text [ end : ]
# 提取并解析 JSON 内容
json_blocks = [ ]
for _ , _ , block in matches :
# 如果是带标记的块,剥离 ```json 和 ```
if block . startswith ( ' ```json ' ) :
pure_json = re . sub ( r ' ^```json \ s*| \ s*```$ ' , ' ' , block , flags = re . DOTALL )
json_blocks . append ( pure_json . strip ( ) )
else :
json_blocks . append ( block . strip ( ) )
# 尝试解析所有 JSON 块
for json_str in json_blocks :
try :
json_str = re . sub ( r ' //.* ' , ' ' , json_str ) # 移除行内注释
json_str = json_str . replace ( " , " , " , " ) # 处理中文逗号
parsed_data = json . loads ( json_str )
error = None
break # 解析成功即停止
except json . JSONDecodeError as e :
error = f " JSON 解析失败: { e } ,位置:第 { e . lineno } 行 "
except Exception as e :
error = f " 解析异常: { str ( e ) } "
if not json_blocks :
error = " 未检测到 JSON 数据 "
return parsed_data , cleaned_text . strip ( ) , error
# cyx 20250510 增加 生成后续可以追问的内容, 输出为json格式
def generate_structured_followups ( chat_mdl , answer , max_questions = 5 ) :
"""
生成结构化追问建议 ( 带JSON格式解释 )
: return : ( JSON数据 , 格式解释 , 消耗tokens )
"""
system_prompt = """ 你是一位精通数据结构的博物馆教育专家,请完成以下任务:
1. 根据讲解内容生成 { max_questions } 个追问问题 , 格式为严格遵循的JSON
2. 为生成的JSON添加格式解释
3. JSON需包含问题分类和置信度
JSON格式要求 :
{ {
" questions " : [
{ {
" text " : " 问题文本 " ,
" type " : " 问题类型 " ,
" confidence " : 置信度 ( 0 - 1 )
} }
] ,
" source_analysis " : { {
" main_topics " : [ " 主要话题 " ] ,
" missing_areas " : [ " 未涉及领域 " ]
} }
} } """
user_prompt = f """ 原始讲解内容:
{ answer }
请按以下步骤处理 :
1. 分析内容的关键知识点
2. 生成 { max_questions } 个延伸问题
3. 生成JSON后添加格式解释
4. 用 - - - 分隔数据和解释 """
gen_config = {
" temperature " : 0.5 ,
" max_tokens " : 800
}
try :
ans = chat_mdl . chat (
system = system_prompt . format ( max_questions = max_questions ) ,
history = [ { " role " : " user " , " content " : user_prompt } ] ,
gen_conf = gen_config
)
# 分离JSON和解释
json_data , error = extract_and_parse_json ( ans )
if json_data is not None :
# 强化JSON提取
return json_data , " 解析正确 "
else :
return { } , " 解析错误 " , tokens
except json . JSONDecodeError as e :
print ( f " JSON解析失败: { str ( e ) } " )
return { } , " 格式解析错误 "
except Exception as e :
print ( f " 生成失败: { str ( e ) } " )
return { } , " 生成过程异常 "
2025-02-23 09:52:30 +08:00
MAX_BUFFER_LEN = 200 # 最大缓冲长度
2025-05-26 21:38:46 +08:00
FLUSH_TIMEOUT = 0.5 # 强制刷新时间(秒)
2025-02-23 09:52:30 +08:00
# 智能查找文本最佳分割点(标点/语义单位/短语边界)
def find_split_position ( text ) :
""" 智能查找最佳分割位置 """
# 优先查找句子结束符
sentence_end = list ( re . finditer ( r ' [。!?] ' , text ) )
if sentence_end :
return sentence_end [ - 1 ] . end ( )
# 其次查找自然停顿符
pause_mark = list ( re . finditer ( r ' [,;、] ' , text ) )
if pause_mark :
return pause_mark [ - 1 ] . end ( )
# 防止截断日期/数字短语
date_pattern = re . search ( r ' \ d+(年|月|日)(?! \ d) ' , text )
if date_pattern :
return date_pattern . end ( )
return None
# 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分
def process_buffer ( chunk_buffer , force_flush = False ) :
""" 处理文本缓冲区,返回待发送文本和剩余缓冲区 """
current_text = " " . join ( chunk_buffer )
if not current_text :
return " " , [ ]
split_pos = find_split_position ( current_text )
# 强制刷新逻辑
if force_flush or len ( current_text ) > = MAX_BUFFER_LEN :
# 即使强制刷新也要尽量找合适的分割点
if split_pos is None or split_pos < len ( current_text ) / / 2 :
split_pos = max ( split_pos or 0 , MAX_BUFFER_LEN )
split_pos = min ( split_pos , len ( current_text ) )
if split_pos is not None and split_pos > 0 :
2025-05-15 15:26:06 +08:00
to_tts_text = current_text [ : split_pos ]
remaining_text = [ current_text [ split_pos : ] ]
return to_tts_text , remaining_text
2025-02-23 09:52:30 +08:00
2025-04-08 08:41:07 +08:00
return None , chunk_buffer
2025-02-23 09:52:30 +08:00
2024-08-15 09:17:36 +08:00
def chat ( dialog , messages , stream = True , * * kwargs ) :
assert messages [ - 1 ] [ " role " ] == " user " , " The last content of this conversation is not from user. "
2024-09-09 12:08:50 +08:00
st = timer ( )
2024-09-18 16:09:22 +08:00
tmp = dialog . llm_id . split ( " @ " )
fid = None
llm_id = tmp [ 0 ]
if len ( tmp ) > 1 : fid = tmp [ 1 ]
llm = LLMService . query ( llm_name = llm_id ) if not fid else LLMService . query ( llm_name = llm_id , fid = fid )
2024-08-15 09:17:36 +08:00
if not llm :
2024-09-18 16:09:22 +08:00
llm = TenantLLMService . query ( tenant_id = dialog . tenant_id , llm_name = llm_id ) if not fid else \
TenantLLMService . query ( tenant_id = dialog . tenant_id , llm_name = llm_id , llm_factory = fid )
2024-08-15 09:17:36 +08:00
if not llm :
raise LookupError ( " LLM( %s ) not found " % dialog . llm_id )
max_tokens = 8192
else :
max_tokens = llm [ 0 ] . max_tokens
kbs = KnowledgebaseService . get_by_ids ( dialog . kb_ids )
embd_nms = list ( set ( [ kb . embd_id for kb in kbs ] ) )
if len ( embd_nms ) != 1 :
yield { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
return { " answer " : " **ERROR**: Knowledge bases use different embedding models. " , " reference " : [ ] }
is_kg = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
2024-11-15 17:30:56 +08:00
retr = settings . retrievaler if not is_kg else settings . kg_retrievaler
2024-08-15 09:17:36 +08:00
questions = [ m [ " content " ] for m in messages if m [ " role " ] == " user " ] [ - 3 : ]
attachments = kwargs [ " doc_ids " ] . split ( " , " ) if " doc_ids " in kwargs else None
if " doc_ids " in messages [ - 1 ] :
attachments = messages [ - 1 ] [ " doc_ids " ]
for m in messages [ : - 1 ] :
if " doc_ids " in m :
attachments . extend ( m [ " doc_ids " ] )
embd_mdl = LLMBundle ( dialog . tenant_id , LLMType . EMBEDDING , embd_nms [ 0 ] )
2024-11-05 09:29:01 +08:00
if not embd_mdl :
raise LookupError ( " Embedding model( %s ) not found " % embd_nms [ 0 ] )
2024-11-05 10:04:31 +08:00
2024-08-15 09:17:36 +08:00
if llm_id2llm_type ( dialog . llm_id ) == " image2text " :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . IMAGE2TEXT , dialog . llm_id )
else :
chat_mdl = LLMBundle ( dialog . tenant_id , LLMType . CHAT , dialog . llm_id )
prompt_config = dialog . prompt_config
field_map = KnowledgebaseService . get_field_map ( dialog . kb_ids )
2024-09-03 19:49:14 +08:00
tts_mdl = None
2025-02-06 23:34:26 +08:00
2024-09-03 19:49:14 +08:00
if prompt_config . get ( " tts " ) :
2025-02-23 09:52:30 +08:00
if kwargs . get ( ' tts_model ' ) :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS , kwargs . get ( ' tts_model ' ) )
else :
tts_mdl = LLMBundle ( dialog . tenant_id , LLMType . TTS , dialog . tts_id )
2025-02-06 23:34:26 +08:00
2025-04-08 08:41:07 +08:00
tts_sample_rate = kwargs . get ( " tts_sample_rate " , 8000 ) # 默认为8K
tts_stream_format = kwargs . get ( " tts_stream_format " , " mp3 " ) # 默认为mp3格式
2024-08-15 09:17:36 +08:00
# try to use sql if field mapping is good to go
if field_map :
2024-11-14 17:13:48 +08:00
logging . debug ( " Use SQL to retrieval: {} " . format ( questions [ - 1 ] ) )
2024-08-15 09:17:36 +08:00
ans = use_sql ( questions [ - 1 ] , field_map , dialog . tenant_id , chat_mdl , prompt_config . get ( " quote " , True ) )
if ans :
yield ans
return
2025-02-06 23:34:26 +08:00
# logging.info(f"dialog_service--1 chat prompt_config{prompt_config['parameters']} {prompt_config}") # cyx
2024-08-15 09:17:36 +08:00
for p in prompt_config [ " parameters " ] :
if p [ " key " ] == " knowledge " :
continue
if p [ " key " ] not in kwargs and not p [ " optional " ] :
raise KeyError ( " Miss parameter: " + p [ " key " ] )
if p [ " key " ] not in kwargs :
prompt_config [ " system " ] = prompt_config [ " system " ] . replace (
" { %s } " % p [ " key " ] , " " )
2024-09-20 17:25:55 +08:00
if len ( questions ) > 1 and prompt_config . get ( " refine_multiturn " ) :
questions = [ full_question ( dialog . tenant_id , dialog . llm_id , messages ) ]
else :
questions = questions [ - 1 : ]
2024-11-05 13:39:50 +08:00
refineQ_tm = timer ( )
keyword_tm = timer ( )
2024-08-15 09:17:36 +08:00
rerank_mdl = None
if dialog . rerank_id :
rerank_mdl = LLMBundle ( dialog . tenant_id , LLMType . RERANK , dialog . rerank_id )
for _ in range ( len ( questions ) / / 2 ) :
questions . append ( questions [ - 1 ] )
if " knowledge " not in [ p [ " key " ] for p in prompt_config [ " parameters " ] ] :
kbinfos = { " total " : 0 , " chunks " : [ ] , " doc_aggs " : [ ] }
else :
if prompt_config . get ( " keyword " , False ) :
questions [ - 1 ] + = keyword_extraction ( chat_mdl , questions [ - 1 ] )
2024-11-05 13:39:50 +08:00
keyword_tm = timer ( )
2024-10-29 13:19:01 +08:00
tenant_ids = list ( set ( [ kb . tenant_id for kb in kbs ] ) )
kbinfos = retr . retrieval ( " " . join ( questions ) , embd_mdl , tenant_ids , dialog . kb_ids , 1 , dialog . top_n ,
2024-08-15 09:17:36 +08:00
dialog . similarity_threshold ,
dialog . vector_similarity_weight ,
doc_ids = attachments ,
top = dialog . top_k , aggs = False , rerank_mdl = rerank_mdl )
knowledges = [ ck [ " content_with_weight " ] for ck in kbinfos [ " chunks " ] ]
2025-04-08 08:41:07 +08:00
logging . debug ( " {} -> {} " . format ( " " . join ( questions ) , " \n -> " . join ( knowledges ) ) )
# 打印历史记录
2025-05-26 21:38:46 +08:00
# logging.info( "dale-----!!!:{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
2024-09-09 12:08:50 +08:00
retrieval_tm = timer ( )
2024-08-15 09:17:36 +08:00
if not knowledges and prompt_config . get ( " empty_response " ) :
2024-09-03 19:49:14 +08:00
empty_res = prompt_config [ " empty_response " ]
2025-04-08 08:41:07 +08:00
yield { " answer " : empty_res , " reference " : kbinfos , " audio_binary " :
tts ( tts_mdl , empty_res , sample_rate = tts_sample_rate , stream_format = tts_stream_format ) }
2024-08-15 09:17:36 +08:00
return { " answer " : prompt_config [ " empty_response " ] , " reference " : kbinfos }
2024-09-24 12:04:16 +08:00
kwargs [ " knowledge " ] = " \n \n ------ \n \n " . join ( knowledges )
2024-08-15 09:17:36 +08:00
gen_conf = dialog . llm_setting
msg = [ { " role " : " system " , " content " : prompt_config [ " system " ] . format ( * * kwargs ) } ]
2025-04-08 08:41:07 +08:00
2024-08-15 09:17:36 +08:00
msg . extend ( [ { " role " : m [ " role " ] , " content " : re . sub ( r " ## \ d+ \ $ \ $ " , " " , m [ " content " ] ) }
for m in messages if m [ " role " ] != " system " ] )
used_token_count , msg = message_fit_in ( msg , int ( max_tokens * 0.97 ) )
assert len ( msg ) > = 2 , f " message_fit_in has bug: { msg } "
2024-08-26 16:14:15 +08:00
prompt = msg [ 0 ] [ " content " ]
2024-10-08 12:53:04 +08:00
prompt + = " \n \n ### Query: \n %s " % " " . join ( questions )
2024-08-15 09:17:36 +08:00
if " max_tokens " in gen_conf :
gen_conf [ " max_tokens " ] = min (
gen_conf [ " max_tokens " ] ,
max_tokens - used_token_count )
def decorate_answer ( answer ) :
2024-09-09 12:08:50 +08:00
nonlocal prompt_config , knowledges , kwargs , kbinfos , prompt , retrieval_tm
2024-08-15 09:17:36 +08:00
refs = [ ]
if knowledges and ( prompt_config . get ( " quote " , True ) and kwargs . get ( " quote " , True ) ) :
answer , idx = retr . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 1 - dialog . vector_similarity_weight ,
vtweight = dialog . vector_similarity_weight )
2025-04-08 08:41:07 +08:00
# 上述转换过程中, 发现有时候会在answer中插入类似##0$$ ##1$$ 这样的字符串,需要去除
# cyx 20250407
answer = re . sub ( r ' ## \ d+ \ $ \ $ ' , ' ' , answer ) . strip ( ) #去除##0$$类似内容 同时去除多余空格
2024-08-15 09:17:36 +08:00
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
if not recall_docs : recall_docs = kbinfos [ " doc_aggs " ]
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
answer + = " Please set LLM API-Key in ' User Setting -> Model Providers -> API-Key ' "
2024-09-09 12:08:50 +08:00
done_tm = timer ( )
2024-11-05 13:39:50 +08:00
prompt + = " \n \n ### Elapsed \n - Refine Question: %.1f ms \n - Keywords: %.1f ms \n - Retrieval: %.1f ms \n - LLM: %.1f ms " % (
( refineQ_tm - st ) * 1000 , ( keyword_tm - refineQ_tm ) * 1000 , ( retrieval_tm - keyword_tm ) * 1000 ,
( done_tm - retrieval_tm ) * 1000 )
2025-05-15 15:26:06 +08:00
#return {"answer": answer, "prompt": prompt,"reference": refs }
# cyx 增加 20250510 生成后续追问的内容
# cyx 修改 20250422 不向前端发送prompt 和 refs ,增加发送 finished 标志
return { " answer " : answer , " finished " : True , " reference " : " " }
2024-08-15 09:17:36 +08:00
if stream :
2024-09-03 19:49:14 +08:00
last_ans = " "
2024-08-15 09:17:36 +08:00
answer = " "
2025-07-10 22:04:44 +08:00
audio_url = None
if not kwargs . get ( ' tts_disable ' ) :
2025-02-23 09:52:30 +08:00
# 创建TTS会话( 提前初始化)
2025-07-10 22:04:44 +08:00
tts_session_id = stream_manager . create_session ( tts_mdl , sample_rate = tts_sample_rate , stream_format = tts_stream_format ,
voice = kwargs . get ( ' tts_model ' ) )
tts_session = stream_manager . get_session ( tts_session_id )
audio_url = f " /tts_stream/ { tts_session_id } "
2025-05-15 15:26:06 +08:00
send_tts_url = False
2025-02-23 09:52:30 +08:00
chunk_buffer = [ ] # 新增文本缓冲
last_flush_time = time . time ( ) # 初始化时间戳
2025-04-08 08:41:07 +08:00
# 下面优先处理知识库中没有找到相关内容 cyx 20250323 修改
if not kwargs [ " knowledge " ] or kwargs [ " knowledge " ] == " " or len ( kwargs [ " knowledge " ] ) < 4 :
2025-07-10 22:04:44 +08:00
if not kwargs . get ( ' tts_disable ' ) :
stream_manager . append_text ( tts_session_id , " 未找到相关内容 " )
2025-04-08 08:41:07 +08:00
yield {
" answer " : " 未找到相关内容 " ,
" delta_ans " : " 未找到相关内容 " ,
" session_id " : tts_session_id ,
" reference " : { } ,
" audio_stream_url " : audio_url ,
" sample_rate " : tts_sample_rate ,
" stream_format " : tts_stream_format ,
}
else :
for ans in chat_mdl . chat_streamly ( prompt , msg [ 1 : ] , gen_conf ) :
answer = ans
delta_ans = ans [ len ( last_ans ) : ]
if num_tokens_from_string ( delta_ans ) < 24 :
continue
last_ans = answer
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
tts_input_is_valid , sanitized_text = validate_and_sanitize_tts_input ( delta_ans )
# cyx 2025 01 18 前端传入tts_disable 参数, 就不生成tts 音频给前端,即:没有audio_binary
if kwargs . get ( ' tts_disable ' ) :
tts_input_is_valid = False
2025-05-15 15:26:06 +08:00
if tts_input_is_valid :
2025-04-08 08:41:07 +08:00
# 缓冲文本直到遇到标点
chunk_buffer . append ( sanitized_text )
# 处理缓冲区内容
while True :
# 判断是否需要强制刷新
force = time . time ( ) - last_flush_time > FLUSH_TIMEOUT
to_send , remaining = process_buffer ( chunk_buffer , force_flush = force )
if not to_send :
break
# 发送有效内容
stream_manager . append_text ( tts_session_id , to_send )
chunk_buffer = remaining
last_flush_time = time . time ( )
"""
if tts_input_is_valid :
yield { " answer " : answer , " delta_ans " : sanitized_text , " reference " : { } , " audio_binary " : tts ( tts_mdl , sanitized_text ) }
else :
yield { " answer " : answer , " delta_ans " : sanitized_text , " reference " : { } }
"""
# 首块返回音频URL
2025-07-10 22:04:44 +08:00
if send_tts_url is False and not kwargs . get ( ' tts_disable ' ) :
if tts_session [ ' tts_chunk_data_valid ' ] is True :
yield {
" answer " : answer ,
" delta_ans " : sanitized_text ,
" session_id " : tts_session_id ,
" reference " : { } ,
" audio_stream_url " : audio_url ,
" sample_rate " : tts_sample_rate ,
" stream_format " : tts_stream_format ,
}
send_tts_url = True # 发送一次tts url 给前端即可,不能重复发送
logging . info ( f " --chat retur tts url { audio_url } " )
2025-04-08 08:41:07 +08:00
else :
yield { " answer " : answer , " delta_ans " : sanitized_text , " reference " : { } }
2025-05-15 15:26:06 +08:00
2025-04-08 08:41:07 +08:00
delta_ans = answer [ len ( last_ans ) : ]
if delta_ans :
# stream_manager.append_text(tts_session_id, delta_ans)
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
tts_input_is_valid , sanitized_text = validate_and_sanitize_tts_input ( delta_ans )
if kwargs . get ( ' tts_disable ' ) : # cyx 2025 01 18 前端传入tts_disable 参数, 就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
2025-05-15 15:26:06 +08:00
if tts_input_is_valid :
2025-04-08 08:41:07 +08:00
# 20250221 修改,在后端生成音频数据
chunk_buffer . append ( sanitized_text )
2025-05-15 15:26:06 +08:00
to_send , remaining = process_buffer ( chunk_buffer , force_flush = force )
if to_send :
stream_manager . append_text ( tts_session_id , to_send )
2025-02-06 23:34:26 +08:00
yield { " answer " : answer , " delta_ans " : sanitized_text , " reference " : { } }
2025-04-08 08:41:07 +08:00
"""
if tts_input_is_valid :
yield { " answer " : answer , " delta_ans " : sanitized_text , " reference " : { } , " audio_binary " : tts ( tts_mdl , sanitized_text ) }
else :
yield { " answer " : answer , " delta_ans " : sanitized_text , " reference " : { } }
"""
yield decorate_answer ( answer )
2025-02-06 23:34:26 +08:00
2024-08-15 09:17:36 +08:00
else :
2024-08-26 16:14:15 +08:00
answer = chat_mdl . chat ( prompt , msg [ 1 : ] , gen_conf )
2024-11-14 17:13:48 +08:00
logging . debug ( " User: {} |Assistant: {} " . format (
2024-08-15 09:17:36 +08:00
msg [ - 1 ] [ " content " ] , answer ) )
2024-09-03 19:49:14 +08:00
res = decorate_answer ( answer )
2025-02-06 23:34:26 +08:00
if kwargs . get ( ' tts_disable ' ) : # cyx 2025 01 18 前端传入tts_disable 参数, 就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
else :
2025-04-08 08:41:07 +08:00
res [ " audio_binary " ] = tts ( tts_mdl , answer , tts_sample_rate , tts_stream_format )
2024-09-03 19:49:14 +08:00
yield res
2024-08-15 09:17:36 +08:00
def use_sql ( question , field_map , tenant_id , chat_mdl , quota = True ) :
sys_prompt = " 你是一个DBA。你需要这对以下表的字段结构, 根据用户的问题列表, 写出最后一个问题对应的SQL。 "
user_promt = """
表名 : { } ;
数据库表字段说明如下 :
{ }
问题如下 :
{ }
请写出SQL , 且只要SQL , 不要有其他说明及文字 。
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question
)
tried_times = 0
def get_table ( ) :
nonlocal sys_prompt , user_promt , question , tried_times
sql = chat_mdl . chat ( sys_prompt , [ { " role " : " user " , " content " : user_promt } ] , {
" temperature " : 0.06 } )
2024-11-14 17:13:48 +08:00
logging . debug ( f " { question } ==> { user_promt } get SQL: { sql } " )
2024-08-15 09:17:36 +08:00
sql = re . sub ( r " [ \ r \ n]+ " , " " , sql . lower ( ) )
sql = re . sub ( r " .*select " , " select " , sql . lower ( ) )
sql = re . sub ( r " + " , " " , sql )
sql = re . sub ( r " ([;; ]|```).* " , " " , sql )
if sql [ : len ( " select " ) ] != " select " :
return None , None
if not re . search ( r " ((sum|avg|max|min) \ (|group by ) " , sql . lower ( ) ) :
if sql [ : len ( " select * " ) ] != " select * " :
sql = " select doc_id,docnm_kwd, " + sql [ 6 : ]
else :
flds = [ ]
for k in field_map . keys ( ) :
if k in forbidden_select_fields4resume :
continue
if len ( flds ) > 11 :
break
flds . append ( k )
sql = " select doc_id,docnm_kwd, " + " , " . join ( flds ) + sql [ 8 : ]
2024-11-14 17:13:48 +08:00
logging . debug ( f " { question } get SQL(refined): { sql } " )
2024-08-15 09:17:36 +08:00
tried_times + = 1
2024-11-15 17:30:56 +08:00
return settings . retrievaler . sql_retrieval ( sql , format = " json " ) , sql
2024-08-15 09:17:36 +08:00
tbl , sql = get_table ( )
if tbl is None :
return None
if tbl . get ( " error " ) and tried_times < = 2 :
user_promt = """
表名 : { } ;
数据库表字段说明如下 :
{ }
问题如下 :
{ }
你上一次给出的错误SQL如下 :
{ }
后台报错如下 :
{ }
请纠正SQL中的错误再写一遍 , 且只要SQL , 不要有其他说明及文字 。
""" .format(
index_name ( tenant_id ) ,
" \n " . join ( [ f " { k } : { v } " for k , v in field_map . items ( ) ] ) ,
question , sql , tbl [ " error " ]
)
tbl , sql = get_table ( )
2024-11-14 17:13:48 +08:00
logging . debug ( " TRY it again: {} " . format ( sql ) )
2024-08-15 09:17:36 +08:00
2024-11-14 17:13:48 +08:00
logging . debug ( " GET table: {} " . format ( tbl ) )
2024-08-15 09:17:36 +08:00
if tbl . get ( " error " ) or len ( tbl [ " rows " ] ) == 0 :
return None
docid_idx = set ( [ ii for ii , c in enumerate (
tbl [ " columns " ] ) if c [ " name " ] == " doc_id " ] )
docnm_idx = set ( [ ii for ii , c in enumerate (
tbl [ " columns " ] ) if c [ " name " ] == " docnm_kwd " ] )
clmn_idx = [ ii for ii in range (
len ( tbl [ " columns " ] ) ) if ii not in ( docid_idx | docnm_idx ) ]
# compose markdown table
clmns = " | " + " | " . join ( [ re . sub ( r " (/.*|( [^( ) ]+) ) " , " " , field_map . get ( tbl [ " columns " ] [ i ] [ " name " ] ,
tbl [ " columns " ] [ i ] [ " name " ] ) ) for i in
clmn_idx ] ) + ( " |Source| " if docid_idx and docid_idx else " | " )
line = " | " + " | " . join ( [ " ------ " for _ in range ( len ( clmn_idx ) ) ] ) + \
( " |------| " if docid_idx and docid_idx else " " )
rows = [ " | " +
" | " . join ( [ rmSpace ( str ( r [ i ] ) ) for i in clmn_idx ] ) . replace ( " None " , " " ) +
" | " for r in tbl [ " rows " ] ]
2024-11-06 18:47:53 +08:00
rows = [ r for r in rows if re . sub ( r " [ |]+ " , " " , r ) ]
2024-08-15 09:17:36 +08:00
if quota :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
else :
rows = " \n " . join ( [ r + f " ## { ii } $$ | " for ii , r in enumerate ( rows ) ] )
rows = re . sub ( r " T[0-9] {2} :[0-9] {2} :[0-9] {2} ( \ .[0-9]+Z)? \ | " , " | " , rows )
if not docid_idx or not docnm_idx :
2024-11-14 17:13:48 +08:00
logging . warning ( " SQL missing field: " + sql )
2024-08-15 09:17:36 +08:00
return {
" answer " : " \n " . join ( [ clmns , line , rows ] ) ,
2024-08-26 16:14:15 +08:00
" reference " : { " chunks " : [ ] , " doc_aggs " : [ ] } ,
" prompt " : sys_prompt
2024-08-15 09:17:36 +08:00
}
docid_idx = list ( docid_idx ) [ 0 ]
docnm_idx = list ( docnm_idx ) [ 0 ]
doc_aggs = { }
for r in tbl [ " rows " ] :
if r [ docid_idx ] not in doc_aggs :
doc_aggs [ r [ docid_idx ] ] = { " doc_name " : r [ docnm_idx ] , " count " : 0 }
doc_aggs [ r [ docid_idx ] ] [ " count " ] + = 1
return {
" answer " : " \n " . join ( [ clmns , line , rows ] ) ,
" reference " : { " chunks " : [ { " doc_id " : r [ docid_idx ] , " docnm_kwd " : r [ docnm_idx ] } for r in tbl [ " rows " ] ] ,
" doc_aggs " : [ { " doc_id " : did , " doc_name " : d [ " doc_name " ] , " count " : d [ " count " ] } for did , d in
2024-08-26 16:14:15 +08:00
doc_aggs . items ( ) ] } ,
" prompt " : sys_prompt
2024-08-15 09:17:36 +08:00
}
def relevant ( tenant_id , llm_id , question , contents : list ) :
if llm_id2llm_type ( llm_id ) == " image2text " :
chat_mdl = LLMBundle ( tenant_id , LLMType . IMAGE2TEXT , llm_id )
else :
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , llm_id )
prompt = """
You are a grader assessing relevance of a retrieved document to a user question .
It does not need to be a stringent test . The goal is to filter out erroneous retrievals .
If the document contains keyword ( s ) or semantic meaning related to the user question , grade it as relevant .
Give a binary score ' yes ' or ' no ' score to indicate whether the document is relevant to the question .
No other words needed except ' yes ' or ' no ' .
"""
if not contents : return False
contents = " Documents: \n " + " - " . join ( contents )
contents = f " Question: { question } \n " + contents
if num_tokens_from_string ( contents ) > = chat_mdl . max_length - 4 :
contents = encoder . decode ( encoder . encode ( contents ) [ : chat_mdl . max_length - 4 ] )
ans = chat_mdl . chat ( prompt , [ { " role " : " user " , " content " : contents } ] , { " temperature " : 0.01 } )
if ans . lower ( ) . find ( " yes " ) > = 0 : return True
return False
def rewrite ( tenant_id , llm_id , question ) :
if llm_id2llm_type ( llm_id ) == " image2text " :
chat_mdl = LLMBundle ( tenant_id , LLMType . IMAGE2TEXT , llm_id )
else :
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , llm_id )
prompt = """
You are an expert at query expansion to generate a paraphrasing of a question .
I can ' t retrieval relevant information from the knowledge base by using user ' s question directly .
You need to expand or paraphrase user ' s question by multiple ways such as using synonyms words/phrase,
writing the abbreviation in its entirety , adding some extra descriptions or explanations ,
changing the way of expression , translating the original question into another language ( English / Chinese ) , etc .
And return 5 versions of question and one is from translation .
Just list the question . No other words are needed .
"""
ans = chat_mdl . chat ( prompt , [ { " role " : " user " , " content " : question } ] , { " temperature " : 0.8 } )
return ans
2024-09-03 19:49:14 +08:00
2024-10-22 13:12:49 +08:00
def keyword_extraction ( chat_mdl , content , topn = 3 ) :
prompt = f """
Role : You ' re a text analyzer.
Task : extract the most important keywords / phrases of a given piece of text content .
Requirements :
- Summarize the text content , and give top { topn } important keywords / phrases .
- The keywords MUST be in language of the given piece of text content .
- The keywords are delimited by ENGLISH COMMA .
- Keywords ONLY in output .
### Text Content
{ content }
"""
msg = [
{ " role " : " system " , " content " : prompt } ,
{ " role " : " user " , " content " : " Output: " }
]
_ , msg = message_fit_in ( msg , chat_mdl . max_length )
kwd = chat_mdl . chat ( prompt , msg [ 1 : ] , { " temperature " : 0.2 } )
if isinstance ( kwd , tuple ) : kwd = kwd [ 0 ]
if kwd . find ( " **ERROR** " ) > = 0 : return " "
return kwd
def question_proposal ( chat_mdl , content , topn = 3 ) :
prompt = f """
Role : You ' re a text analyzer.
Task : propose { topn } questions about a given piece of text content .
Requirements :
- Understand and summarize the text content , and propose top { topn } important questions .
- The questions SHOULD NOT have overlapping meanings .
- The questions SHOULD cover the main content of the text as much as possible .
- The questions MUST be in language of the given piece of text content .
- One question per line .
- Question ONLY in output .
### Text Content
{ content }
"""
msg = [
{ " role " : " system " , " content " : prompt } ,
{ " role " : " user " , " content " : " Output: " }
]
_ , msg = message_fit_in ( msg , chat_mdl . max_length )
kwd = chat_mdl . chat ( prompt , msg [ 1 : ] , { " temperature " : 0.2 } )
if isinstance ( kwd , tuple ) : kwd = kwd [ 0 ]
if kwd . find ( " **ERROR** " ) > = 0 : return " "
return kwd
2024-09-20 17:25:55 +08:00
def full_question ( tenant_id , llm_id , messages ) :
if llm_id2llm_type ( llm_id ) == " image2text " :
chat_mdl = LLMBundle ( tenant_id , LLMType . IMAGE2TEXT , llm_id )
else :
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT , llm_id )
conv = [ ]
for m in messages :
if m [ " role " ] not in [ " user " , " assistant " ] : continue
conv . append ( " {} : {} " . format ( m [ " role " ] . upper ( ) , m [ " content " ] ) )
conv = " \n " . join ( conv )
2024-11-13 13:49:18 +08:00
today = datetime . date . today ( ) . isoformat ( )
yesterday = ( datetime . date . today ( ) - timedelta ( days = 1 ) ) . isoformat ( )
tomorrow = ( datetime . date . today ( ) + timedelta ( days = 1 ) ) . isoformat ( )
2024-09-20 17:25:55 +08:00
prompt = f """
Role : A helpful assistant
2024-11-13 13:49:18 +08:00
Task and steps :
1. Generate a full user question that would follow the conversation .
2. If the user ' s question involves relative date, you need to convert it into absolute date based on the current date, which is {today} . For example: ' yesterday ' would be converted to {yesterday} .
2024-09-20 17:25:55 +08:00
Requirements & Restrictions :
- Text generated MUST be in the same language of the original user ' s question.
- If the user ' s latest question is completely, don ' t do anything , just return the original question .
- DON ' T generate anything except a refined question.
######################
- Examples -
######################
# Example 1
## Conversation
USER : What is the name of Donald Trump ' s father?
ASSISTANT : Fred Trump .
USER : And his mother ?
###############
Output : What ' s the name of Donald Trump ' s mother ?
- - - - - - - - - - - -
# Example 2
## Conversation
USER : What is the name of Donald Trump ' s father?
ASSISTANT : Fred Trump .
USER : And his mother ?
ASSISTANT : Mary Trump .
User : What ' s her full name?
###############
Output : What ' s the full name of Donald Trump ' s mother Mary Trump ?
2024-11-13 13:49:18 +08:00
- - - - - - - - - - - -
# Example 3
## Conversation
USER : What ' s the weather today in London?
ASSISTANT : Cloudy .
USER : What ' s about tomorrow in Rochester?
###############
Output : What ' s the weather in Rochester on {tomorrow} ?
2024-09-20 17:25:55 +08:00
######################
# Real Data
## Conversation
{ conv }
###############
"""
ans = chat_mdl . chat ( prompt , [ { " role " : " user " , " content " : " Output: " } ] , { " temperature " : 0.2 } )
return ans if ans . find ( " **ERROR** " ) < 0 else messages [ - 1 ] [ " content " ]
2025-04-08 08:41:07 +08:00
def tts ( tts_mdl , text , sample_rate = 8000 , stream_format = " mp3 " ) :
2024-09-03 19:49:14 +08:00
if not tts_mdl or not text : return
bin = b " "
2025-04-08 08:41:07 +08:00
for chunk in tts_mdl . tts ( text , sample_rate , stream_format ) :
2024-09-03 19:49:14 +08:00
bin + = chunk
2024-09-09 12:08:50 +08:00
return binascii . hexlify ( bin ) . decode ( " utf-8 " )
def ask ( question , kb_ids , tenant_id ) :
kbs = KnowledgebaseService . get_by_ids ( kb_ids )
2024-11-20 19:45:50 +08:00
tenant_ids = [ kb . tenant_id for kb in kbs ]
2024-09-09 12:08:50 +08:00
embd_nms = list ( set ( [ kb . embd_id for kb in kbs ] ) )
is_kg = all ( [ kb . parser_id == ParserType . KG for kb in kbs ] )
2024-11-15 17:30:56 +08:00
retr = settings . retrievaler if not is_kg else settings . kg_retrievaler
2024-09-09 12:08:50 +08:00
embd_mdl = LLMBundle ( tenant_id , LLMType . EMBEDDING , embd_nms [ 0 ] )
chat_mdl = LLMBundle ( tenant_id , LLMType . CHAT )
max_tokens = chat_mdl . max_length
2024-11-20 19:45:50 +08:00
kbinfos = retr . retrieval ( question , embd_mdl , tenant_ids , kb_ids , 1 , 12 , 0.1 , 0.3 , aggs = False )
2024-09-09 12:08:50 +08:00
knowledges = [ ck [ " content_with_weight " ] for ck in kbinfos [ " chunks " ] ]
used_token_count = 0
for i , c in enumerate ( knowledges ) :
used_token_count + = num_tokens_from_string ( c )
if max_tokens * 0.97 < used_token_count :
knowledges = knowledges [ : i ]
break
prompt = """
Role : You ' re a smart assistant. Your name is Miss R.
Task : Summarize the information from knowledge bases and answer user ' s question.
Requirements and restriction :
- DO NOT make things up , especially for numbers .
- If the information from knowledge is irrelevant with user ' s question, JUST SAY: Sorry, no relevant information provided.
- Answer with markdown format text .
- Answer in language of user ' s question.
- DO NOT make things up , especially for numbers .
### Information from knowledge bases
% s
The above is information from knowledge bases .
""" % " \n " .join(knowledges)
msg = [ { " role " : " user " , " content " : question } ]
def decorate_answer ( answer ) :
nonlocal knowledges , kbinfos , prompt
answer , idx = retr . insert_citations ( answer ,
[ ck [ " content_ltks " ]
for ck in kbinfos [ " chunks " ] ] ,
[ ck [ " vector " ]
for ck in kbinfos [ " chunks " ] ] ,
embd_mdl ,
tkweight = 0.7 ,
vtweight = 0.3 )
idx = set ( [ kbinfos [ " chunks " ] [ int ( i ) ] [ " doc_id " ] for i in idx ] )
recall_docs = [
d for d in kbinfos [ " doc_aggs " ] if d [ " doc_id " ] in idx ]
if not recall_docs : recall_docs = kbinfos [ " doc_aggs " ]
kbinfos [ " doc_aggs " ] = recall_docs
refs = deepcopy ( kbinfos )
for c in refs [ " chunks " ] :
if c . get ( " vector " ) :
del c [ " vector " ]
if answer . lower ( ) . find ( " invalid key " ) > = 0 or answer . lower ( ) . find ( " invalid api " ) > = 0 :
answer + = " Please set LLM API-Key in ' User Setting -> Model Providers -> API-Key ' "
return { " answer " : answer , " reference " : refs }
answer = " "
for ans in chat_mdl . chat_streamly ( prompt , msg , { " temperature " : 0.1 } ) :
answer = ans
yield { " answer " : answer , " reference " : { } }
yield decorate_answer ( answer )