从新提交到gitee 仓库
This commit is contained in:
51
api/apps/sdk/dale_extra.py
Normal file
51
api/apps/sdk/dale_extra.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request
|
||||
from api import settings
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
|
||||
|
||||
# 用户已经添加的模型 cyx 2025-01-26
|
||||
@manager.route('/get_llms', methods=['GET'])
|
||||
@token_required
|
||||
def my_llms(tenant_id):
|
||||
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
|
||||
model_type = request.args.get("type")
|
||||
try:
|
||||
res = {}
|
||||
for o in TenantLLMService.get_my_llms(tenant_id):
|
||||
if model_type is None or o["model_type"] == model_type: # 增加按类型的筛选
|
||||
if o["llm_factory"] not in res:
|
||||
res[o["llm_factory"]] = {
|
||||
"tags": o["tags"],
|
||||
"llm": []
|
||||
}
|
||||
|
||||
res[o["llm_factory"]]["llm"].append({
|
||||
"type": o["model_type"],
|
||||
"name": o["llm_name"],
|
||||
"used_token": o["used_tokens"]
|
||||
})
|
||||
return get_result(data=res)
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||||
@@ -13,12 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import re, io
|
||||
import json
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from uuid import uuid4
|
||||
from api.db import LLMType
|
||||
from flask import request, Response
|
||||
from flask import request, Response, jsonify
|
||||
from api.db.services.dialog_service import ask
|
||||
from agent.canvas import Canvas
|
||||
from api.db import StatusEnum
|
||||
@@ -31,11 +32,13 @@ from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result
|
||||
from api.utils.api_utils import get_result, token_required
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
import uuid
|
||||
import queue
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
|
||||
@token_required
|
||||
def create(tenant_id,chat_id):
|
||||
def create(tenant_id, chat_id):
|
||||
req = request.json
|
||||
req["dialog_id"] = chat_id
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
|
||||
@@ -77,7 +80,7 @@ def create_agent_session(tenant_id, agent_id):
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
|
||||
"user_id": req.get("usr_id", "") if isinstance(req, dict) else "",
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
@@ -88,11 +91,11 @@ def create_agent_session(tenant_id, agent_id):
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
|
||||
@token_required
|
||||
def update(tenant_id,chat_id,session_id):
|
||||
def update(tenant_id, chat_id, session_id):
|
||||
req = request.json
|
||||
req["dialog_id"] = chat_id
|
||||
conv_id = session_id
|
||||
conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
|
||||
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
|
||||
if not conv:
|
||||
return get_error_data_result(message="Session does not exist")
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
@@ -110,9 +113,9 @@ def update(tenant_id,chat_id,session_id):
|
||||
|
||||
@manager.route('/chats/<chat_id>/completions', methods=['POST'])
|
||||
@token_required
|
||||
def completion(tenant_id, chat_id):
|
||||
def completion(tenant_id, chat_id): # chat_id 和 别的文件中的dialog_id 应该是一个意思? cyx 2025-01-25
|
||||
req = request.json
|
||||
if not req.get("session_id"):
|
||||
if not req.get("session_id"): # session_id 和 别的文件中的conversation_id 应该是一个意思? cyx 2025-01-25
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": chat_id,
|
||||
@@ -123,12 +126,18 @@ def completion(tenant_id, chat_id):
|
||||
return get_error_data_result(message="`name` can not be empty.")
|
||||
ConversationService.save(**conv)
|
||||
e, conv = ConversationService.get_by_id(conv["id"])
|
||||
session_id=conv.id
|
||||
session_id = conv.id
|
||||
else:
|
||||
session_id = req.get("session_id")
|
||||
if not req.get("question"):
|
||||
return get_error_data_result(message="Please input your question.")
|
||||
conv = ConversationService.query(id=session_id,dialog_id=chat_id)
|
||||
#conv = ConversationService.query(id=session_id, dialog_id=chat_id)
|
||||
# 以下改动是为了限制从历史记录中取过多的记录
|
||||
history_limit = req.get("history_limit", None)
|
||||
if history_limit is not None:
|
||||
conv = ConversationService.query(id=session_id, dialog_id=chat_id, reverse=True, order_by="create_time")
|
||||
else:
|
||||
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
|
||||
if not conv:
|
||||
return get_error_data_result(message="Session does not exist")
|
||||
conv = conv[0]
|
||||
@@ -141,13 +150,25 @@ def completion(tenant_id, chat_id):
|
||||
"id": str(uuid4())
|
||||
}
|
||||
conv.message.append(question)
|
||||
# 第一次遍历,计算 assistant 消息的总数
|
||||
assistant_total_count = sum(1 for m in conv.message if m["role"] == "assistant")
|
||||
# 第二次遍历,按条件添加消息到 msg
|
||||
current_assistant_count = 0 # 跟踪当前添加的 assistant 消息数
|
||||
|
||||
for m in conv.message:
|
||||
if m["role"] == "system": continue
|
||||
if m["role"] == "assistant" and not msg: continue
|
||||
if m['role'] == "assistant":
|
||||
# 如果 assistant 消息超出需要保留的数量,跳过
|
||||
# 检查 history_limit 是否为 None,None 表示不限制
|
||||
if history_limit is not None and current_assistant_count < assistant_total_count - history_limit:
|
||||
current_assistant_count += 1
|
||||
continue
|
||||
msg.append(m)
|
||||
|
||||
message_id = msg[-1].get("id")
|
||||
e, dia = DialogService.get_by_id(conv.dialog_id)
|
||||
|
||||
logging.info(f"/chats/{chat_id}/completions req={req}--dale --2 history_limit={history_limit} dia {dia}") # cyx
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.message.append({"role": "assistant", "content": "", "id": message_id})
|
||||
@@ -182,19 +203,22 @@ def completion(tenant_id, chat_id):
|
||||
chunk_list.append(new_chunk)
|
||||
reference["chunks"] = chunk_list
|
||||
ans["id"] = message_id
|
||||
ans["session_id"]=session_id
|
||||
ans["session_id"] = session_id
|
||||
|
||||
def stream():
|
||||
nonlocal dia, msg, req, conv
|
||||
try:
|
||||
for ans in chat(dia, msg, **req):
|
||||
fillin_conv(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e),"reference": []}},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
logging.info(f"sessions--3 /chats/<chat_id>/completions error {e} ") # cyx
|
||||
# yield "data:" + json.dumps({"code": 500, "message": str(e),
|
||||
# "data": {"answer": "**ERROR**: " + str(e),"reference": []}},
|
||||
# ensure_ascii=False) + "\n\n"
|
||||
# cyx 2024 12 04 不把错误返回给前端
|
||||
|
||||
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
@@ -216,6 +240,185 @@ def completion(tenant_id, chat_id):
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
# 全角字符到半角字符的映射
|
||||
|
||||
|
||||
def fullwidth_to_halfwidth(s):
|
||||
full_to_half_map = {
|
||||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||||
'、': '、', '・': '・', ':': ':'
|
||||
}
|
||||
return ''.join(full_to_half_map.get(char, char) for char in s)
|
||||
|
||||
|
||||
def is_dale(s):
|
||||
full_to_half_map = {
|
||||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||||
'、': '、', '・': '・', ':': ':', '。': '.'
|
||||
}
|
||||
|
||||
|
||||
def extract_text_from_markdown(markdown_text):
|
||||
# 移除Markdown标题
|
||||
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
||||
# 移除内联代码块
|
||||
text = re.sub(r'`[^`]+`', '', text)
|
||||
# 移除代码块
|
||||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||||
# 移除加粗和斜体
|
||||
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
||||
# 移除链接
|
||||
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
||||
# 移除图片
|
||||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||||
# 移除HTML标签
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
# 转换标点符号
|
||||
# text = re.sub(r'[^\w\s]', '', text)
|
||||
text = fullwidth_to_halfwidth(text)
|
||||
# 移除多余的空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def split_text_at_punctuation(text, chunk_size=100):
|
||||
# 使用正则表达式找到所有的标点符号和特殊字符
|
||||
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
||||
tokens = re.split(punctuation_pattern, text)
|
||||
|
||||
# 移除空字符串
|
||||
tokens = [token for token in tokens if token]
|
||||
|
||||
# 存储最终的文本块
|
||||
chunks = []
|
||||
current_chunk = ''
|
||||
|
||||
for token in tokens:
|
||||
if len(current_chunk) + len(token) <= chunk_size:
|
||||
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
||||
current_chunk += (token + ' ')
|
||||
else:
|
||||
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = token + ' '
|
||||
|
||||
# 添加最后一个块(如果有剩余)
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
audio_text_cache = {}
|
||||
|
||||
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
|
||||
def dialog_tts_get(chat_id, audio_stream_id):
|
||||
tts_info = audio_text_cache.pop(audio_stream_id, None)
|
||||
req = tts_info
|
||||
if not req:
|
||||
return get_error_data_result(message="Audio stream not found or expired.")
|
||||
audio_stream = req.get('audio_stream')
|
||||
tenant_id = req.get('tenant_id')
|
||||
chat_id = req.get('chat_id')
|
||||
text = req.get('text', "..")
|
||||
model_name = req.get('model_name')
|
||||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
tts_model_name = dia.tts_id
|
||||
if model_name: tts_model_name = model_name
|
||||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||||
|
||||
def stream_audio():
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e)}},
|
||||
ensure_ascii=False)).encode('utf-8')
|
||||
|
||||
def generate():
|
||||
data = audio_stream.read(1024)
|
||||
while data:
|
||||
yield data
|
||||
data = audio_stream.read(1024)
|
||||
|
||||
if audio_stream:
|
||||
# 确保流的位置在开始处
|
||||
audio_stream.seek(0)
|
||||
resp = Response(generate(), mimetype="audio/mpeg")
|
||||
else:
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
|
||||
@token_required
|
||||
def dialog_tts_post(tenant_id, chat_id):
|
||||
req = request.json
|
||||
if not req.get("text"):
|
||||
return get_error_data_result(message="Please input your question.")
|
||||
delay_gen_audio = req.get('delay_gen_audio', False)
|
||||
# text = extract_text_from_markdown(req.get('text'))
|
||||
text = req.get('text')
|
||||
audio_stream_id = req.get('audio_stream_id')
|
||||
# logging.info(f"request tts audio url:{text} audio_stream_id:{audio_stream_id} ")
|
||||
if audio_stream_id is None:
|
||||
audio_stream_id = str(uuid.uuid4())
|
||||
# 在这里生成音频流并存储到内存中
|
||||
model_name = req.get('model_name')
|
||||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
tts_model_name = dia.tts_id
|
||||
if model_name: tts_model_name = model_name
|
||||
logging.info(f"---tts {tts_model_name}")
|
||||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||||
if delay_gen_audio:
|
||||
audio_stream = None
|
||||
else:
|
||||
audio_stream = io.BytesIO()
|
||||
audio_text_cache[audio_stream_id] = {'text': text, 'chat_id': chat_id, "tenant_id": tenant_id,
|
||||
'audio_stream': audio_stream,'model_name':model_name} # 缓存文本以便后续生成音频流
|
||||
if delay_gen_audio is False:
|
||||
try:
|
||||
"""
|
||||
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||
try:
|
||||
if txt is None or txt.strip() == "":
|
||||
continue
|
||||
for chunk in tts_mdl.tts(txt):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
continue
|
||||
"""
|
||||
if text is None or text.strip() == "":
|
||||
audio_stream.write(b'\x00' * 100)
|
||||
else:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
return get_error_data_result(message="get tts audio stream error.")
|
||||
|
||||
# 构建音频流URL
|
||||
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
|
||||
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
|
||||
# 返回音频流URL
|
||||
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/completions', methods=['POST'])
|
||||
@token_required
|
||||
def agent_completion(tenant_id, agent_id):
|
||||
@@ -235,7 +438,7 @@ def agent_completion(tenant_id, agent_id):
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": req.get("user_id",""),
|
||||
"user_id": req.get("user_id", ""),
|
||||
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
|
||||
"source": "agent"
|
||||
}
|
||||
@@ -251,9 +454,9 @@ def agent_completion(tenant_id, agent_id):
|
||||
question = req.get("question")
|
||||
if not question:
|
||||
return get_error_data_result("`question` is required.")
|
||||
question={
|
||||
"role":"user",
|
||||
"content":question,
|
||||
question = {
|
||||
"role": "user",
|
||||
"content": question,
|
||||
"id": str(uuid4())
|
||||
}
|
||||
messages.append(question)
|
||||
@@ -308,6 +511,7 @@ def agent_completion(tenant_id, agent_id):
|
||||
if 'docnm_kwd' in chunk_i:
|
||||
chunk_i['doc_name'] = chunk_i['docnm_kwd']
|
||||
chunk_i.pop('docnm_kwd')
|
||||
|
||||
conv.message.append(msg[-1])
|
||||
|
||||
if not conv.reference:
|
||||
@@ -375,9 +579,26 @@ def agent_completion(tenant_id, agent_id):
|
||||
return get_result(data=result)
|
||||
|
||||
|
||||
# added by cyx
|
||||
# 打印 ConversationService.model 的表名及字段定义
|
||||
def print_table_info(service):
|
||||
model = service.model # 获取关联的模型
|
||||
if model is None:
|
||||
print("No model associated with the service.")
|
||||
return
|
||||
|
||||
# 打印表名
|
||||
logging.info(f"Table Name: {model._meta.table_name}")
|
||||
|
||||
# 打印所有字段及其定义
|
||||
logging.info("Fields and Definitions:")
|
||||
for field_name, field in model._meta.fields.items():
|
||||
print(f" {field_name}: {field}")
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
|
||||
@token_required
|
||||
def list_session(chat_id,tenant_id):
|
||||
def list_session(chat_id, tenant_id):
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
||||
id = request.args.get("id")
|
||||
@@ -389,7 +610,8 @@ def list_session(chat_id,tenant_id):
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
|
||||
print_table_info(ConversationService) # cyx
|
||||
convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name)
|
||||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
@@ -429,9 +651,33 @@ def list_session(chat_id,tenant_id):
|
||||
return get_result(data=convs)
|
||||
|
||||
|
||||
# added by cyx 20241201
|
||||
@manager.route('/chats/<chat_id>/sessions_summary', methods=['GET'])
|
||||
@token_required
|
||||
def sessions_summary(chat_id, tenant_id):
|
||||
# 校验用户是否拥有指定的会话助手
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
|
||||
|
||||
# 统计会话概要信息
|
||||
summaries = ConversationService.query_sessions_summary()
|
||||
|
||||
# 过滤结果,仅返回属于指定 chat_id 的记录
|
||||
filtered_summaries = [
|
||||
summary for summary in summaries if summary["dialog_id"] == chat_id
|
||||
]
|
||||
|
||||
# 如果没有符合条件的记录,返回空列表
|
||||
if not filtered_summaries:
|
||||
return get_result(data=[])
|
||||
|
||||
# 返回过滤后的概要信息
|
||||
return get_result(data=filtered_summaries)
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
|
||||
@token_required
|
||||
def delete(tenant_id,chat_id):
|
||||
def delete(tenant_id, chat_id):
|
||||
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
|
||||
return get_error_data_result(message="You don't own the chat")
|
||||
req = request.json
|
||||
@@ -439,21 +685,22 @@ def delete(tenant_id,chat_id):
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
ids=req.get("ids")
|
||||
ids = req.get("ids")
|
||||
|
||||
if not ids:
|
||||
conv_list = []
|
||||
for conv in convs:
|
||||
conv_list.append(conv.id)
|
||||
else:
|
||||
conv_list=ids
|
||||
conv_list = ids
|
||||
for id in conv_list:
|
||||
conv = ConversationService.query(id=id,dialog_id=chat_id)
|
||||
conv = ConversationService.query(id=id, dialog_id=chat_id)
|
||||
if not conv:
|
||||
return get_error_data_result(message="The chat doesn't own the session")
|
||||
ConversationService.delete_by_id(id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/sessions/ask', methods=['POST'])
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
@@ -462,17 +709,18 @@ def ask_about(tenant_id):
|
||||
return get_error_data_result("`question` is required.")
|
||||
if not req.get("dataset_ids"):
|
||||
return get_error_data_result("`dataset_ids` is required.")
|
||||
if not isinstance(req.get("dataset_ids"),list):
|
||||
if not isinstance(req.get("dataset_ids"), list):
|
||||
return get_error_data_result("`dataset_ids` should be a list.")
|
||||
req["kb_ids"]=req.pop("dataset_ids")
|
||||
req["kb_ids"] = req.pop("dataset_ids")
|
||||
for kb_id in req["kb_ids"]:
|
||||
if not KnowledgebaseService.accessible(kb_id,tenant_id):
|
||||
if not KnowledgebaseService.accessible(kb_id, tenant_id):
|
||||
return get_error_data_result(f"You don't own the dataset {kb_id}.")
|
||||
kbs = KnowledgebaseService.query(id=kb_id)
|
||||
kb = kbs[0]
|
||||
if kb.chunk_num == 0:
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
uid = tenant_id
|
||||
|
||||
def stream():
|
||||
nonlocal req, uid
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user