增加加了博物馆展品清单数据库及对前端获取展品清单、展品详细的接口,增加了QWenOmni多模态大模型的支 持(主要为了测试),增加了本地部署大模型支持(主要为了测试,在autoDL上),修正了TTS生成和返回前端的逻辑与参数,增加了判断用户问题有没有在知识库中检索到相关片段、如果没有则直接返回并提示未包含

This commit is contained in:
qcloud
2025-04-08 08:41:07 +08:00
parent a5e83f4d3b
commit 330976812d
12 changed files with 593 additions and 158 deletions

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import request , Response, jsonify
from flask import request , Response, jsonify,stream_with_context
from api import settings
from api.db import LLMType
from api.db import StatusEnum
@@ -23,12 +23,15 @@ from api.db.services.llm_service import TenantLLMService
from api.db.services.user_service import TenantService
from api.db.services.brief_service import MesumOverviewService
from api.db.services.llm_service import LLMBundle
from api.db.services.antique_service import MesumAntiqueService
from api.utils import get_uuid
from api.utils.api_utils import get_error_data_result, token_required
from api.utils.api_utils import get_result
from api.utils.file_utils import get_project_base_directory
import logging
import base64
import queue,time,uuid
import base64, gzip
from io import BytesIO
import queue,time,uuid,os,array
from threading import Lock,Thread
from zhipuai import ZhipuAI
@@ -59,12 +62,37 @@ def my_llms(tenant_id):
main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链"
@manager.route('/photo/recongeText', methods=['POST'])
@manager.route('/photo/recongeText/<mesum_id>', methods=['POST'])
@token_required
def upload_file(tenant_id):
def upload_file(tenant_id,mesum_id):
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
antiques_selected = ""
if mesum_id:
"""
e,mesum_breif = MesumOverviewService.get_by_id(mesum_id)
if not e:
logging.info(f"没有找到匹配的博物馆信息,mesum_id={mesum_id}")
else:
antiques_selected =f"结果从:{mesum_breif.antique} 中进行选择"
"""
mesum_id_str = str(mesum_id)
antique_labels=get_antique_labels(mesum_id)
# 使用列表推导式和str()函数将所有元素转换为字符串
string_elements = [str(element) for element in antique_labels]
# 使用join()方法将字符串元素连接起来,以逗号为分隔符
joined_string = ','.join(string_elements)
antiques_selected = f"结果从:{joined_string} 中进行选择"
logging.info(f"{mesum_id} {joined_string}")
prompt = (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
f"请识别这个图片中文字,重点识别出含在文字中的某一文物标题、某一个历史事件或某一历史人物,"
f"你的回答有2个结果第一个结果是是从文字中识别出历史文物、历史事件、历史人物,"
f"此回答时只给出匹配的文物、事件、人物,不需要其他多余的文字,{antiques_selected}"
f",第二个结果是原始识别的所有文字"
"2个结果输出以{ }的json格式给出匹配文物、事件、人物的键值为antique如果有多个请加序号如:antique1,antique2,"
f"原始数据的键值为text输出是1个完整的JSON数据不要有多余的前置和后置内容确保前端能正确解析出JSON数据")
file = request.files['file']
if file.filename == '':
@@ -92,14 +120,7 @@ def upload_file(tenant_id):
},
{
"type": "text",
"text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称,"
f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍"
f"你的回答有2个结果第一个结果是是从文字进行分析出匹配文物候选文物只能如下:{req_antique},"
f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出,"
f",第二个结果是原始识别的所有文字"
"2个结果输出以{ }的json格式给出匹配文物的键值为antique如果有多个请加序号如:antique1,antique2,"
f"原始数据的键值为text输出是1个完整的JSON数据不要有多余的前置和后置内容确保前端能正确解析出JSON数据")
"text": prompt
}
]
}
@@ -213,6 +234,17 @@ def extract_text_from_markdown(markdown_text):
return text
def encode_gzip_base64(original_data: bytes) -> str:
"""核心编码过程:二进制数据 → Gzip压缩 → Base64编码"""
# Step 1: Gzip 压缩
with BytesIO() as buf:
with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file:
gz_file.write(original_data)
compressed_bytes = buf.getvalue()
# Step 2: Base64 编码配置与Android端匹配
return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符等同于Android的Base64.NO_WRAP
def clean_audio_cache():
"""定时清理过期缓存"""
with cache_lock:
@@ -241,38 +273,55 @@ def start_background_cleaner():
# 应用启动时启动清理线程
start_background_cleaner()
@manager.route('/tts_stream/<session_id>')
@manager.route('/tts_stream/<session_id>',methods=['GET'])
def tts_stream(session_id):
session = stream_manager.sessions.get(session_id)
def generate():
retry_count = 0
session = None
count = 0;
path = os.path.join(get_project_base_directory(), "api", "apps/sdk/test.mp3")
fmp3 =open(path, 'rb')
finished_event = session['finished']
try:
while retry_count < 1:
session = stream_manager.sessions.get(session_id)
while not finished_event.is_set() :
if not session or not session['active']:
break
try:
chunk = session['buffer'].get(timeout=5) # 30秒超时
chunk = session['buffer'].get_nowait() #
count = count + 1
if isinstance(chunk, str) and chunk.startswith("ERROR"):
logging.info("---tts stream error!!!!")
logging.info(f"---tts stream error!!!! {chunk}")
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
break
yield chunk
if session['stream_format'] == "wav":
gzip_base64_data = encode_gzip_base64(chunk) + "\r\n"
yield gzip_base64_data
else:
yield chunk
retry_count = 0 # 成功收到数据重置重试计数器
except queue.Empty:
retry_count += 1
yield b'' # 保持连接
if session['stream_format'] == "wav":
# yield encode_gzip_base64(b'\x03\x04' * 1) + "\r\n"
pass
else:
yield b'' # 保持连接
#data = fmp3.read(1024)
#yield data
except Exception as e:
logging.info(f"tts streag get error2 {e} ")
finally:
# 确保流结束后关闭会话
if session:
# 延迟关闭会话,确保所有数据已发送
time.sleep(5) # 等待5秒确保流结束
stream_manager.close_session(session_id)
logging.info(f"Session {session_id} closed.")
# 关键响应头设置
resp = Response(generate(), mimetype="audio/mpeg")
if session['stream_format'] == "wav":
resp = Response(stream_with_context(generate()), mimetype="audio/mpeg")
else:
resp = Response(stream_with_context(generate()), mimetype="audio/wav")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
@@ -291,17 +340,23 @@ def dialog_tts_get(chat_id, audio_stream_id):
chat_id = req.get('chat_id')
text = req.get('text', "..")
model_name = req.get('model_name')
sample_rate = req.get('tts_sample_rate',8000) # 默认8K
stream_format = req.get('tts_stream_format','mp3')
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
if not dia:
return get_error_data_result(message="You do not own the chat")
tts_model_name = dia.tts_id
if model_name: tts_model_name = model_name
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
logging.info(f"dialog_tts_get {sample_rate} {stream_format}")
def stream_audio():
try:
for chunk in tts_mdl.tts(text):
yield chunk
for chunk in tts_mdl.tts(text,sample_rate=sample_rate,stream_format=stream_format):
if stream_format =='wav':
yield encode_gzip_base64(chunk) + "\r\n"
else:
yield chunk
except Exception as e:
yield ("data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e)}},
@@ -318,7 +373,10 @@ def dialog_tts_get(chat_id, audio_stream_id):
audio_stream.seek(0)
resp = Response(generate(), mimetype="audio/mpeg")
else:
resp = Response(stream_audio(), mimetype="audio/mpeg")
if stream_format == 'wav':
resp = Response(stream_audio(), mimetype="audio/wav")
else:
resp = Response(stream_audio(), mimetype="audio/mpeg")
resp.headers.add_header("Cache-Control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
@@ -328,20 +386,19 @@ def dialog_tts_get(chat_id, audio_stream_id):
return get_error_data_result(message="音频流传输失败")
finally:
# 确保资源释放
if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
if tts_info and tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
tts_info['audio_stream'].close()
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
@token_required
def dialog_tts_post(tenant_id, chat_id):
req = request.json
try:
req = request.json
if not req.get("text"):
return get_error_data_result(message="Please input your question.")
delay_gen_audio = req.get('delay_gen_audio', False)
# text = extract_text_from_markdown(req.get('text'))
text = req.get('text')
delay_gen_audio = req.get('delay_gen_audio', False)
model_name = req.get('model_name')
audio_stream_id = req.get('audio_stream_id', None)
if audio_stream_id is None:
@@ -355,6 +412,10 @@ def dialog_tts_post(tenant_id, chat_id):
audio_stream = None
else:
audio_stream = io.BytesIO()
tts_stream_format = req.get('tts_stream_format', "mp3")
tts_sample_rate = req.get('tts_sample_rate', 8000)
logging.info(f"tts post {tts_sample_rate} {tts_stream_format}")
# 结构化缓存数据
tts_info = {
'text': text,
@@ -364,30 +425,21 @@ def dialog_tts_post(tenant_id, chat_id):
'audio_stream': audio_stream, # 维持原有逻辑
'model_name': req.get('model_name'),
'delay_gen_audio': delay_gen_audio, # 明确存储状态
audio_stream_id: audio_stream_id
'audio_stream_id': audio_stream_id,
'tts_sample_rate':tts_sample_rate,
'tts_stream_format':tts_stream_format
}
with cache_lock:
audio_text_cache[audio_stream_id] = tts_info
if delay_gen_audio is False:
try:
"""
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
try:
if txt is None or txt.strip() == "":
continue
for chunk in tts_mdl.tts(txt):
audio_stream.write(chunk)
except Exception as e:
continue
"""
audio_stream.seek(0, io.SEEK_END)
if text is None or text.strip() == "":
audio_stream.write(b'\x00' * 100)
else:
# 确保在流的末尾写入
audio_stream.seek(0, io.SEEK_END)
for chunk in tts_mdl.tts(text):
for chunk in tts_mdl.tts(text,sample_rate=tts_sample_rate,stream_formate=tts_stream_format):
audio_stream.write(chunk)
except Exception as e:
logging.info(f"--error:{e}")
@@ -397,10 +449,79 @@ def dialog_tts_post(tenant_id, chat_id):
# 构建音频流URL
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url} "
f"{tts_sample_rate} {tts_stream_format}")
# 返回音频流URL
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id,
"sample_rate":tts_sample_rate, "stream_format":tts_stream_format,})
except Exception as e:
logging.error(f"请求处理失败: {str(e)}", exc_info=True)
return get_error_data_result(message="服务器内部错误")
def get_antique_categories(mesum_id):
res = MesumAntiqueService.get_all_categories()
return res
def get_labels_ext(mesum_id):
res = MesumAntiqueService.get_labels_ext(mesum_id)
return res
def get_antique_labels(mesum_id):
res = MesumAntiqueService.get_all_labels()
return res
def get_all_antiques(mesum_id):
res =[]
antiques=MesumAntiqueService.get_by_mesum_id(mesum_id)
for o in antiques:
res.append(o.to_dict())
return res
@manager.route('/mesum/antique/<mesum_id>', methods=['GET'])
def mesum_antique_get(mesum_id):
try:
data = {
"anqituqes":get_all_antiques(mesum_id),
"categories":get_antique_categories(mesum_id),
"labels":get_antique_labels(mesum_id)
}
return get_result(data=data)
except Exception as e:
return get_error_data_result(message=f"Get mesum antique error {e}")
# 按照mesum_id 获得此博物馆的展品清单
@manager.route('/mesum/antique_brief/<mesum_id>', methods=['GET'])
@token_required
def mesum_antique_get_brief(tenant_id,mesum_id):
try:
data = {
"categories":get_antique_categories(mesum_id),
"labels":get_labels_ext(mesum_id)
}
return get_result(data=data)
except Exception as e:
return get_error_data_result(message=f"Get mesum antique error {e}")
@manager.route('/mesum/antique_detail/<mesum_id>/<antique_id>', methods=['GET'])
@token_required
def mesum_antique_get_full(tenant_id,mesum_id,antique_id):
try:
logging.info(f"mesum_antique_get_full {mesum_id} {antique_id}")
return get_result(data=MesumAntiqueService.get_antique_by_id(mesum_id,antique_id))
except Exception as e:
return get_error_data_result(message=f"Get mesum antique error {e}")
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()

