Files
ragflow_python/api/apps/sdk/dale_extra.py

407 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request , Response, jsonify
from api import settings
from api.db import LLMType
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService,stream_manager
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from api.db.services.brief_service import MesumOverviewService
from api.db.services.llm_service import LLMBundle
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result
import logging
import base64
import queue,time,uuid
from threading import Lock,Thread
from zhipuai import ZhipuAI
# 用户已经添加的模型 cyx 2025-01-26
@manager.route('/get_llms', methods=['GET'])
@token_required
def my_llms(tenant_id):
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
model_type = request.args.get("type")
try:
res = {}
for o in TenantLLMService.get_my_llms(tenant_id):
if model_type is None or o["model_type"] == model_type: # 增加按类型的筛选
if o["llm_factory"] not in res:
res[o["llm_factory"]] = {
"tags": o["tags"],
"llm": []
}
res[o["llm_factory"]]["llm"].append({
"type": o["model_type"],
"name": o["llm_name"],
"used_token": o["used_tokens"]
})
return get_result(data=res)
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链"
@manager.route('/photo/recongeText', methods=['POST'])
@token_required
def upload_file(tenant_id):
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
file_size = request.content_length
img_base = base64.b64encode(file.read()).decode('utf-8')
req_antique = request.form.get('antique',None)
if req_antique is None:
req_antique = main_antiquity
logging.info(f"recevie photo file {file.filename} {file_size} 识别中....")
client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey
response = client.chat.completions.create(
model="glm-4v-plus", # 填写需要调用的模型名称
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": img_base
}
},
{
"type": "text",
"text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称,"
f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍"
f"你的回答有2个结果第一个结果是是从文字进行分析出匹配文物候选文物只能如下:{req_antique},"
f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出,"
f",第二个结果是原始识别的所有文字"
"2个结果输出以{ }的json格式给出匹配文物的键值为antique如果有多个请加序号如:antique1,antique2,"
f"原始数据的键值为text输出是1个完整的JSON数据不要有多余的前置和后置内容确保前端能正确解析出JSON数据")
}
]
}
]
)
message = response.choices[0].message
logging.info(message.content)
return jsonify({'message': 'File uploaded successfully','text':message.content }), 200
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
#get_all
@manager.route('/mesum/list', methods=['GET'])
@token_required
def mesum_list(tenant_id):
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
# model_type = request.args.get("type")
try:
res = []
overviews=MesumOverviewService.get_all()
for o in overviews:
res.append(o.to_dict())
return get_result(data=res)
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
@manager.route('/mesum/set_antique', methods=['POST'])
@token_required
def mesum_set_antique(tenant_id):
global main_antiquity
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
req_data = request.json
req_data_antique=req_data.get('antique',None)
try:
if req_data_antique:
main_antiquity = req_data_antique
print(main_antiquity)
return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'})
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
now = time.time()
expired_keys = [
k for k, v in audio_text_cache.items()
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
]
for k in expired_keys:
entry = audio_text_cache.pop(k, None)
if entry and entry.get('audio_stream'):
entry['audio_stream'].close()
def start_background_cleaner():
"""启动后台清理线程"""
def cleaner_loop():
while True:
time.sleep(180) # 每3分钟清理一次
clean_audio_cache()
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
cleaner_thread.start()
# 应用启动时启动清理线程
start_background_cleaner()
@manager.route('/tts_stream/<session_id>')
def tts_stream(session_id):
def generate():
retry_count = 0
session = None
count = 0;
try:
while retry_count < 1:
session = stream_manager.sessions.get(session_id)
if not session or not session['active']:
break
try:
chunk = session['buffer'].get(timeout=5) # 30秒超时
count = count + 1
if isinstance(chunk, str) and chunk.startswith("ERROR"):
logging.info("---tts stream error!!!!")
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
break
yield chunk
retry_count = 0 # 成功收到数据重置重试计数器
except queue.Empty:
retry_count += 1
yield b'' # 保持连接
finally:
# 确保流结束后关闭会话
if session:
# 延迟关闭会话,确保所有数据已发送
time.sleep(5) # 等待5秒确保流结束
stream_manager.close_session(session_id)
logging.info(f"Session {session_id} closed.")
resp = Response(generate(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
def dialog_tts_get(chat_id, audio_stream_id):
with cache_lock:
tts_info = audio_text_cache.pop(audio_stream_id, None) # 取出即删除
try:
req = tts_info
if not req:
return get_error_data_result(message="Audio stream not found or expired.")
audio_stream = req.get('audio_stream')
tenant_id = req.get('tenant_id')
chat_id = req.get('chat_id')
text = req.get('text', "..")
model_name = req.get('model_name')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not dia:
return get_error_data_result(message="You do not own the chat")
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
def stream_audio():
try:
for chunk in tts_mdl.tts(text):
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
def generate():
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
if audio_stream:
# 确保流的位置在开始处
audio_stream.seek(0)
resp = Response(generate(), mimetype="audio/mpeg")
else:
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
except Exception as e:
logging.error(f"音频流传输错误: {str(e)}", exc_info=True)
return get_error_data_result(message="音频流传输失败")
finally:
# 确保资源释放
if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
tts_info['audio_stream'].close()
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
@token_required
def dialog_tts_post(tenant_id, chat_id):
try:
req = request.json
if not req.get("text"):
return get_error_data_result(message="Please input your question.")
delay_gen_audio = req.get('delay_gen_audio', False)
# text = extract_text_from_markdown(req.get('text'))
text = req.get('text')
model_name = req.get('model_name')
audio_stream_id = req.get('audio_stream_id', None)
if audio_stream_id is None:
audio_stream_id = str(uuid.uuid4())
# 在这里生成音频流并存储到内存中
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
if delay_gen_audio:
audio_stream = None
else:
audio_stream = io.BytesIO()
# 结构化缓存数据
tts_info = {
'text': text,
'tenant_id': tenant_id,
'chat_id': chat_id,
'created_at': time.time(),
'audio_stream': audio_stream, # 维持原有逻辑
'model_name': req.get('model_name'),
'delay_gen_audio': delay_gen_audio, # 明确存储状态
audio_stream_id: audio_stream_id
}
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
if delay_gen_audio is False:
try:
"""
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
try:
if txt is None or txt.strip() == "":
continue
for chunk in tts_mdl.tts(txt):
audio_stream.write(chunk)
except Exception as e:
continue
"""
if text is None or text.strip() == "":
audio_stream.write(b'\x00' * 100)
else:
# 确保在流的末尾写入
audio_stream.seek(0, io.SEEK_END)
for chunk in tts_mdl.tts(text):
audio_stream.write(chunk)
except Exception as e:
logging.info(f"--error:{e}")
with cache_lock:
audio_text_cache.pop(audio_stream_id, None)
return get_error_data_result(message="get tts audio stream error.")
# 构建音频流URL
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
# 返回音频流URL
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
except Exception as e:
logging.error(f"请求处理失败: {str(e)}", exc_info=True)
return get_error_data_result(message="服务器内部错误")