Files
ragflow_python/api/apps/sdk/dale_extra.py

527 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request , Response, jsonify,stream_with_context
from api import settings
from api.db import LLMType
from api.db import StatusEnum
from api.db.services.dialog_service import DialogService,stream_manager
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from api.db.services.brief_service import MesumOverviewService
from api.db.services.llm_service import LLMBundle
from api.db.services.antique_service import MesumAntiqueService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result
from api.utils.file_utils import get_project_base_directory
import logging
import base64, gzip
from io import BytesIO
import queue,time,uuid,os,array
from threading import Lock,Thread
from zhipuai import ZhipuAI
# 用户已经添加的模型 cyx 2025-01-26
@manager.route('/get_llms', methods=['GET'])
@token_required
def my_llms(tenant_id):
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
model_type = request.args.get("type")
try:
res = {}
for o in TenantLLMService.get_my_llms(tenant_id):
if model_type is None or o["model_type"] == model_type: # 增加按类型的筛选
if o["llm_factory"] not in res:
res[o["llm_factory"]] = {
"tags": o["tags"],
"llm": []
}
res[o["llm_factory"]]["llm"].append({
"type": o["model_type"],
"name": o["llm_name"],
"used_token": o["used_tokens"]
})
return get_result(data=res)
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链"
@manager.route('/photo/recongeText/<mesum_id>', methods=['POST'])
@token_required
def upload_file(tenant_id,mesum_id):
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
antiques_selected = ""
if mesum_id:
"""
e,mesum_breif = MesumOverviewService.get_by_id(mesum_id)
if not e:
logging.info(f"没有找到匹配的博物馆信息,mesum_id={mesum_id}")
else:
antiques_selected =f"结果从:{mesum_breif.antique} 中进行选择"
"""
mesum_id_str = str(mesum_id)
antique_labels=get_antique_labels(mesum_id)
# 使用列表推导式和str()函数将所有元素转换为字符串
string_elements = [str(element) for element in antique_labels]
# 使用join()方法将字符串元素连接起来,以逗号为分隔符
joined_string = ','.join(string_elements)
antiques_selected = f"结果从:{joined_string} 中进行选择"
logging.info(f"{mesum_id} {joined_string}")
prompt = (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
f"请识别这个图片中文字,重点识别出含在文字中的某一文物标题、某一个历史事件或某一历史人物,"
f"你的回答有2个结果第一个结果是是从文字中识别出历史文物、历史事件、历史人物,"
f"此回答时只给出匹配的文物、事件、人物,不需要其他多余的文字,{antiques_selected}"
f",第二个结果是原始识别的所有文字"
"2个结果输出以{ }的json格式给出匹配文物、事件、人物的键值为antique如果有多个请加序号如:antique1,antique2,"
f"原始数据的键值为text输出是1个完整的JSON数据不要有多余的前置和后置内容确保前端能正确解析出JSON数据")
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
file_size = request.content_length
img_base = base64.b64encode(file.read()).decode('utf-8')
req_antique = request.form.get('antique',None)
if req_antique is None:
req_antique = main_antiquity
logging.info(f"recevie photo file {file.filename} {file_size} 识别中....")
client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey
response = client.chat.completions.create(
model="glm-4v-plus", # 填写需要调用的模型名称
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": img_base
}
},
{
"type": "text",
"text": prompt
}
]
}
]
)
message = response.choices[0].message
logging.info(message.content)
return jsonify({'message': 'File uploaded successfully','text':message.content }), 200
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
#get_all
@manager.route('/mesum/list', methods=['GET'])
@token_required
def mesum_list(tenant_id):
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
# model_type = request.args.get("type")
try:
res = []
overviews=MesumOverviewService.get_all()
for o in overviews:
res.append(o.to_dict())
return get_result(data=res)
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
@manager.route('/mesum/set_antique', methods=['POST'])
@token_required
def mesum_set_antique(tenant_id):
global main_antiquity
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
req_data = request.json
req_data_antique=req_data.get('antique',None)
try:
if req_data_antique:
main_antiquity = req_data_antique
print(main_antiquity)
return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'})
except Exception as e:
return get_error_data_result(message=f"Get LLMS error {e}")
audio_text_cache = {}
cache_lock = Lock()
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
# 全角字符到半角字符的映射
def fullwidth_to_halfwidth(s):
full_to_half_map = {
'': '!', '': '"', '': '#', '': '$', '': '%', '': '&', '': "'",
'': '(', '': ')', '': '*', '': '+', '': ',', '': '-', '': '.',
'': '/', '': ':', '': ';', '': '<', '': '=', '': '>', '': '?',
'': '@', '': '[', '': '\\', '': ']', '': '^', '_': '_', '': '`',
'': '{', '': '|', '': '}', '': '~', '': '', '': '', '': '',
'': '', '': ',', '': '.', '': '-', '': '.', '': '', '': '',
'': '', '': '', '': ':'
}
return ''.join(full_to_half_map.get(char, char) for char in s)
def split_text_at_punctuation(text, chunk_size=100):
# 使用正则表达式找到所有的标点符号和特殊字符
punctuation_pattern = r'[\s,.!?;:\-\\(\)\[\]{}"\'\\\/]+'
tokens = re.split(punctuation_pattern, text)
# 移除空字符串
tokens = [token for token in tokens if token]
# 存储最终的文本块
chunks = []
current_chunk = ''
for token in tokens:
if len(current_chunk) + len(token) <= chunk_size:
# 如果添加当前token后长度不超过chunk_size则添加到当前块
current_chunk += (token + ' ')
else:
# 如果长度超过chunk_size则将当前块添加到chunks列表并开始新块
chunks.append(current_chunk.strip())
current_chunk = token + ' '
# 添加最后一个块(如果有剩余)
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
def extract_text_from_markdown(markdown_text):
# 移除Markdown标题
text = re.sub(r'#\s*[^#]+', '', markdown_text)
# 移除内联代码块
text = re.sub(r'`[^`]+`', '', text)
# 移除代码块
text = re.sub(r'```[\s\S]*?```', '', text)
# 移除加粗和斜体
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
# 移除链接
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
# 移除图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 移除HTML标签
text = re.sub(r'<[^>]+>', '', text)
# 转换标点符号
# text = re.sub(r'[^\w\s]', '', text)
text = fullwidth_to_halfwidth(text)
# 移除多余的空格
text = re.sub(r'\s+', ' ', text).strip()
return text
def encode_gzip_base64(original_data: bytes) -> str:
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
# Step 1: Gzip 压缩
with BytesIO() as buf:
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
gz_file.write(original_data)
compressed_bytes = buf.getvalue()
# Step 2: Base64 编码配置与Android端匹配
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符等同于Android的Base64.NO_WRAP
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
now = time.time()
expired_keys = [
k for k, v in audio_text_cache.items()
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
]
for k in expired_keys:
entry = audio_text_cache.pop(k, None)
if entry and entry.get('audio_stream'):
entry['audio_stream'].close()
def start_background_cleaner():
"""启动后台清理线程"""
def cleaner_loop():
while True:
time.sleep(180) # 每3分钟清理一次
clean_audio_cache()
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
cleaner_thread.start()
# 应用启动时启动清理线程
start_background_cleaner()
@manager.route('/tts_stream/<session_id>',methods=['GET'])
def tts_stream(session_id):
session = stream_manager.sessions.get(session_id)
def generate():
count = 0;
path = os.path.join(get_project_base_directory(), "api", "apps/sdk/test.mp3")
fmp3 =open(path, 'rb')
finished_event = session['finished']
try:
while not finished_event.is_set() :
if not session or not session['active']:
break
try:
chunk = session['buffer'].get_nowait() #
count = count + 1
if isinstance(chunk, str) and chunk.startswith("ERROR"):
logging.info(f"---tts stream error!!!! {chunk}")
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
break
if session['stream_format'] == "wav":
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
yield gzip_base64_data
else:
yield chunk
retry_count = 0 # 成功收到数据重置重试计数器
except queue.Empty:
if session['stream_format'] == "wav":
# yield encode_gzip_base64(b'\x03\x04' * 1) + "\r\n"
pass
else:
yield b'' # 保持连接
#data = fmp3.read(1024)
#yield data
except Exception as e:
logging.info(f"tts streag get error2 {e} ")
finally:
# 确保流结束后关闭会话
if session:
# 延迟关闭会话,确保所有数据已发送
stream_manager.close_session(session_id)
logging.info(f"Session {session_id} closed.")
# 关键响应头设置
if session['stream_format'] == "wav":
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
else:
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
def dialog_tts_get(chat_id, audio_stream_id):
with cache_lock:
tts_info = audio_text_cache.pop(audio_stream_id, None) # 取出即删除
try:
req = tts_info
if not req:
return get_error_data_result(message="Audio stream not found or expired.")
audio_stream = req.get('audio_stream')
tenant_id = req.get('tenant_id')
chat_id = req.get('chat_id')
text = req.get('text', "..")
model_name = req.get('model_name')
sample_rate = req.get('tts_sample_rate',8000) # 默认8K
stream_format = req.get('tts_stream_format','mp3')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not dia:
return get_error_data_result(message="You do not own the chat")
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
logging.info(f"dialog_tts_get {sample_rate} {stream_format}")
def stream_audio():
try:
for chunk in tts_mdl.tts(text,sample_rate=sample_rate,stream_format=stream_format):
if stream_format =='wav':
yield encode_gzip_base64(chunk) + "\r\n"
else:
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
ensure_ascii=False)).encode('utf-8')
def generate():
data = audio_stream.read(1024)
while data:
yield data
data = audio_stream.read(1024)
if audio_stream:
# 确保流的位置在开始处
audio_stream.seek(0)
resp = Response(generate(), mimetype="audio/mpeg")
else:
if stream_format == 'wav':
resp = Response(stream_audio(), mimetype="audio/wav")
else:
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
return resp
except Exception as e:
logging.error(f"音频流传输错误: {str(e)}", exc_info=True)
return get_error_data_result(message="音频流传输失败")
finally:
# 确保资源释放
if tts_info and tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
tts_info['audio_stream'].close()
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
@token_required
def dialog_tts_post(tenant_id, chat_id):
req = request.json
try:
if not req.get("text"):
return get_error_data_result(message="Please input your question.")
text = req.get('text')
delay_gen_audio = req.get('delay_gen_audio', False)
model_name = req.get('model_name')
audio_stream_id = req.get('audio_stream_id', None)
if audio_stream_id is None:
audio_stream_id = str(uuid.uuid4())
# 在这里生成音频流并存储到内存中
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
if delay_gen_audio:
audio_stream = None
else:
audio_stream = io.BytesIO()
tts_stream_format = req.get('tts_stream_format', "mp3")
tts_sample_rate = req.get('tts_sample_rate', 8000)
logging.info(f"tts post {tts_sample_rate} {tts_stream_format}")
# 结构化缓存数据
tts_info = {
'text': text,
'tenant_id': tenant_id,
'chat_id': chat_id,
'created_at': time.time(),
'audio_stream': audio_stream, # 维持原有逻辑
'model_name': req.get('model_name'),
'delay_gen_audio': delay_gen_audio, # 明确存储状态
'audio_stream_id': audio_stream_id,
'tts_sample_rate':tts_sample_rate,
'tts_stream_format':tts_stream_format
}
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
if delay_gen_audio is False:
try:
audio_stream.seek(0, io.SEEK_END)
if text is None or text.strip() == "":
audio_stream.write(b'\x00' * 100)
else:
# 确保在流的末尾写入
for chunk in tts_mdl.tts(text,sample_rate=tts_sample_rate,stream_formate=tts_stream_format):
audio_stream.write(chunk)
except Exception as e:
logging.info(f"--error:{e}")
with cache_lock:
audio_text_cache.pop(audio_stream_id, None)
return get_error_data_result(message="get tts audio stream error.")
# 构建音频流URL
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url} "
f"{tts_sample_rate} {tts_stream_format}")
# 返回音频流URL
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id,
"sample_rate":tts_sample_rate, "stream_format":tts_stream_format,})
except Exception as e:
logging.error(f"请求处理失败: {str(e)}", exc_info=True)
return get_error_data_result(message="服务器内部错误")
def get_antique_categories(mesum_id):
res = MesumAntiqueService.get_all_categories()
return res
def get_labels_ext(mesum_id):
res = MesumAntiqueService.get_labels_ext(mesum_id)
return res
def get_antique_labels(mesum_id):
res = MesumAntiqueService.get_all_labels()
return res
def get_all_antiques(mesum_id):
res =[]
antiques=MesumAntiqueService.get_by_mesum_id(mesum_id)
for o in antiques:
res.append(o.to_dict())
return res
@manager.route('/mesum/antique/<mesum_id>', methods=['GET'])
def mesum_antique_get(mesum_id):
try:
data = {
"anqituqes":get_all_antiques(mesum_id),
"categories":get_antique_categories(mesum_id),
"labels":get_antique_labels(mesum_id)
}
return get_result(data=data)
except Exception as e:
return get_error_data_result(message=f"Get mesum antique error {e}")
# 按照mesum_id 获得此博物馆的展品清单
@manager.route('/mesum/antique_brief/<mesum_id>', methods=['GET'])
@token_required
def mesum_antique_get_brief(tenant_id,mesum_id):
try:
data = {
"categories":get_antique_categories(mesum_id),
"labels":get_labels_ext(mesum_id)
}
return get_result(data=data)
except Exception as e:
return get_error_data_result(message=f"Get mesum antique error {e}")
@manager.route('/mesum/antique_detail/<mesum_id>/<antique_id>', methods=['GET'])
@token_required
def mesum_antique_get_full(tenant_id,mesum_id,antique_id):
try:
logging.info(f"mesum_antique_get_full {mesum_id} {antique_id}")
return get_result(data=MesumAntiqueService.get_antique_by_id(mesum_id,antique_id))
except Exception as e:
return get_error_data_result(message=f"Get mesum antique error {e}")
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()