BIN
api/apps/sdk/test.mp3 Normal file

Binary file not shown.

View File

@@ -25,7 +25,7 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
from flask_login import UserMixin
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
from peewee import (
BigIntegerField, BooleanField, CharField,
BigIntegerField, BooleanField, CharField,AutoField,
CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
Field, Model, Metadata
)
@@ -1025,6 +1025,23 @@ class MesumOverview(DataBaseModel):
class Meta:
db_table = "mesum_overview"
# added by cyx for mesum_antique
class MesumAntique(DataBaseModel):
sn = CharField(max_length=100, null=True)
label = CharField(max_length=100, null=True)
description = TextField(null=True)
category = CharField(max_length=100, null=True)
group = CharField(max_length=100, null=True)
background = TextField(null=True)
value = TextField(null=True)
discovery = TextField(null=True)
id = AutoField(primary_key=True)
mesum_id = CharField(max_length=100, null=True)
combined = TextField(null=True)
class Meta:
db_table = 'mesum_antique'
#-------------------------------------------
def migrate_db():
with DB.transaction():

View File

@@ -0,0 +1,88 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from datetime import datetime
import peewee
from werkzeug.security import generate_password_hash, check_password_hash
from api.db import UserTenantRole
from api.db.db_models import DB, UserTenant
from api.db.db_models import User, Tenant, MesumAntique
from api.db.services.common_service import CommonService
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
from api.db import StatusEnum
class MesumAntiqueService(CommonService):
model = MesumAntique
@classmethod
@DB.connection_context()
def get_by_mesum_id(cls, mesum_id):
objs = cls.query(mesum_id=mesum_id)
return objs
@classmethod
@DB.connection_context()
def get_all_categories(cls):
# 查询所有唯一的category
categories = [category.category for category in cls.model.select(cls.model.category).distinct().execute() if category.category]
return categories
@classmethod
@DB.connection_context()
def get_all_labels(cls):
# 查询所有去重后的label
labels = [label.label for label in cls.model.select(cls.model.label).distinct().execute() if label.label]
return labels
@classmethod
@DB.connection_context()
def get_labels_ext(cls, mesum_id):
# 根据mesum_id过滤并排除空的category
query = cls.model.select().where(
(cls.model.mesum_id == mesum_id) &
(cls.model.category != "")
).order_by(cls.model.category)
# 按category分组并存储结果
grouped_data = {}
for obj in query.dicts().execute():
category = obj['category']
if category not in grouped_data:
grouped_data[category] = []
grouped_data[category].append({
'id': obj['id'],
'label': obj['label']
})
return grouped_data
@classmethod
@DB.connection_context()
def get_antique_by_id(cls, mesum_id,antique_id):
query = cls.model.select().where(
(cls.model.mesum_id == mesum_id) &
(cls.model.id == antique_id)
)
data = []
for obj in query.dicts().execute():
data.append(obj)
if len(data) > 0:
data = data[0]
return data

