增加加了博物馆展品清单数据库及对前端获取展品清单、展品详细的接口,增加了QWenOmni多模态大模型的支 持(主要为了测试),增加了本地部署大模型支持(主要为了测试,在autoDL上),修正了TTS生成和返回前端的逻辑与参数,增加了判断用户问题有没有在知识库中检索到相关片段、如果没有则直接返回并提示未包含
This commit is contained in:
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request , Response, jsonify
|
||||
from flask import request , Response, jsonify,stream_with_context
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db import StatusEnum
|
||||
@@ -23,12 +23,15 @@ from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.brief_service import MesumOverviewService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.antique_service import MesumAntiqueService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
import logging
|
||||
import base64
|
||||
import queue,time,uuid
|
||||
import base64, gzip
|
||||
from io import BytesIO
|
||||
import queue,time,uuid,os,array
|
||||
from threading import Lock,Thread
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
@@ -59,12 +62,37 @@ def my_llms(tenant_id):
|
||||
|
||||
|
||||
main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链"
|
||||
@manager.route('/photo/recongeText', methods=['POST'])
|
||||
@manager.route('/photo/recongeText/<mesum_id>', methods=['POST'])
|
||||
@token_required
|
||||
def upload_file(tenant_id):
|
||||
def upload_file(tenant_id,mesum_id):
|
||||
if 'file' not in request.files:
|
||||
return jsonify({'error': 'No file part'}), 400
|
||||
antiques_selected = ""
|
||||
if mesum_id:
|
||||
"""
|
||||
e,mesum_breif = MesumOverviewService.get_by_id(mesum_id)
|
||||
if not e:
|
||||
logging.info(f"没有找到匹配的博物馆信息,mesum_id={mesum_id}")
|
||||
else:
|
||||
antiques_selected =f"结果从:{mesum_breif.antique} 中进行选择"
|
||||
"""
|
||||
mesum_id_str = str(mesum_id)
|
||||
antique_labels=get_antique_labels(mesum_id)
|
||||
# 使用列表推导式和str()函数将所有元素转换为字符串
|
||||
string_elements = [str(element) for element in antique_labels]
|
||||
# 使用join()方法将字符串元素连接起来,以逗号为分隔符
|
||||
joined_string = ','.join(string_elements)
|
||||
antiques_selected = f"结果从:{joined_string} 中进行选择"
|
||||
|
||||
logging.info(f"{mesum_id} {joined_string}")
|
||||
prompt = (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
|
||||
f"请识别这个图片中文字,重点识别出含在文字中的某一文物标题、某一个历史事件或某一历史人物,"
|
||||
f"你的回答有2个结果,第一个结果是是从文字中识别出历史文物、历史事件、历史人物,"
|
||||
f"此回答时只给出匹配的文物、事件、人物,不需要其他多余的文字,{antiques_selected}"
|
||||
f",第二个结果是原始识别的所有文字"
|
||||
"2个结果输出以{ }的json格式给出,匹配文物、事件、人物的键值为antique,如果有多个请加序号,如:antique1,antique2,"
|
||||
f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据")
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
if file.filename == '':
|
||||
@@ -92,14 +120,7 @@ def upload_file(tenant_id):
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
|
||||
f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称,"
|
||||
f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍"
|
||||
f"你的回答有2个结果,第一个结果是是从文字进行分析出匹配文物,候选文物只能如下:{req_antique},"
|
||||
f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出,"
|
||||
f",第二个结果是原始识别的所有文字"
|
||||
"2个结果输出以{ }的json格式给出,匹配文物的键值为antique,如果有多个请加序号,如:antique1,antique2,"
|
||||
f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据")
|
||||
"text": prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -213,6 +234,17 @@ def extract_text_from_markdown(markdown_text):
|
||||
|
||||
return text
|
||||
|
||||
def encode_gzip_base64(original_data: bytes) -> str:
|
||||
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
|
||||
# Step 1: Gzip 压缩
|
||||
with BytesIO() as buf:
|
||||
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
|
||||
gz_file.write(original_data)
|
||||
compressed_bytes = buf.getvalue()
|
||||
|
||||
# Step 2: Base64 编码(配置与Android端匹配)
|
||||
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP)
|
||||
|
||||
def clean_audio_cache():
|
||||
"""定时清理过期缓存"""
|
||||
with cache_lock:
|
||||
@@ -241,38 +273,55 @@ def start_background_cleaner():
|
||||
# 应用启动时启动清理线程
|
||||
start_background_cleaner()
|
||||
|
||||
@manager.route('/tts_stream/<session_id>')
|
||||
@manager.route('/tts_stream/<session_id>',methods=['GET'])
|
||||
def tts_stream(session_id):
|
||||
session = stream_manager.sessions.get(session_id)
|
||||
def generate():
|
||||
retry_count = 0
|
||||
session = None
|
||||
count = 0;
|
||||
path = os.path.join(get_project_base_directory(), "api", "apps/sdk/test.mp3")
|
||||
fmp3 =open(path, 'rb')
|
||||
finished_event = session['finished']
|
||||
try:
|
||||
while retry_count < 1:
|
||||
session = stream_manager.sessions.get(session_id)
|
||||
while not finished_event.is_set() :
|
||||
if not session or not session['active']:
|
||||
break
|
||||
try:
|
||||
chunk = session['buffer'].get(timeout=5) # 30秒超时
|
||||
chunk = session['buffer'].get_nowait() #
|
||||
count = count + 1
|
||||
if isinstance(chunk, str) and chunk.startswith("ERROR"):
|
||||
logging.info("---tts stream error!!!!")
|
||||
logging.info(f"---tts stream error!!!! {chunk}")
|
||||
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
|
||||
break
|
||||
yield chunk
|
||||
if session['stream_format'] == "wav":
|
||||
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
|
||||
yield gzip_base64_data
|
||||
else:
|
||||
yield chunk
|
||||
retry_count = 0 # 成功收到数据重置重试计数器
|
||||
except queue.Empty:
|
||||
retry_count += 1
|
||||
yield b'' # 保持连接
|
||||
if session['stream_format'] == "wav":
|
||||
# yield encode_gzip_base64(b'\x03\x04' * 1) + "\r\n"
|
||||
pass
|
||||
else:
|
||||
yield b'' # 保持连接
|
||||
#data = fmp3.read(1024)
|
||||
#yield data
|
||||
except Exception as e:
|
||||
logging.info(f"tts streag get error2 {e} ")
|
||||
|
||||
|
||||
finally:
|
||||
# 确保流结束后关闭会话
|
||||
if session:
|
||||
# 延迟关闭会话,确保所有数据已发送
|
||||
time.sleep(5) # 等待5秒确保流结束
|
||||
stream_manager.close_session(session_id)
|
||||
logging.info(f"Session {session_id} closed.")
|
||||
# 关键响应头设置
|
||||
|
||||
resp = Response(generate(), mimetype="audio/mpeg")
|
||||
if session['stream_format'] == "wav":
|
||||
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
|
||||
else:
|
||||
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
@@ -291,17 +340,23 @@ def dialog_tts_get(chat_id, audio_stream_id):
|
||||
chat_id = req.get('chat_id')
|
||||
text = req.get('text', "..")
|
||||
model_name = req.get('model_name')
|
||||
sample_rate = req.get('tts_sample_rate',8000) # 默认8K
|
||||
stream_format = req.get('tts_stream_format','mp3')
|
||||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
tts_model_name = dia.tts_id
|
||||
if model_name: tts_model_name = model_name
|
||||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||||
logging.info(f"dialog_tts_get {sample_rate} {stream_format}")
|
||||
|
||||
def stream_audio():
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
yield chunk
|
||||
for chunk in tts_mdl.tts(text,sample_rate=sample_rate,stream_format=stream_format):
|
||||
if stream_format =='wav':
|
||||
yield encode_gzip_base64(chunk) + "\r\n"
|
||||
else:
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e)}},
|
||||
@@ -318,7 +373,10 @@ def dialog_tts_get(chat_id, audio_stream_id):
|
||||
audio_stream.seek(0)
|
||||
resp = Response(generate(), mimetype="audio/mpeg")
|
||||
else:
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
if stream_format == 'wav':
|
||||
resp = Response(stream_audio(), mimetype="audio/wav")
|
||||
else:
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
@@ -328,20 +386,19 @@ def dialog_tts_get(chat_id, audio_stream_id):
|
||||
return get_error_data_result(message="音频流传输失败")
|
||||
finally:
|
||||
# 确保资源释放
|
||||
if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
|
||||
if tts_info and tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
|
||||
tts_info['audio_stream'].close()
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
|
||||
@token_required
|
||||
def dialog_tts_post(tenant_id, chat_id):
|
||||
req = request.json
|
||||
try:
|
||||
req = request.json
|
||||
if not req.get("text"):
|
||||
return get_error_data_result(message="Please input your question.")
|
||||
delay_gen_audio = req.get('delay_gen_audio', False)
|
||||
# text = extract_text_from_markdown(req.get('text'))
|
||||
text = req.get('text')
|
||||
delay_gen_audio = req.get('delay_gen_audio', False)
|
||||
model_name = req.get('model_name')
|
||||
audio_stream_id = req.get('audio_stream_id', None)
|
||||
if audio_stream_id is None:
|
||||
@@ -355,6 +412,10 @@ def dialog_tts_post(tenant_id, chat_id):
|
||||
audio_stream = None
|
||||
else:
|
||||
audio_stream = io.BytesIO()
|
||||
|
||||
tts_stream_format = req.get('tts_stream_format', "mp3")
|
||||
tts_sample_rate = req.get('tts_sample_rate', 8000)
|
||||
logging.info(f"tts post {tts_sample_rate} {tts_stream_format}")
|
||||
# 结构化缓存数据
|
||||
tts_info = {
|
||||
'text': text,
|
||||
@@ -364,30 +425,21 @@ def dialog_tts_post(tenant_id, chat_id):
|
||||
'audio_stream': audio_stream, # 维持原有逻辑
|
||||
'model_name': req.get('model_name'),
|
||||
'delay_gen_audio': delay_gen_audio, # 明确存储状态
|
||||
audio_stream_id: audio_stream_id
|
||||
'audio_stream_id': audio_stream_id,
|
||||
'tts_sample_rate':tts_sample_rate,
|
||||
'tts_stream_format':tts_stream_format
|
||||
}
|
||||
|
||||
with cache_lock:
|
||||
audio_text_cache[audio_stream_id] = tts_info
|
||||
|
||||
if delay_gen_audio is False:
|
||||
try:
|
||||
"""
|
||||
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||
try:
|
||||
if txt is None or txt.strip() == "":
|
||||
continue
|
||||
for chunk in tts_mdl.tts(txt):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
continue
|
||||
"""
|
||||
audio_stream.seek(0, io.SEEK_END)
|
||||
if text is None or text.strip() == "":
|
||||
audio_stream.write(b'\x00' * 100)
|
||||
else:
|
||||
# 确保在流的末尾写入
|
||||
audio_stream.seek(0, io.SEEK_END)
|
||||
for chunk in tts_mdl.tts(text):
|
||||
for chunk in tts_mdl.tts(text,sample_rate=tts_sample_rate,stream_formate=tts_stream_format):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
logging.info(f"--error:{e}")
|
||||
@@ -397,10 +449,79 @@ def dialog_tts_post(tenant_id, chat_id):
|
||||
|
||||
# 构建音频流URL
|
||||
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
|
||||
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
|
||||
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url} "
|
||||
f"{tts_sample_rate} {tts_stream_format}")
|
||||
# 返回音频流URL
|
||||
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
|
||||
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id,
|
||||
"sample_rate":tts_sample_rate, "stream_format":tts_stream_format,})
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"请求处理失败: {str(e)}", exc_info=True)
|
||||
return get_error_data_result(message="服务器内部错误")
|
||||
|
||||
def get_antique_categories(mesum_id):
|
||||
res = MesumAntiqueService.get_all_categories()
|
||||
return res
|
||||
|
||||
def get_labels_ext(mesum_id):
|
||||
res = MesumAntiqueService.get_labels_ext(mesum_id)
|
||||
return res
|
||||
|
||||
def get_antique_labels(mesum_id):
|
||||
res = MesumAntiqueService.get_all_labels()
|
||||
return res
|
||||
|
||||
def get_all_antiques(mesum_id):
|
||||
res =[]
|
||||
antiques=MesumAntiqueService.get_by_mesum_id(mesum_id)
|
||||
for o in antiques:
|
||||
res.append(o.to_dict())
|
||||
return res
|
||||
|
||||
|
||||
@manager.route('/mesum/antique/<mesum_id>', methods=['GET'])
|
||||
def mesum_antique_get(mesum_id):
|
||||
try:
|
||||
data = {
|
||||
"anqituqes":get_all_antiques(mesum_id),
|
||||
"categories":get_antique_categories(mesum_id),
|
||||
"labels":get_antique_labels(mesum_id)
|
||||
}
|
||||
return get_result(data=data)
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Get mesum antique error {e}")
|
||||
|
||||
# 按照mesum_id 获得此博物馆的展品清单
|
||||
@manager.route('/mesum/antique_brief/<mesum_id>', methods=['GET'])
|
||||
@token_required
|
||||
def mesum_antique_get_brief(tenant_id,mesum_id):
|
||||
try:
|
||||
data = {
|
||||
"categories":get_antique_categories(mesum_id),
|
||||
"labels":get_labels_ext(mesum_id)
|
||||
}
|
||||
return get_result(data=data)
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Get mesum antique error {e}")
|
||||
|
||||
@manager.route('/mesum/antique_detail/<mesum_id>/<antique_id>', methods=['GET'])
|
||||
@token_required
|
||||
def mesum_antique_get_full(tenant_id,mesum_id,antique_id):
|
||||
try:
|
||||
logging.info(f"mesum_antique_get_full {mesum_id} {antique_id}")
|
||||
return get_result(data=MesumAntiqueService.get_antique_by_id(mesum_id,antique_id))
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Get mesum antique error {e}")
|
||||
|
||||
def audio_fade_in(audio_data, fade_length):
|
||||
# 假设音频数据是16位单声道PCM
|
||||
# 将二进制数据转换为整数数组
|
||||
samples = array.array('h', audio_data)
|
||||
|
||||
# 对前fade_length个样本进行淡入处理
|
||||
for i in range(fade_length):
|
||||
fade_factor = i / fade_length
|
||||
samples[i] = int(samples[i] * fade_factor)
|
||||
|
||||
# 将整数数组转换回二进制数据
|
||||
return samples.tobytes()
|
||||
BIN
api/apps/sdk/test.mp3
Normal file
BIN
api/apps/sdk/test.mp3
Normal file
Binary file not shown.
Reference in New Issue
Block a user