875 lines
36 KiB
Python
875 lines
36 KiB
Python
#
|
||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
#
|
||
from flask import request , Response, jsonify,stream_with_context,send_file
|
||
from api import settings
|
||
from api.db import LLMType
|
||
from api.db import StatusEnum
|
||
from api.db.services.dialog_service import DialogService,stream_manager
|
||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||
from api.db.services.llm_service import TenantLLMService
|
||
from api.db.services.user_service import TenantService
|
||
from api.db.services.brief_service import MesumOverviewService
|
||
from api.db.services.llm_service import LLMBundle
|
||
from api.db.services.antique_service import MesumAntiqueService
|
||
from api.db.services.appinfo_service import AppInfoService
|
||
from api.utils import get_uuid
|
||
from api.utils.api_utils import get_error_data_result, token_required
|
||
from api.utils.api_utils import get_result
|
||
from api.utils.file_utils import get_project_base_directory
|
||
from rag.utils.minio_conn import RAGFlowMinio
|
||
import logging
|
||
import base64, gzip
|
||
import io, re, json
|
||
from io import BytesIO
|
||
import queue,time,uuid,os,array
|
||
from threading import Lock,Thread
|
||
from zhipuai import ZhipuAI
|
||
from openai import OpenAI
|
||
import openai
|
||
from datetime import datetime
|
||
|
||
|
||
# 为APP升级 cyx 2025-04-20
|
||
@manager.route('/get_apk_update/<version>', methods=['GET'])
|
||
@token_required
|
||
def get_apk_update_info(tenant_id,version):
|
||
try:
|
||
# 正则表达式匹配版本号格式
|
||
pattern = r'^\d+(\.\d+)*$'
|
||
if bool(re.match(pattern, version)) is False:
|
||
version = "1.0.0"
|
||
front_app_version_int = int(version.replace('.', '')) # 前端发起来的版本号(数值)
|
||
res = []
|
||
app_infos = AppInfoService.get_all()
|
||
for o in app_infos:
|
||
val = o.to_dict()
|
||
try:
|
||
if val.get("app_version"):
|
||
app_version_int = int(val.get("app_version").replace('.', '')) # 将类似1.0.4 转换为104
|
||
if app_version_int >= front_app_version_int:
|
||
del val['create_time']
|
||
del val['create_date']
|
||
del val['update_time']
|
||
del val['update_date']
|
||
del val['description']
|
||
try:
|
||
# 转换upload_time 格式
|
||
# 将 datetime 对象格式化为 "yyyy-mm-dd hh:mm:ss" 格式
|
||
val['upload_time'] = val.get('upload_time').strftime("%Y-%m-%d %H:%M:%S")
|
||
finally:
|
||
pass
|
||
|
||
res.append(val)
|
||
finally:
|
||
pass
|
||
return get_result(data=res)
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"Get apk update info error {e}")
|
||
|
||
# 用户已经添加的模型 cyx 2025-01-26
|
||
@manager.route('/get_llms', methods=['GET'])
|
||
@token_required
|
||
def my_llms(tenant_id):
|
||
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
|
||
model_type = request.args.get("type")
|
||
device_id = request.args.get('device_id')
|
||
logging.info(f"get llms {model_type} {device_id}")
|
||
try:
|
||
res = {}
|
||
for o in TenantLLMService.get_my_llms(tenant_id):
|
||
if model_type is None or o["model_type"] == model_type: # 增加按类型的筛选
|
||
if o["llm_factory"] not in res:
|
||
res[o["llm_factory"]] = {
|
||
"tags": o["tags"],
|
||
"llm": []
|
||
}
|
||
if o["model_type"].lower() == 'tts': # 20250502 对tts 的模型进行特殊处理,只返回说明中注明成年 和 童声各1个
|
||
|
||
pattern = r"\([^)]*\)" # 匹配非嵌套的简单括号内容 (童声) (成年)
|
||
matches = re.findall(pattern, o['description'])
|
||
if len(matches) == 1:
|
||
voice_type = "adult"
|
||
if "童" in matches[0]:
|
||
voice_type = "child"
|
||
res[o["llm_factory"]]["llm"].append({
|
||
"type": o["model_type"],
|
||
"name": o["llm_name"],
|
||
"used_token": o["used_tokens"],
|
||
"description": o["description"],
|
||
"voice_type": voice_type
|
||
})
|
||
else:
|
||
res[o["llm_factory"]]["llm"].append({
|
||
"type": o["model_type"],
|
||
"name": o["llm_name"],
|
||
"used_token": o["used_tokens"],
|
||
"description":o["description"]
|
||
})
|
||
|
||
return get_result(data=res)
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||
|
||
|
||
main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链"
|
||
@manager.route('/photo/recongeText/<mesum_id>', methods=['POST'])
|
||
@token_required
|
||
def upload_file(tenant_id,mesum_id):
|
||
if 'file' not in request.files:
|
||
return jsonify({'error': 'No file part'}), 400
|
||
antiques_selected = ""
|
||
if mesum_id:
|
||
|
||
mesum_id_str = str(mesum_id)
|
||
labels_with_id = get_labels_with_id(mesum_id)
|
||
antique_labels = ','.join([item['label'] for item in labels_with_id])
|
||
joined_string = antique_labels
|
||
antiques_selected = f"{joined_string}"
|
||
|
||
#logging.info(f"mesumid={mesum_id} {joined_string}")
|
||
|
||
prompt1 = (f"你是一名图片识别和理解助手"
|
||
f"任务是先识别图片中文字,然后理解文字中包含的内容,分析哪一项可以作为识别出文字的标题,"
|
||
f"你的回答有3个结果,第一个结果匹配出的结果,JSON键值为antique"
|
||
f"从下面的候选项:{antiques_selected}进行匹配,每一个候选项中间以';'分割,如果没有任何匹配则结果为'',以免误触发讲解,匹配成功则输出匹配出的内容"
|
||
f",第二个结果是原始识别的所有文字,json 键值为text"
|
||
f"第三个结果是识别出文字与匹配项列表中元素的匹配度,范围从0-1,1表示100%匹配,0表示完全不匹配,JSON键值为match_score,"
|
||
"3个结果输出以{ }的json格式给出,匹配出文物、事件、人物的结果键值为antique"
|
||
f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据")
|
||
|
||
prompt = (
|
||
f"作为图片识别和理解助手,您的任务是:"
|
||
f"\n1. 精确识别图片中的文字内容"
|
||
f"\n2. 理解文字语义"
|
||
f"\n3. 从以下候选标题中选择最佳匹配项:"
|
||
f"\n [{antiques_selected}]"
|
||
f"\n\n### 输出要求:"
|
||
f"\n- 以严格JSON格式输出,包含3个字段:"
|
||
f"\n • `antique`: 匹配的标题(多个用英文分号';'分割,最多匹配3个,无匹配则空字符串)"
|
||
f"\n • `text`: 识别出的完整文字"
|
||
f"\n • `match_score`: 整体匹配度(0-1的浮点数),1=完全匹配"
|
||
f"\n\n### 匹配规则:"
|
||
f"\n1. 语义匹配优先于字面匹配"
|
||
f"\n2. 考虑同义词、近义词和描述性匹配"
|
||
f"\n3. 允许部分匹配(如'青铜酒器'匹配'青铜器')"
|
||
f"\n4. 若无明确匹配项,`antique`返回空字符串"
|
||
f"\n\n### 重要:"
|
||
f"\n- 输出必须是可直接解析的JSON,无任何前置/后置文本"
|
||
f"\n- 匹配度评分需客观反映文本与候选标题的相似度"
|
||
)
|
||
file = request.files['file']
|
||
|
||
if file.filename == '':
|
||
return jsonify({'error': 'No selected file'}), 400
|
||
|
||
if file and allowed_file(file.filename):
|
||
file_size = request.content_length
|
||
img_base = base64.b64encode(file.read()).decode('utf-8')
|
||
req_antique = request.form.get('antique',None)
|
||
if req_antique is None:
|
||
req_antique = main_antiquity
|
||
logging.info(f"recevie photo file {file.filename} {file_size} 识别中.... ")
|
||
# vl_model = "qwen-vl-max-latest"
|
||
vl_model = "glm-4v-plus"
|
||
"""
|
||
client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey
|
||
response = client.chat.completions.create(
|
||
model=vl_model, # 填写需要调用的模型名称
|
||
messages=[
|
||
{
|
||
"role": "user",
|
||
"content": [{"type": "text", "text":prompt}]
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": img_base
|
||
}
|
||
},
|
||
{
|
||
"type": "text",
|
||
"text": "json格式"
|
||
}
|
||
]
|
||
}
|
||
]
|
||
)
|
||
|
||
"""
|
||
client = OpenAI(
|
||
api_key="sk-a47a3fb5f4a94f66bbaf713779101c75",
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
|
||
response = client.chat.completions.create(
|
||
model="qwen-vl-max-latest",
|
||
messages=[
|
||
{
|
||
"role": "system",
|
||
"content": [{"type": "text", "text": prompt}],
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/png;base64,{img_base}"
|
||
},
|
||
}
|
||
],
|
||
},
|
||
],
|
||
)
|
||
|
||
message = response.choices[0].message
|
||
parsed_json_res = parse_markdown_json(message.content)
|
||
parsed_json_data = {"antique": "", "text": "", "match_score": 0}
|
||
matchedArray = []
|
||
if parsed_json_res.get('success') is True:
|
||
parsed_json_data = parsed_json_res.get('data')
|
||
matchedAntiqueArray = parsed_json_data.get('antique').split(';') # 识别出的文物的数组,中间以';'分割,可能有多个
|
||
if len(matchedAntiqueArray) ==1: # 只有一个匹配项,直接返回
|
||
for item in labels_with_id:
|
||
if item['label'] == parsed_json_data.get('antique'):
|
||
parsed_json_data['id'] = item.get('id')
|
||
else: # 有多个匹配项,需要进行多个匹配
|
||
for label in matchedAntiqueArray:
|
||
antique= {'label':label}
|
||
for item in labels_with_id:
|
||
if item['label'] == label:
|
||
antique['id'] = item.get('id')
|
||
matchedArray.append(antique)
|
||
if len(matchedArray) > 0:
|
||
parsed_json_data['matchedArray'] = matchedArray
|
||
logging.info(f"{parsed_json_data}")
|
||
return jsonify({'message': 'File uploaded successfully','text': message.content,
|
||
'data': parsed_json_data}), 200
|
||
|
||
def allowed_file(filename):
|
||
return '.' in filename and \
|
||
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
|
||
|
||
|
||
#get_all
|
||
|
||
@manager.route('/mesum/list', methods=['GET'])
|
||
@token_required
|
||
def mesum_list(tenant_id):
|
||
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
|
||
# model_type = request.args.get("type")
|
||
try:
|
||
res = []
|
||
overviews=MesumOverviewService.get_all()
|
||
for o in overviews:
|
||
res.append(o.to_dict())
|
||
return get_result(data=res)
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||
|
||
@manager.route('/mesum/set_antique', methods=['POST'])
|
||
@token_required
|
||
def mesum_set_antique(tenant_id):
|
||
global main_antiquity
|
||
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
|
||
req_data = request.json
|
||
req_data_antique=req_data.get('antique',None)
|
||
try:
|
||
if req_data_antique:
|
||
main_antiquity = req_data_antique
|
||
print(main_antiquity)
|
||
return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'})
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||
|
||
audio_text_cache = {}
|
||
cache_lock = Lock()
|
||
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
||
# 全角字符到半角字符的映射
|
||
def fullwidth_to_halfwidth(s):
|
||
full_to_half_map = {
|
||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||
'、': '、', '・': '・', ':': ':'
|
||
}
|
||
return ''.join(full_to_half_map.get(char, char) for char in s)
|
||
|
||
def split_text_at_punctuation(text, chunk_size=100):
|
||
# 使用正则表达式找到所有的标点符号和特殊字符
|
||
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
||
tokens = re.split(punctuation_pattern, text)
|
||
|
||
# 移除空字符串
|
||
tokens = [token for token in tokens if token]
|
||
|
||
# 存储最终的文本块
|
||
chunks = []
|
||
current_chunk = ''
|
||
|
||
for token in tokens:
|
||
if len(current_chunk) + len(token) <= chunk_size:
|
||
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
||
current_chunk += (token + ' ')
|
||
else:
|
||
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
||
chunks.append(current_chunk.strip())
|
||
current_chunk = token + ' '
|
||
|
||
# 添加最后一个块(如果有剩余)
|
||
if current_chunk:
|
||
chunks.append(current_chunk.strip())
|
||
|
||
return chunks
|
||
|
||
def extract_text_from_markdown(markdown_text):
|
||
# 移除Markdown标题
|
||
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
||
# 移除内联代码块
|
||
text = re.sub(r'`[^`]+`', '', text)
|
||
# 移除代码块
|
||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||
# 移除加粗和斜体
|
||
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
||
# 移除链接
|
||
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
||
# 移除图片
|
||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||
# 移除HTML标签
|
||
text = re.sub(r'<[^>]+>', '', text)
|
||
# 转换标点符号
|
||
# text = re.sub(r'[^\w\s]', '', text)
|
||
text = fullwidth_to_halfwidth(text)
|
||
# 移除多余的空格
|
||
text = re.sub(r'\s+', ' ', text).strip()
|
||
|
||
return text
|
||
|
||
def encode_gzip_base64(original_data: bytes) -> str:
|
||
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
|
||
# Step 1: Gzip 压缩
|
||
with BytesIO() as buf:
|
||
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
|
||
gz_file.write(original_data)
|
||
compressed_bytes = buf.getvalue()
|
||
|
||
# Step 2: Base64 编码(配置与Android端匹配)
|
||
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP)
|
||
|
||
def clean_audio_cache():
|
||
"""定时清理过期缓存"""
|
||
with cache_lock:
|
||
now = time.time()
|
||
expired_keys = [
|
||
k for k, v in audio_text_cache.items()
|
||
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
|
||
]
|
||
for k in expired_keys:
|
||
entry = audio_text_cache.pop(k, None)
|
||
if entry and entry.get('audio_stream'):
|
||
entry['audio_stream'].close()
|
||
|
||
|
||
def start_background_cleaner():
|
||
"""启动后台清理线程"""
|
||
|
||
def cleaner_loop():
|
||
while True:
|
||
time.sleep(180) # 每3分钟清理一次
|
||
clean_audio_cache()
|
||
|
||
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
|
||
cleaner_thread.start()
|
||
|
||
# 应用启动时启动清理线程
|
||
start_background_cleaner()
|
||
|
||
# 在返回大模型对话的文本中,同时生成tts音频,由dialog_service 中的StreamSessionManager进行管理
|
||
# session_id 为 def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
|
||
# session_id = str(uuid.uuid4())
|
||
@manager.route('/tts_stream/<session_id>',methods=['GET'])
|
||
def tts_stream(session_id):
|
||
session = stream_manager.sessions.get(session_id)
|
||
logging.info(f"--tts_stream {session}")
|
||
if session is None:
|
||
return get_error_data_result(message="Audio stream not found or expired.")
|
||
|
||
def generate():
|
||
total_audio_strean_length = 0
|
||
count = 0;
|
||
finished_event = session['finished']
|
||
try:
|
||
while not finished_event.is_set() :
|
||
if not session or not session['active']:
|
||
break
|
||
try:
|
||
chunk = session['buffer'].get_nowait() #
|
||
count = count + 1
|
||
if isinstance(chunk, str) and chunk.startswith("ERROR"):
|
||
logging.info(f"---tts stream error!!!! {chunk}")
|
||
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
|
||
break
|
||
if session['stream_format'] == "wav":
|
||
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
|
||
yield gzip_base64_data
|
||
else:
|
||
total_audio_strean_length = total_audio_strean_length + len(chunk)
|
||
yield chunk
|
||
retry_count = 0 # 成功收到数据重置重试计数器
|
||
except queue.Empty:
|
||
if session['stream_format'] == "wav":
|
||
pass
|
||
else:
|
||
pass
|
||
except Exception as e:
|
||
logging.info(f"tts streag get error2 {e} ")
|
||
|
||
|
||
finally:
|
||
# 确保流结束后关闭会话
|
||
if session:
|
||
# 延迟关闭会话,确保所有数据已发送
|
||
stream_manager.close_session(session_id)
|
||
# 关键响应头设置
|
||
|
||
if session['stream_format'] == "wav":
|
||
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
|
||
else:
|
||
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
|
||
resp.headers.add_header("Accept-Ranges", "bytes")
|
||
resp.headers.add_header("Cache-Control", "no-cache")
|
||
resp.headers.add_header("Connection", "keep-alive")
|
||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||
return resp
|
||
|
||
def generate_mp3_header(bitrate_kbps=128, padding=0):
|
||
# 字段定义
|
||
sync = 0b11111111111 # 同步字(11位)
|
||
version = 0b11 # MPEG-1(2位)
|
||
layer = 0b01 # Layer III(2位)
|
||
protection = 0b0 # 无CRC(1位)
|
||
bitrate_index = { # 比特率索引表(MPEG-1 Layer III)
|
||
32: 0b0001, 40:0b0010, 48:0b0011, 56:0b0100,
|
||
64:0b0101, 80:0b0110, 96:0b0111, 112:0b1000,
|
||
128:0b1001, 160:0b1010, 192:0b1011, 224:0b1100,
|
||
256:0b1101, 320:0b1110
|
||
}[bitrate_kbps]
|
||
sampling_rate = 0b00 # 44.1kHz(2位)
|
||
padding_bit = padding # 填充位(1位)
|
||
private = 0b0 # 私有位(1位)
|
||
mode = 0b11 # 单声道(2位)
|
||
mode_ext = 0b00 # 扩展模式(2位)
|
||
copyright = 0b0 # 无版权(1位)
|
||
original = 0b0 # 非原版(1位)
|
||
emphasis = 0b00 # 无强调(2位)
|
||
|
||
# 组合为32位整数(大端序)
|
||
header = (
|
||
(sync << 21) |
|
||
(version << 19) |
|
||
(layer << 17) |
|
||
(protection << 16) |
|
||
(bitrate_index << 12) |
|
||
(sampling_rate << 10) |
|
||
(padding_bit << 9) |
|
||
(private << 8) |
|
||
(mode << 6) |
|
||
(mode_ext << 4) |
|
||
(copyright << 3) |
|
||
(original << 2) |
|
||
emphasis
|
||
)
|
||
# 转换为4字节二进制数据
|
||
return header.to_bytes(4, byteorder='big')
|
||
|
||
@manager.route('/chats/<chat_id>/audio/pcm_mp3', methods=['GET'])
|
||
def audio_test_mp3_stream(chat_id):
|
||
logging.info(f"--audio_test_mp3_stream--{chat_id}")
|
||
file_path = os.path.join(get_project_base_directory(), "api", "apps/sdk/test.mp3")
|
||
file_size = os.path.getsize(file_path)
|
||
# 设置Last-Modified头
|
||
last_modified = datetime(2025, 4, 17, 0, 54, 40)
|
||
def generate_test_mp3():
|
||
data_length = 0
|
||
total = 0
|
||
with open(file_path, 'rb') as fmp3:
|
||
data = fmp3.read(1024)
|
||
data_length = data_length + 1024
|
||
while data:
|
||
yield data
|
||
# time.sleep(0.5)
|
||
data = fmp3.read(1024)
|
||
data_length = data_length + 1024
|
||
if data_length > 240000:
|
||
logging.info(f"sleep 5s {data_length}")
|
||
total = total + data_length
|
||
data_length = 0
|
||
# time.sleep(2)
|
||
# print("end sleep",total
|
||
if total > 130000:
|
||
print("end")
|
||
break;
|
||
|
||
return Response(generate_test_mp3(), mimetype="audio/mpeg",headers={
|
||
'Accept-Ranges':'bytes'
|
||
})
|
||
|
||
|
||
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
|
||
def dialog_tts_get(chat_id, audio_stream_id):
|
||
global audio_text_cache
|
||
# logging.info(f"---dialog_tts_get--0 {audio_text_cache} {audio_stream_id}")
|
||
with cache_lock:
|
||
tts_info = audio_text_cache.get(audio_stream_id) # 取出即删除
|
||
try:
|
||
req = tts_info
|
||
if not req:
|
||
return get_error_data_result(message="Audio stream not found or expired.")
|
||
audio_stream = req.get('audio_stream')
|
||
tenant_id = req.get('tenant_id')
|
||
chat_id = req.get('chat_id')
|
||
text = req.get('text', "..")
|
||
|
||
model_name = req.get('model_name')
|
||
sample_rate = req.get('tts_sample_rate',8000) # 默认8K
|
||
stream_format = req.get('tts_stream_format','mp3')
|
||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||
if not dia:
|
||
return get_error_data_result(message="You do not own the chat")
|
||
tts_model_name = dia.tts_id
|
||
if model_name: tts_model_name = model_name
|
||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||
# logging.info(f"dialog_tts_get {sample_rate} {stream_format} {text}")
|
||
|
||
def stream_audio():
|
||
if stream_format == 'mp3':
|
||
yield generate_mp3_header()
|
||
try:
|
||
for chunk in tts_mdl.tts(text,sample_rate=sample_rate,stream_format=stream_format):
|
||
if stream_format =='wav':
|
||
#logging.info(f"yield audio data {len(chunk)} ")
|
||
yield encode_gzip_base64(chunk) + "\r\n"
|
||
else:
|
||
yield chunk
|
||
except Exception as e:
|
||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
||
"data": {"answer": "**ERROR**: " + str(e)}},
|
||
ensure_ascii=False)).encode('utf-8')
|
||
def generate():
|
||
data = audio_stream.read(1024)
|
||
while data:
|
||
yield data
|
||
data = audio_stream.read(1024)
|
||
logging.info("audio stream end")
|
||
try:
|
||
del audio_text_cache[audio_stream_id]
|
||
finally:
|
||
pass
|
||
|
||
if audio_stream:
|
||
# 确保流的位置在开始处
|
||
audio_stream.seek(0)
|
||
if stream_format == 'wav':
|
||
resp = Response(generate(), mimetype="audio/wav")
|
||
else:
|
||
headers = {
|
||
'Content-Type': 'audio/mpeg',
|
||
'Content-Length': str(tts_info.get('chunk_size',2048)),
|
||
'Accept-Ranges': 'bytes'
|
||
}
|
||
resp = Response(generate(),
|
||
#mimetype="audio/mpeg",
|
||
headers = headers
|
||
)
|
||
else:
|
||
if stream_format == 'wav':
|
||
resp = Response(stream_audio(), mimetype="audio/wav")
|
||
else:
|
||
resp = Response(
|
||
stream_audio(),
|
||
mimetype="audio/mpeg"
|
||
)
|
||
resp.headers.add_header("Accept-Ranges", "bytes")
|
||
resp.headers.add_header("Cache-Control", "no-cache")
|
||
resp.headers.add_header("Connection", "keep-alive")
|
||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||
return resp
|
||
except Exception as e:
|
||
logging.error(f"音频流传输错误: {str(e)}", exc_info=True)
|
||
return get_error_data_result(message="音频流传输失败")
|
||
finally:
|
||
# 确保资源释放
|
||
# if tts_info and tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
|
||
# tts_info['audio_stream'].close()
|
||
pass
|
||
|
||
|
||
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
|
||
@token_required
|
||
def dialog_tts_post(tenant_id, chat_id):
|
||
global audio_text_cache
|
||
req = request.json
|
||
try:
|
||
if not req.get("text"):
|
||
return get_error_data_result(message="Please input your question.")
|
||
text = req.get('text')
|
||
delay_gen_audio = req.get('delay_gen_audio', False)
|
||
model_name = req.get('model_name') # 示例:"cosyvoice-v1/longyuan@Tongyi-Qianwen" "sambert-zhiru-v1@Tongyi-Qianwen"
|
||
audio_stream_id = req.get('audio_stream_id', None)
|
||
if audio_stream_id is None:
|
||
audio_stream_id = str(uuid.uuid4())
|
||
# 在这里生成音频流并存储到内存中
|
||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||
tts_model_name = dia.tts_id
|
||
if model_name: tts_model_name = model_name
|
||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||
if delay_gen_audio:
|
||
audio_stream = None
|
||
else:
|
||
audio_stream = io.BytesIO()
|
||
tts_stream_format = req.get('tts_stream_format', "mp3")
|
||
tts_sample_rate = req.get('tts_sample_rate', 8000)
|
||
logging.info(f"tts post {tts_sample_rate} {tts_stream_format}")
|
||
# 结构化缓存数据
|
||
tts_info = {
|
||
'text': text,
|
||
'tenant_id': tenant_id,
|
||
'chat_id': chat_id,
|
||
'created_at': time.time(),
|
||
'audio_stream': audio_stream, # 维持原有逻辑
|
||
'model_name': req.get('model_name'),
|
||
'delay_gen_audio': delay_gen_audio, # 明确存储状态
|
||
'audio_stream_id': audio_stream_id,
|
||
'tts_sample_rate':tts_sample_rate,
|
||
'tts_stream_format':tts_stream_format
|
||
}
|
||
total_chunk_size = 0
|
||
if delay_gen_audio is False:
|
||
try:
|
||
audio_stream.seek(0, io.SEEK_END)
|
||
if text is None or text.strip() == "":
|
||
audio_stream.write(b'\x00' * 100)
|
||
else:
|
||
# 确保在流的末尾写入
|
||
for chunk in tts_mdl.tts(text,sample_rate=tts_sample_rate,stream_format=tts_stream_format):
|
||
total_chunk_size = total_chunk_size + len(chunk)
|
||
audio_stream.write(chunk)
|
||
audio_stream.seek(0)
|
||
except Exception as e:
|
||
logging.info(f"--error:{e}")
|
||
with cache_lock:
|
||
audio_text_cache.pop(audio_stream_id, None)
|
||
return get_error_data_result(message="get tts audio stream error.")
|
||
|
||
with cache_lock:
|
||
tts_info['chunk_size'] = total_chunk_size
|
||
audio_text_cache[audio_stream_id] = tts_info
|
||
# 构建音频流URL
|
||
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
|
||
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url} "
|
||
f"{tts_sample_rate} {tts_stream_format}")
|
||
# 返回音频流URL
|
||
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id,
|
||
"sample_rate":tts_sample_rate, "stream_format":tts_stream_format,
|
||
"ws_url":audio_stream_url})
|
||
|
||
except Exception as e:
|
||
logging.error(f"请求处理失败: {str(e)}", exc_info=True)
|
||
return get_error_data_result(message="服务器内部错误")
|
||
|
||
def get_antique_categories(mesum_id):
|
||
res = MesumAntiqueService.get_all_categories(mesum_id)
|
||
return res
|
||
|
||
|
||
def get_labels_ext(mesum_id):
|
||
res = MesumAntiqueService.get_labels_ext(mesum_id)
|
||
return res
|
||
|
||
|
||
def get_labels_with_id(mesum_id):
|
||
res = MesumAntiqueService.get_labels_with_id(mesum_id)
|
||
return res
|
||
|
||
|
||
def get_antique_labels(mesum_id):
|
||
res = MesumAntiqueService.get_all_labels()
|
||
return res
|
||
|
||
|
||
def get_all_antiques(mesum_id):
|
||
res =[]
|
||
antiques=MesumAntiqueService.get_by_mesum_id(mesum_id)
|
||
for o in antiques:
|
||
res.append(o.to_dict())
|
||
return res
|
||
|
||
|
||
@manager.route('/mesum/antique/<mesum_id>', methods=['GET'])
|
||
def mesum_antique_get(mesum_id):
|
||
try:
|
||
data = {
|
||
"anqituqes":get_all_antiques(mesum_id),
|
||
"categories":get_antique_categories(mesum_id),
|
||
"labels":get_antique_labels(mesum_id)
|
||
}
|
||
return get_result(data=data)
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"Get mesum antique error {e}")
|
||
|
||
# 按照mesum_id 获得此博物馆的展品清单
|
||
@manager.route('/mesum/antique_brief/<mesum_id>', methods=['GET'])
|
||
@token_required
|
||
def mesum_antique_get_brief(tenant_id,mesum_id):
|
||
try:
|
||
data = {
|
||
"categories":get_antique_categories(mesum_id),
|
||
"labels":get_labels_ext(mesum_id)
|
||
}
|
||
return get_result(data=data)
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"Get mesum antique error {e}")
|
||
|
||
@manager.route('/mesum/antique_detail/<mesum_id>/<antique_id>', methods=['GET'])
|
||
@token_required
|
||
def mesum_antique_get_full(tenant_id,mesum_id,antique_id):
|
||
try:
|
||
logging.info(f"mesum_antique_get_full {mesum_id} {antique_id}")
|
||
antique_detail =MesumAntiqueService.get_antique_by_id(mesum_id,antique_id)
|
||
# 这里是得到事先生成的tts文件地址,需要根据需要,返回正确,可能需要根据前端的要求返回不同的地址
|
||
if antique_detail.get('ttsUrl_adult'):
|
||
antique_detail['ttsUrl'] = antique_detail.get('ttsUrl_adult')
|
||
return get_result(data=antique_detail)
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"Get mesum antique error {e}")
|
||
|
||
@manager.route('/mesum/antique/<action>/<mesum_id>/<antique_id>', methods=['POST'])
|
||
@token_required
|
||
def mesum_antique_action(tenant_id,action,mesum_id,antique_id):
|
||
req = request.json
|
||
try:
|
||
logging.info(f"mesum_antique_action {action} {mesum_id} {antique_id}")
|
||
if action.lower() not in ['rm','update','insert']:
|
||
return get_result(data={"error": f"{action} not supported"})
|
||
|
||
if action.lower() == "rm": # 删除
|
||
res=MesumAntiqueService.delete_by_id(antique_id)
|
||
return get_result(data={"rm":res})
|
||
if action.lower() == "update":
|
||
record = req
|
||
logging.info(f"mesum_antique_action {action} {record}")
|
||
res = MesumAntiqueService.update_by_id(antique_id,record)
|
||
logging.info(f"mesum_antique_action return {action} {record}")
|
||
return get_result(data={"update": res})
|
||
if action.lower() == "insert":
|
||
record = req
|
||
logging.info(f"mesum_antique_action {action} {record}")
|
||
res = MesumAntiqueService.insert(**record)
|
||
return get_result(data={"insert": res})
|
||
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"antique {action} error {e}")
|
||
|
||
# 20250428 增加操作minio 的调用API
|
||
minio_client = RAGFlowMinio()
|
||
@manager.route('/minio/check', methods=['POST'])
|
||
@token_required
|
||
def minio_check_obj(tenant_id):
|
||
req = request.json
|
||
try:
|
||
is_exist = minio_client.obj_exist(req.get('bucket'),req.get('file_name'))
|
||
return get_result(data={"is_exits":is_exist})
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"minio check object error {e}")
|
||
|
||
@manager.route('/minio/get/<bucket>/<file_name>', methods=['GET'])
|
||
@token_required
|
||
def minio_get_obj(tenant_id,bucket,file_name):
|
||
try:
|
||
res= minio_client.get(bucket,file_name)
|
||
return get_result(data={"binary":res})
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"minio get object error {e}")
|
||
|
||
@manager.route('/minio/rm', methods=['POST'])
|
||
@token_required
|
||
def minio_rm_obj(tenant_id):
|
||
req = request.json
|
||
try:
|
||
minio_client.rm(req.get('bucket'), req.get('file_name'))
|
||
return get_result(data={"rm": f"{req.get('bucket')}{req.get('file_name')}"})
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"minio rm object error {e}")
|
||
|
||
@manager.route('/minio/put', methods=['POST'])
|
||
@token_required
|
||
def minio_put_obj(tenant_id):
|
||
req = request.form
|
||
if 'file' not in request.files:
|
||
return jsonify({'error': 'No file part'}), 400
|
||
try:
|
||
file = request.files['file']
|
||
bucket = req.get('bucket')
|
||
file_name = req.get('file_name')
|
||
binary = file.read()
|
||
res=minio_client.put(bucket,file_name ,binary)
|
||
return get_result(data={"put": f"{res}",'url':f"http://1.13.185.116:9000/{bucket}/{file_name}"})
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"minio put object error {e}")
|
||
|
||
@manager.route('/minio/list/<bucket>/<prefix>', methods=['GET'])
|
||
@token_required
|
||
def list_objects(tenant_id,bucket: str, prefix: str = "", recursive: bool = True):
|
||
try:
|
||
result=minio_client.list_objects(bucket,prefix ,True)
|
||
return get_result(data=result)
|
||
except Exception as e:
|
||
return get_error_data_result(message=f"minio put list objects error {e}")
|
||
|
||
#------------------------------------------------
|
||
def audio_fade_in(audio_data, fade_length):
|
||
# 假设音频数据是16位单声道PCM
|
||
# 将二进制数据转换为整数数组
|
||
samples = array.array('h', audio_data)
|
||
|
||
# 对前fade_length个样本进行淡入处理
|
||
for i in range(fade_length):
|
||
fade_factor = i / fade_length
|
||
samples[i] = int(samples[i] * fade_factor)
|
||
|
||
# 将整数数组转换回二进制数据
|
||
return samples.tobytes( )
|
||
|
||
|
||
def parse_markdown_json(json_string):
|
||
# 使用正则表达式匹配Markdown中的JSON代码块
|
||
match = re.search(r'```json\n(.*?)\n```', json_string, re.DOTALL)
|
||
if match:
|
||
try:
|
||
# 尝试解析JSON字符串
|
||
data = json.loads(match[1])
|
||
return {'success': True, 'data': data }
|
||
except json.JSONDecodeError as e:
|
||
# 如果解析失败,返回错误信息
|
||
return {'success': False, 'data': str(e)}
|
||
else:
|
||
return {'success': False, 'data': 'not a valid markdown json string'} |