View File

@@ -33,9 +33,21 @@ from rag.nlp.search import index_name
from rag.utils import rmSpace, num_tokens_from_string, encoder
from api.utils.file_utils import get_project_base_directory
from peewee import fn
import threading, queue,uuid,time
import threading, queue,uuid,time,array
from concurrent.futures import ThreadPoolExecutor
def audio_fade_in(audio_data, fade_length):
# 假设音频数据是16位单声道PCM
# 将二进制数据转换为整数数组
samples = array.array('h', audio_data)
# 对前fade_length个样本进行淡入处理
for i in range(fade_length):
fade_factor = i / fade_length
samples[i] = int(samples[i] * fade_factor)
# 将整数数组转换回二进制数据
return samples.tobytes()
class StreamSessionManager:
@@ -44,17 +56,20 @@ class StreamSessionManager:
self.lock = threading.Lock()
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
self.gc_interval = 300 # 5分钟清理一次
def create_session(self, tts_model):
self.gc_tts = 3 # 3s
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
session_id = str(uuid.uuid4())
with self.lock:
self.sessions[session_id] = {
'tts_model': tts_model,
'buffer': queue.Queue(maxsize=100), # 线程安全队列
'buffer': queue.Queue(maxsize=300), # 线程安全队列
'task_queue': queue.Queue(),
'active': True,
'last_active': time.time(),
'audio_chunk_count':0
'audio_chunk_count':0,
'finished': threading.Event(), # 添加事件对象
'sample_rate':sample_rate,
'stream_format':stream_format
}
# 启动任务处理线程
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
@@ -98,6 +113,9 @@ class StreamSessionManager:
if time.time() - session['last_active'] > self.gc_interval:
self.close_session(session_id)
break
if time.time() - session['last_active'] > self.gc_tts:
session['finished'].set()
break
except Exception as e:
logging.error(f"Task processing error: {str(e)}")
@@ -107,22 +125,36 @@ class StreamSessionManager:
session = self.sessions.get(session_id)
if not session: return
# logging.info(f"_generate_audio:{text}")
first_chunk = True
try:
for chunk in session['tts_model'].tts(text):
session['buffer'].put(chunk)
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
if session['stream_format'] == 'wav':
if first_chunk:
chunk_len = len(chunk)
if chunk_len > 2048:
session['buffer'].put(audio_fade_in(chunk,1024))
else:
session['buffer'].put(audio_fade_in(chunk, chunk_len))
first_chunk = False
else:
session['buffer'].put(chunk)
else:
session['buffer'].put(chunk)
session['last_active'] = time.time()
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
except Exception as e:
session['buffer'].put(f"ERROR:{str(e)}")
logging.info(f"--_generate_audio--error {str(e)}")
def close_session(self, session_id):
with self.lock:
if session_id in self.sessions:
# 标记会话为不活跃
self.sessions[session_id]['active'] = False
# 延迟30秒后清理资源
threading.Timer(10, self._clean_session, args=[session_id]).start()
# 延迟2秒后清理资源
threading.Timer(1, self._clean_session, args=[session_id]).start()
def _clean_session(self, session_id):
with self.lock:
@@ -297,7 +329,6 @@ def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
# 5. 检查长度
if len(delta_ans) == 0 or len(delta_ans) > max_length:
return False, ""
# 如果通过所有检查,返回有效标志和修正后的文本
return True, delta_ans
@@ -339,12 +370,6 @@ def find_split_position(text):
if date_pattern:
return date_pattern.end()
# 避免拆分常见短语
for phrase in ["青少年", "博物馆", "参观"]:
idx = text.rfind(phrase)
if idx != -1 and idx + len(phrase) <= len(text):
return idx + len(phrase)
return None
# 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分
@@ -353,9 +378,7 @@ def process_buffer(chunk_buffer, force_flush=False):
current_text = "".join(chunk_buffer)
if not current_text:
return "", []
split_pos = find_split_position(current_text)
# 强制刷新逻辑
if force_flush or len(current_text) >= MAX_BUFFER_LEN:
# 即使强制刷新也要尽量找合适的分割点
@@ -366,7 +389,7 @@ def process_buffer(chunk_buffer, force_flush=False):
if split_pos is not None and split_pos > 0:
return current_text[:split_pos], [current_text[split_pos:]]
return "", chunk_buffer
return None, chunk_buffer
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
@@ -421,6 +444,8 @@ def chat(dialog, messages, stream=True, **kwargs):
else:
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS, dialog.tts_id)
tts_sample_rate = kwargs.get("tts_sample_rate",8000) # 默认为8K
tts_stream_format = kwargs.get("tts_stream_format","mp3") # 默认为mp3格式
# try to use sql if field mapping is good to go
if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
@@ -465,20 +490,21 @@ def chat(dialog, messages, stream=True, **kwargs):
doc_ids=attachments,
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
logging.debug(
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
# 打印历史记录
#logging.info( "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
retrieval_tm = timer()
if not knowledges and prompt_config.get("empty_response"):
empty_res = prompt_config["empty_response"]
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
yield {"answer": empty_res, "reference": kbinfos, "audio_binary":
tts(tts_mdl, empty_res,sample_rate=tts_sample_rate,stream_format=tts_stream_format)}
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
kwargs["knowledge"] = "\n\n------\n\n".join(knowledges)
gen_conf = dialog.llm_setting
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
for m in messages if m["role"] != "system"])
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
@@ -502,6 +528,9 @@ def chat(dialog, messages, stream=True, **kwargs):
embd_mdl,
tkweight=1 - dialog.vector_similarity_weight,
vtweight=dialog.vector_similarity_weight)
# 上述转换过程中发现有时候会在answer中插入类似##0$$ ##1$$ 这样的字符串,需要去除
# cyx 20250407
answer = re.sub(r'##\d+\$\$', '', answer).strip() #去除##0$$类似内容 同时去除多余空格
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
recall_docs = [
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
@@ -512,7 +541,6 @@ def chat(dialog, messages, stream=True, **kwargs):
for c in refs["chunks"]:
if c.get("vector"):
del c["vector"]
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
done_tm = timer()
@@ -525,83 +553,92 @@ def chat(dialog, messages, stream=True, **kwargs):
last_ans = ""
answer = ""
# 创建TTS会话提前初始化
tts_session_id = stream_manager.create_session(tts_mdl)
tts_session_id = stream_manager.create_session(tts_mdl,sample_rate=tts_sample_rate,stream_format=tts_stream_format)
audio_url = f"/tts_stream/{tts_session_id}"
first_chunk = True
chunk_buffer = [] # 新增文本缓冲
last_flush_time = time.time() # 初始化时间戳
# 下面优先处理知识库中没有找到相关内容 cyx 20250323 修改
if not kwargs["knowledge"] or kwargs["knowledge"] =="" or len(kwargs["knowledge"]) < 4:
stream_manager.append_text(tts_session_id, "未找到相关内容")
yield {
"answer": "未找到相关内容",
"delta_ans": "未找到相关内容",
"session_id": tts_session_id,
"reference": {},
"audio_stream_url": audio_url,
"sample_rate": tts_sample_rate,
"stream_format": tts_stream_format,
}
else:
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 24:
continue
last_ans = answer
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
# cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
if kwargs.get('tts_disable'):
tts_input_is_valid =False
if tts_input_is_valid:
# 缓冲文本直到遇到标点
chunk_buffer.append(sanitized_text)
# 处理缓冲区内容
while True:
# 判断是否需要强制刷新
force = time.time() - last_flush_time > FLUSH_TIMEOUT
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
if not to_send:
break
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
# 发送有效内容
stream_manager.append_text(tts_session_id, to_send)
chunk_buffer = remaining
last_flush_time = time.time()
"""
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
"""
# 首块返回音频URL
if first_chunk:
yield {
"answer": answer,
"delta_ans": sanitized_text,
"session_id": tts_session_id,
"reference": {},
"audio_stream_url": audio_url,
"sample_rate":tts_sample_rate,
"stream_format":tts_stream_format,
}
first_chunk = False
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
answer = ans
delta_ans = ans[len(last_ans):]
if num_tokens_from_string(delta_ans) < 24:
continue
last_ans = answer
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
# cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
if kwargs.get('tts_disable'):
tts_input_is_valid =False
if tts_input_is_valid:
# 缓冲文本直到遇到标点
chunk_buffer.append(sanitized_text)
# 处理缓冲区内容
while True:
# 判断是否需要强制刷新
force = time.time() - last_flush_time > FLUSH_TIMEOUT
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
if not to_send:
break
# 发送有效内容
stream_manager.append_text(tts_session_id, to_send)
chunk_buffer = remaining
last_flush_time = time.time()
"""
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
delta_ans = answer[len(last_ans):]
if delta_ans:
# stream_manager.append_text(tts_session_id, delta_ans)
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
if tts_input_is_valid:
# 20250221 修改,在后端生成音频数据
chunk_buffer.append(sanitized_text)
stream_manager.append_text(tts_session_id, ''.join(chunk_buffer))
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
"""
# 首块返回音频URL
if first_chunk:
yield {
"answer": answer,
"delta_ans": sanitized_text,
"audio_stream_url": audio_url,
"session_id": tts_session_id,
"reference": {}
}
first_chunk = False
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
delta_ans = answer[len(last_ans):]
if delta_ans:
# stream_manager.append_text(tts_session_id, delta_ans)
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
if tts_input_is_valid:
# 20250221 修改,在后端生成音频数据
chunk_buffer.append(sanitized_text)
stream_manager.append_text(tts_session_id, ''.join(chunk_buffer))
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
"""
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
"""
yield decorate_answer(answer)
"""
if tts_input_is_valid:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
else:
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
"""
yield decorate_answer(answer)
else:
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
@@ -611,7 +648,7 @@ def chat(dialog, messages, stream=True, **kwargs):
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数就不生成tts 音频给前端,即:没有audio_binary
tts_input_is_valid = False
else:
res["audio_binary"] = tts(tts_mdl, answer)
res["audio_binary"] = tts(tts_mdl, answer,tts_sample_rate,tts_stream_format)
yield res
@@ -899,10 +936,10 @@ Output: What's the weather in Rochester on {tomorrow}?
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
def tts(tts_mdl, text):
def tts(tts_mdl, text,sample_rate=8000,stream_format = "mp3"):
if not tts_mdl or not text: return
bin = b""
for chunk in tts_mdl.tts(text):
for chunk in tts_mdl.tts(text,sample_rate,stream_format):
bin += chunk
return binascii.hexlify(bin).decode("utf-8")

View File

@@ -248,8 +248,8 @@ class LLMBundle(object):
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
return txt
def tts(self, text): # tts 调用 cyx
for chunk in self.mdl.tts(text):
def tts(self, text, sample_rate=8000, stream_format='mp3'): # tts 调用 cyx
for chunk in self.mdl.tts(text, sample_rate=sample_rate,stream_format = stream_format):
if isinstance(chunk,int):
if not TenantLLMService.increase_usage(
self.tenant_id, self.llm_type, chunk, self.llm_name):

View File

@@ -167,6 +167,40 @@
}
]
},
{
"name": "Qianwen-Omni",
"logo": "",
"tags": "LLM,IMAGE2TEXT,MODERATION",
"status": "1",
"llm": [
{
"llm_name": "qwen-omni-turbo",
"tags": "LLM,CHAT,32K",
"max_tokens": 30768,
"model_type": "chat"
},
{
"llm_name": "qwen-omni-turbo-latest",
"tags": "LLM,CHAT,IMAGE2TEXT",
"max_tokens": 30768,
"model_type": "image2text"
}
]
},
{
"name": "LOCAL-LLM",
"logo": "",
"tags": "chat",
"status": "1",
"llm": [
{
"llm_name": "chat",
"tags": "LLM,CHAT,32K",
"max_tokens": 12768,
"model_type": "chat"
}
]
},
{
"name": "ZHIPU-AI",
"logo": "",

View File

@@ -76,6 +76,8 @@ ChatModel = {
"Azure-OpenAI": AzureChat,
"ZHIPU-AI": ZhipuChat,
"Tongyi-Qianwen": QWenChat,
"Qianwen-Omni": QWenOmniChat,
"LOCAL-LLM":LocalLLMChat,
"Ollama": OllamaChat,
"LocalAI": LocalAIChat,
"Xinference": XinferenceChat,

View File

@@ -127,6 +127,98 @@ class DeepSeekChat(Base):
super().__init__(key, model_name, base_url)
class LocalLLMChat(Base):
def __init__(self, key, model_name="Qwen2.5-7B", base_url="http://106.52.71.204:9483/v1"):
if not base_url: base_url = "http://106.52.71.204:9483/v1"
super().__init__(key, model_name, base_url)
class QWenOmniChat(Base):
def __init__(self, key, model_name="qwen-omni-turbo", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"):
if not base_url: base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1"
super().__init__(key, model_name, base_url)
def chat(self, system, history, gen_conf):
if system:
history.insert(0, {"role": "system", "content": system})
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model = self.model_name,
messages=history,
stream=True,
# 设置输出数据的模态,当前支持两种:["text","audio"]、["text"]
# modalities=["text", "audio"],
# audio={"voice": "Cherry", "format": "wav"},
# stream 必须设置为 True否则会报错
# stream_options={"include_usage": True},
**gen_conf
)
for resp in response:
if not resp.choices: continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
break # 如果达到长度限制,可以跳出循环
except openai.APIError as e:
ans= ans + "\n**ERROR**: " + str(e)
return ans, total_tokens
def chat_streamly(self, system, history, gen_conf):
# logging.info(f"chat_streamly :{gen_conf}")
if system :
history.insert(0, {"role": "system", "content":
[{"type":"text","text": system}]})
ans = ""
total_tokens = 0
try:
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
stream=True,
#**gen_conf
)
for resp in response:
if not resp.choices: continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
ans += resp.choices[0].delta.content
if not hasattr(resp, "usage") or not resp.usage:
total_tokens = (
total_tokens
+ num_tokens_from_string(resp.choices[0].delta.content)
)
elif isinstance(resp.usage, dict):
total_tokens = resp.usage.get("total_tokens", total_tokens)
else: total_tokens = resp.usage.total_tokens
if resp.choices[0].finish_reason == "length":
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
yield ans
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
yield total_tokens
class AzureChat(Base):
def __init__(self, key, model_name, **kwargs):
api_key = json.loads(key).get('api_key', '')

View File

@@ -203,11 +203,14 @@ class ZhipuEmbed(Base):
def encode(self, texts: list, batch_size=16):
arr = []
tks_num = 0
for txt in texts:
res = self.client.embeddings.create(input=txt,
model=self.model_name)
arr.append(res.data[0].embedding)
tks_num += res.usage.total_tokens
try:
for txt in texts:
res = self.client.embeddings.create(input=txt,
model=self.model_name)
arr.append(res.data[0].embedding)
tks_num += res.usage.total_tokens
except Exception as error:
logging.info(f"!!!ZhipuEmbed embedding error {error}")
return np.array(arr), tks_num
def encode_queries(self, text):

View File

@@ -133,7 +133,8 @@ class QwenTTS(Base):
if parts[0] == 'cosyvoice-v1':
self.is_cosyvoice = True
self.voice = parts[1]
def tts(self, text):
# 参数stream_format 为产生的tts 音频数据格式, mp3 wav pcm
def tts(self, text, sample_rate=8000,stream_format="mp3"):
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
if self.is_cosyvoice is False:
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
@@ -143,7 +144,7 @@ class QwenTTS(Base):
from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat #, SpeechSynthesisResult
from dashscope.audio.tts import SpeechSynthesisResult
from collections import deque
print(f"--QwenTTS--tts_stream begin-- {text}") # cyx
# print(f"--QwenTTS--tts_stream begin-- {text}") # cyx
class Callback(ResultCallback):
def __init__(self) -> None:
self.dque = deque()
@@ -237,12 +238,52 @@ class QwenTTS(Base):
format="mp3")
else:
self.callback = Callback_v2()
print(f"--tts {sample_rate} {stream_format}")
if sample_rate == 8000:
if stream_format == 'mp3':
format = AudioFormat.MP3_8000HZ_MONO_128KBPS
elif stream_format == 'pcm':
format = AudioFormat.PCM_8000HZ_MONO_16BIT
elif stream_format == 'wav':
format = AudioFormat.WAV_8000HZ_MONO_16BIT
else:
format = AudioFormat.MP3_8000HZ_MONO_128KBPS
elif sample_rate == 16000:
if stream_format == 'mp3':
format = AudioFormat.MP3_16000HZ_MONO_128KBPS
elif stream_format == 'pcm':
format = AudioFormat.PCM_16000HZ_MONO_16BIT
elif stream_format == 'wav':
format = AudioFormat.WAV_16000HZ_MONO_16BIT
else:
format = AudioFormat.MP3_16000HZ_MONO_128KBPS
elif sample_rate == 22050:
if stream_format == 'mp3':
format = AudioFormat.MP3_22050HZ_MONO_256KBPS
elif stream_format == 'pcm':
format = AudioFormat.PCM_22050HZ_MONO_16BIT
elif stream_format == 'wav':
format = AudioFormat.WAV_22050HZ_MONO_16BIT
else:
format = AudioFormat.MP3_22050HZ_MONO_256KBPS
elif sample_rate == 44100:
if stream_format == 'mp3':
format = AudioFormat.MP3_44100HZ_MONO_256KBPS
elif stream_format == 'pcm':
format = AudioFormat.PCM_44100HZ_MONO_16BIT
elif stream_format == 'wav':
format = AudioFormat.WAV_44100HZ_MONO_16BIT
else:
format = AudioFormat.MP3_44100HZ_MONO_256KBPS
# format=AudioFormat.MP3_44100HZ_MONO_256KBPS
else:
format = AudioFormat.MP3_44100HZ_MONO_256KBPS
self.synthesizer = SpeechSynthesizer(
model='cosyvoice-v1',
# voice="longyuan", #"longfei",
voice = self.voice,
callback=self.callback,
format=AudioFormat.MP3_44100HZ_MONO_256KBPS
format=format
)
self.synthesizer.call(text)

View File

@@ -317,7 +317,7 @@ def embedding(docs, mdl, parser_config=None, callback=None):
vects = (title_w * tts + (1 - title_w) *
cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs)
# assert len(vects) == len(docs)
vector_size = 0
for i, d in enumerate(docs):
v = vects[i].tolist()