从新提交到gitee 仓库

This commit is contained in:
qcloud
2025-02-06 23:34:26 +08:00
parent e678819f70
commit c88312a914
62 changed files with 211935 additions and 7500 deletions

View File

@@ -0,0 +1,51 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request
from api import settings
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result
# 用户已经添加的模型 cyx 2025-01-26
@manager.route('/get_llms', methods=['GET'])
@token_required
def my_llms(tenant_id):
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
model_type = request.args.get("type")
try:
res = {}
for o in TenantLLMService.get_my_llms(tenant_id):
if model_type is None or o["model_type"] == model_type: # 增加按类型的筛选
if o["llm_factory"] not in res:
res[o["llm_factory"]] = {
"tags": o["tags"],
"llm": []
}
res[o["llm_factory"]]["llm"].append({
"type": o["model_type"],
"name": o["llm_name"],
"used_token": o["used_tokens"]
})
return get_result(data=res)
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")

View File

@@ -13,12 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import re, io
import json
import logging
from copy import deepcopy
from uuid import uuid4
from api.db import LLMType
from flask import request, Response
from flask import request, Response, jsonify
from api.db.services.dialog_service import ask
from agent.canvas import Canvas
from api.db import StatusEnum
@@ -31,11 +32,13 @@ from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result
from api.utils.api_utils import get_result, token_required
from api.db.services.llm_service import LLMBundle
import uuid
import queue
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
@token_required
def create(tenant_id,chat_id):
def create(tenant_id, chat_id):
req = request.json
req["dialog_id"] = chat_id
dia = DialogService.query(tenant_id=tenant_id, id=req["dialog_id"], status=StatusEnum.VALID.value)
@@ -77,7 +80,7 @@ def create_agent_session(tenant_id, agent_id):
conv = {
"id": get_uuid(),
"dialog_id": cvs.id,
"user_id": req.get("usr_id","") if isinstance(req, dict) else "",
"user_id": req.get("usr_id", "") if isinstance(req, dict) else "",
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
@@ -88,11 +91,11 @@ def create_agent_session(tenant_id, agent_id):
@manager.route('/chats/<chat_id>/sessions/<session_id>', methods=['PUT'])
@token_required
def update(tenant_id,chat_id,session_id):
def update(tenant_id, chat_id, session_id):
req = request.json
req["dialog_id"] = chat_id
conv_id = session_id
conv = ConversationService.query(id=conv_id,dialog_id=chat_id)
conv = ConversationService.query(id=conv_id, dialog_id=chat_id)
if not conv:
return get_error_data_result(message="Session does not exist")
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
@@ -110,9 +113,9 @@ def update(tenant_id,chat_id,session_id):
@manager.route('/chats/<chat_id>/completions', methods=['POST'])
@token_required
def completion(tenant_id, chat_id):
def completion(tenant_id, chat_id): # chat_id 和 别的文件中的dialog_id 应该是一个意思? cyx 2025-01-25
req = request.json
if not req.get("session_id"):
if not req.get("session_id"): # session_id 和 别的文件中的conversation_id 应该是一个意思? cyx 2025-01-25
conv = {
"id": get_uuid(),
"dialog_id": chat_id,
@@ -123,12 +126,18 @@ def completion(tenant_id, chat_id):
return get_error_data_result(message="`name` can not be empty.")
ConversationService.save(**conv)
e, conv = ConversationService.get_by_id(conv["id"])
session_id=conv.id
session_id = conv.id
else:
session_id = req.get("session_id")
if not req.get("question"):
return get_error_data_result(message="Please input your question.")
conv = ConversationService.query(id=session_id,dialog_id=chat_id)
#conv = ConversationService.query(id=session_id, dialog_id=chat_id)
# 以下改动是为了限制从历史记录中取过多的记录
history_limit = req.get("history_limit", None)
if history_limit is not None:
conv = ConversationService.query(id=session_id, dialog_id=chat_id, reverse=True, order_by="create_time")
else:
conv = ConversationService.query(id=session_id, dialog_id=chat_id)
if not conv:
return get_error_data_result(message="Session does not exist")
conv = conv[0]
@@ -141,13 +150,25 @@ def completion(tenant_id, chat_id):
"id": str(uuid4())
}
conv.message.append(question)
# 第一次遍历,计算 assistant 消息的总数
assistant_total_count = sum(1 for m in conv.message if m["role"] == "assistant")
# 第二次遍历,按条件添加消息到 msg
current_assistant_count = 0 # 跟踪当前添加的 assistant 消息数
for m in conv.message:
if m["role"] == "system": continue
if m["role"] == "assistant" and not msg: continue
if m['role'] == "assistant":
# 如果 assistant 消息超出需要保留的数量,跳过
# 检查 history_limit 是否为 NoneNone 表示不限制
if history_limit is not None and current_assistant_count < assistant_total_count - history_limit:
current_assistant_count += 1
continue
msg.append(m)
message_id = msg[-1].get("id")
e, dia = DialogService.get_by_id(conv.dialog_id)
logging.info(f"/chats/{chat_id}/completions req={req}--dale --2 history_limit={history_limit} dia {dia}") # cyx
if not conv.reference:
conv.reference = []
conv.message.append({"role": "assistant", "content": "", "id": message_id})
@@ -182,19 +203,22 @@ def completion(tenant_id, chat_id):
chunk_list.append(new_chunk)
reference["chunks"] = chunk_list
ans["id"] = message_id
ans["session_id"]=session_id
ans["session_id"] = session_id
def stream():
nonlocal dia, msg, req, conv
try:
for ans in chat(dia, msg, **req):
fillin_conv(ans)
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
ConversationService.update_by_id(conv.id, conv.to_dict())
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e),"reference": []}},
ensure_ascii=False) + "\n\n"
logging.info(f"sessions--3 /chats/<chat_id>/completions error {e} ") # cyx
# yield "data:" + json.dumps({"code": 500, "message": str(e),
# "data": {"answer": "**ERROR**: " + str(e),"reference": []}},
# ensure_ascii=False) + "\n\n"
# cyx 2024 12 04 不把错误返回给前端
yield "data:" + json.dumps({"code": 0, "data": True}, ensure_ascii=False) + "\n\n"
if req.get("stream", True):
@@ -216,6 +240,185 @@ def completion(tenant_id, chat_id):
return get_result(data=answer)
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def is_dale(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':', '': '.'
}
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
audio_text_cache = {}
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
def dialog_tts_get(chat_id, audio_stream_id):
tts_info = audio_text_cache.pop(audio_stream_id, None)
req = tts_info
if not req:
return get_error_data_result(message="Audio stream not found or expired.")
audio_stream = req.get('audio_stream')
tenant_id = req.get('tenant_id')
chat_id = req.get('chat_id')
text = req.get('text', "..")
model_name = req.get('model_name')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not dia:
return get_error_data_result(message="You do not own the chat")
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
def stream_audio():
try:
for chunk in tts_mdl.tts(text):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
def generate():
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
if audio_stream:
# 确保流的位置在开始处
audio_stream.seek(0)
resp = Response(generate(), mimetype="audio/mpeg")
else:
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
@token_required
def dialog_tts_post(tenant_id, chat_id):
req = request.json
if not req.get("text"):
return get_error_data_result(message="Please input your question.")
delay_gen_audio = req.get('delay_gen_audio', False)
# text = extract_text_from_markdown(req.get('text'))
text = req.get('text')
audio_stream_id = req.get('audio_stream_id')
# logging.info(f"request tts audio url:{text} audio_stream_id:{audio_stream_id} ")
if audio_stream_id is None:
audio_stream_id = str(uuid.uuid4())
# 在这里生成音频流并存储到内存中
model_name = req.get('model_name')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
logging.info(f"---tts {tts_model_name}")
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
if delay_gen_audio:
audio_stream = None
else:
audio_stream = io.BytesIO()
audio_text_cache[audio_stream_id] = {'text': text, 'chat_id': chat_id, "tenant_id": tenant_id,
'audio_stream': audio_stream,'model_name':model_name} # 缓存文本以便后续生成音频流
if delay_gen_audio is False:
try:
"""
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
try:
if txt is None or txt.strip() == "":
continue
for chunk in tts_mdl.tts(txt):
audio_stream.write(chunk)
except Exception as e:
continue
"""
if text is None or text.strip() == "":
audio_stream.write(b'\x00' * 100)
else:
for chunk in tts_mdl.tts(text):
audio_stream.write(chunk)
except Exception as e:
return get_error_data_result(message="get tts audio stream error.")
# 构建音频流URL
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
# 返回音频流URL
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
@manager.route('/agents/<agent_id>/completions', methods=['POST'])
@token_required
def agent_completion(tenant_id, agent_id):
@@ -235,7 +438,7 @@ def agent_completion(tenant_id, agent_id):
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": req.get("user_id",""),
"user_id": req.get("user_id", ""),
"message": [{"role": "assistant", "content": canvas.get_prologue()}],
"source": "agent"
}
@@ -251,9 +454,9 @@ def agent_completion(tenant_id, agent_id):
question = req.get("question")
if not question:
return get_error_data_result("`question` is required.")
question={
"role":"user",
"content":question,
question = {
"role": "user",
"content": question,
"id": str(uuid4())
}
messages.append(question)
@@ -308,6 +511,7 @@ def agent_completion(tenant_id, agent_id):
if 'docnm_kwd' in chunk_i:
chunk_i['doc_name'] = chunk_i['docnm_kwd']
chunk_i.pop('docnm_kwd')
conv.message.append(msg[-1])
if not conv.reference:
@@ -375,9 +579,26 @@ def agent_completion(tenant_id, agent_id):
return get_result(data=result)
# added by cyx
# 打印 ConversationService.model 的表名及字段定义
def print_table_info(service):
model = service.model # 获取关联的模型
if model is None:
print("No model associated with the service.")
return
# 打印表名
logging.info(f"Table Name: {model._meta.table_name}")
# 打印所有字段及其定义
logging.info("Fields and Definitions:")
for field_name, field in model._meta.fields.items():
print(f" {field_name}: {field}")
@manager.route('/chats/<chat_id>/sessions', methods=['GET'])
@token_required
def list_session(chat_id,tenant_id):
def list_session(chat_id, tenant_id):
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
id = request.args.get("id")
@@ -389,7 +610,8 @@ def list_session(chat_id,tenant_id):
desc = False
else:
desc = True
convs = ConversationService.get_list(chat_id,page_number,items_per_page,orderby,desc,id,name)
print_table_info(ConversationService) # cyx
convs = ConversationService.get_list(chat_id, page_number, items_per_page, orderby, desc, id, name)
if not convs:
return get_result(data=[])
for conv in convs:
@@ -429,9 +651,33 @@ def list_session(chat_id,tenant_id):
return get_result(data=convs)
# added by cyx 20241201
@manager.route('/chats/<chat_id>/sessions_summary', methods=['GET'])
@token_required
def sessions_summary(chat_id, tenant_id):
# 校验用户是否拥有指定的会话助手
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
return get_error_data_result(message=f"You don't own the assistant {chat_id}.")
# 统计会话概要信息
summaries = ConversationService.query_sessions_summary()
# 过滤结果,仅返回属于指定 chat_id 的记录
filtered_summaries = [
summary for summary in summaries if summary["dialog_id"] == chat_id
]
# 如果没有符合条件的记录,返回空列表
if not filtered_summaries:
return get_result(data=[])
# 返回过滤后的概要信息
return get_result(data=filtered_summaries)
@manager.route('/chats/<chat_id>/sessions', methods=["DELETE"])
@token_required
def delete(tenant_id,chat_id):
def delete(tenant_id, chat_id):
if not DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value):
return get_error_data_result(message="You don't own the chat")
req = request.json
@@ -439,21 +685,22 @@ def delete(tenant_id,chat_id):
if not req:
ids = None
else:
ids=req.get("ids")
ids = req.get("ids")
if not ids:
conv_list = []
for conv in convs:
conv_list.append(conv.id)
else:
conv_list=ids
conv_list = ids
for id in conv_list:
conv = ConversationService.query(id=id,dialog_id=chat_id)
conv = ConversationService.query(id=id, dialog_id=chat_id)
if not conv:
return get_error_data_result(message="The chat doesn't own the session")
ConversationService.delete_by_id(id)
return get_result()
@manager.route('/sessions/ask', methods=['POST'])
@token_required
def ask_about(tenant_id):
@@ -462,17 +709,18 @@ def ask_about(tenant_id):
return get_error_data_result("`question` is required.")
if not req.get("dataset_ids"):
return get_error_data_result("`dataset_ids` is required.")
if not isinstance(req.get("dataset_ids"),list):
if not isinstance(req.get("dataset_ids"), list):
return get_error_data_result("`dataset_ids` should be a list.")
req["kb_ids"]=req.pop("dataset_ids")
req["kb_ids"] = req.pop("dataset_ids")
for kb_id in req["kb_ids"]:
if not KnowledgebaseService.accessible(kb_id,tenant_id):
if not KnowledgebaseService.accessible(kb_id, tenant_id):
return get_error_data_result(f"You don't own the dataset {kb_id}.")
kbs = KnowledgebaseService.query(id=kb_id)
kb = kbs[0]
if kb.chunk_num == 0:
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
uid = tenant_id
def stream():
nonlocal req, uid
try: