From a5e83f4d3b465120c6691b0717c514aa4d9d76af Mon Sep 17 00:00:00 2001 From: qcloud Date: Sun, 23 Feb 2025 09:52:30 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9C=A8=E7=94=9F=E6=88=90=E5=AF=B9=E8=AF=9D?= =?UTF-8?q?=E6=96=87=E5=AD=97=E6=97=B6=EF=BC=8C=E5=90=8C=E6=97=B6=E5=9C=A8?= =?UTF-8?q?=E5=90=8E=E5=8F=B0=E7=94=9F=E6=88=90tts=E9=9F=B3=E9=A2=91?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=8A=A0=E6=9C=97=E8=AF=BB=E9=9F=B3=E8=89=B2?= =?UTF-8?q?=E9=80=89=E6=8B=A9=EF=BC=8C=E5=A2=9E=E5=8A=A0=E5=8D=9A=E7=89=A9?= =?UTF-8?q?=E9=A6=86=E7=9A=84=E6=A6=82=E5=86=B5=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/apps/sdk/dale_extra.py | 363 +++++++++++++++++++++++++++++- api/apps/sdk/session.py | 188 +--------------- api/db/db_models.py | 37 +++ api/db/services/brief_service.py | 31 +++ api/db/services/dialog_service.py | 247 +++++++++++++++++--- conf/llm_factories.json | 5 +- rag/llm/tts_model.py | 6 +- 7 files changed, 653 insertions(+), 224 deletions(-) create mode 100644 api/db/services/brief_service.py diff --git a/api/apps/sdk/dale_extra.py b/api/apps/sdk/dale_extra.py index e6726234..b05119ec 100644 --- a/api/apps/sdk/dale_extra.py +++ b/api/apps/sdk/dale_extra.py @@ -13,17 +13,24 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from flask import request +from flask import request , Response, jsonify from api import settings +from api.db import LLMType from api.db import StatusEnum -from api.db.services.dialog_service import DialogService +from api.db.services.dialog_service import DialogService,stream_manager from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService +from api.db.services.brief_service import MesumOverviewService +from api.db.services.llm_service import LLMBundle from api.utils import get_uuid from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_result - +import logging +import base64 +import queue,time,uuid +from threading import Lock,Thread +from zhipuai import ZhipuAI # 用户已经添加的模型 cyx 2025-01-26 @manager.route('/get_llms', methods=['GET']) @@ -48,4 +55,352 @@ def my_llms(tenant_id): }) return get_result(data=res) except Exception as e: - return get_error_data_result(message=f"Get LLMS error {e}") \ No newline at end of file + return get_error_data_result(message=f"Get LLMS error {e}") + + +main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链" +@manager.route('/photo/recongeText', methods=['POST']) +@token_required +def upload_file(tenant_id): + if 'file' not in request.files: + return jsonify({'error': 'No file part'}), 400 + + file = request.files['file'] + + if file.filename == '': + return jsonify({'error': 'No selected file'}), 400 + + if file and allowed_file(file.filename): + file_size = request.content_length + img_base = base64.b64encode(file.read()).decode('utf-8') + req_antique = request.form.get('antique',None) + if req_antique is None: + req_antique = main_antiquity + logging.info(f"recevie photo file {file.filename} {file_size} 识别中....") + client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey + response = client.chat.completions.create( + model="glm-4v-plus", # 填写需要调用的模型名称 + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": img_base + } + }, + { + "type": "text", + "text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家," + f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称," + f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍" + f"你的回答有2个结果,第一个结果是是从文字进行分析出匹配文物,候选文物只能如下:{req_antique}," + f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出," + f",第二个结果是原始识别的所有文字" + "2个结果输出以{ }的json格式给出,匹配文物的键值为antique,如果有多个请加序号,如:antique1,antique2," + f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据") + } + ] + } + ] + ) + + message = response.choices[0].message + logging.info(message.content) + return jsonify({'message': 'File uploaded successfully','text':message.content }), 200 + +def allowed_file(filename): + return '.' in filename and \ + filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} + + +#get_all + +@manager.route('/mesum/list', methods=['GET']) +@token_required +def mesum_list(tenant_id): + # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 + # model_type = request.args.get("type") + try: + res = [] + overviews=MesumOverviewService.get_all() + for o in overviews: + res.append(o.to_dict()) + return get_result(data=res) + except Exception as e: + return get_error_data_result(message=f"Get LLMS error {e}") + +@manager.route('/mesum/set_antique', methods=['POST']) +@token_required +def mesum_set_antique(tenant_id): + global main_antiquity + # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 + req_data = request.json + req_data_antique=req_data.get('antique',None) + try: + if req_data_antique: + main_antiquity = req_data_antique + print(main_antiquity) + return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'}) + except Exception as e: + return get_error_data_result(message=f"Get LLMS error {e}") + +audio_text_cache = {} +cache_lock = Lock() +CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 +# 全角字符到半角字符的映射 +def fullwidth_to_halfwidth(s): + full_to_half_map = { + '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", + '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', + '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', + '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', + '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', + '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', + '、': '、', '・': '・', ':': ':' + } + return ''.join(full_to_half_map.get(char, char) for char in s) + +def split_text_at_punctuation(text, chunk_size=100): + # 使用正则表达式找到所有的标点符号和特殊字符 + punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' + tokens = re.split(punctuation_pattern, text) + + # 移除空字符串 + tokens = [token for token in tokens if token] + + # 存储最终的文本块 + chunks = [] + current_chunk = '' + + for token in tokens: + if len(current_chunk) + len(token) <= chunk_size: + # 如果添加当前token后长度不超过chunk_size,则添加到当前块 + current_chunk += (token + ' ') + else: + # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 + chunks.append(current_chunk.strip()) + current_chunk = token + ' ' + + # 添加最后一个块(如果有剩余) + if current_chunk: + chunks.append(current_chunk.strip()) + + return chunks + + +def extract_text_from_markdown(markdown_text): + # 移除Markdown标题 + text = re.sub(r'#\s*[^#]+', '', markdown_text) + # 移除内联代码块 + text = re.sub(r'`[^`]+`', '', text) + # 移除代码块 + text = re.sub(r'```[\s\S]*?```', '', text) + # 移除加粗和斜体 + text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) + # 移除链接 + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + # 移除图片 + text = re.sub(r'!\[.*?\]\(.*?\)', '', text) + # 移除HTML标签 + text = re.sub(r'<[^>]+>', '', text) + # 转换标点符号 + # text = re.sub(r'[^\w\s]', '', text) + text = fullwidth_to_halfwidth(text) + # 移除多余的空格 + text = re.sub(r'\s+', ' ', text).strip() + + return text + +def clean_audio_cache(): + """定时清理过期缓存""" + with cache_lock: + now = time.time() + expired_keys = [ + k for k, v in audio_text_cache.items() + if now - v['created_at'] > CACHE_EXPIRE_SECONDS + ] + for k in expired_keys: + entry = audio_text_cache.pop(k, None) + if entry and entry.get('audio_stream'): + entry['audio_stream'].close() + + +def start_background_cleaner(): + """启动后台清理线程""" + + def cleaner_loop(): + while True: + time.sleep(180) # 每3分钟清理一次 + clean_audio_cache() + + cleaner_thread = Thread(target=cleaner_loop, daemon=True) + cleaner_thread.start() + +# 应用启动时启动清理线程 +start_background_cleaner() + +@manager.route('/tts_stream/') +def tts_stream(session_id): + def generate(): + retry_count = 0 + session = None + count = 0; + try: + while retry_count < 1: + session = stream_manager.sessions.get(session_id) + if not session or not session['active']: + break + try: + chunk = session['buffer'].get(timeout=5) # 30秒超时 + count = count + 1 + if isinstance(chunk, str) and chunk.startswith("ERROR"): + logging.info("---tts stream error!!!!") + yield f"data:{{'error':'{chunk[6:]}'}}\n\n" + break + yield chunk + retry_count = 0 # 成功收到数据重置重试计数器 + except queue.Empty: + retry_count += 1 + yield b'' # 保持连接 + finally: + # 确保流结束后关闭会话 + if session: + # 延迟关闭会话,确保所有数据已发送 + time.sleep(5) # 等待5秒确保流结束 + stream_manager.close_session(session_id) + logging.info(f"Session {session_id} closed.") + + resp = Response(generate(), mimetype="audio/mpeg") + resp.headers.add_header("Cache-Control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + return resp + +@manager.route('/chats//tts/', methods=['GET']) +def dialog_tts_get(chat_id, audio_stream_id): + with cache_lock: + tts_info = audio_text_cache.pop(audio_stream_id, None) # 取出即删除 + try: + req = tts_info + if not req: + return get_error_data_result(message="Audio stream not found or expired.") + audio_stream = req.get('audio_stream') + tenant_id = req.get('tenant_id') + chat_id = req.get('chat_id') + text = req.get('text', "..") + model_name = req.get('model_name') + dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) + if not dia: + return get_error_data_result(message="You do not own the chat") + tts_model_name = dia.tts_id + if model_name: tts_model_name = model_name + tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) + + def stream_audio(): + try: + for chunk in tts_mdl.tts(text): + yield chunk + except Exception as e: + yield ("data:" + json.dumps({"code": 500, "message": str(e), + "data": {"answer": "**ERROR**: " + str(e)}}, + ensure_ascii=False)).encode('utf-8') + + def generate(): + data = audio_stream.read(1024) + while data: + yield data + data = audio_stream.read(1024) + + if audio_stream: + # 确保流的位置在开始处 + audio_stream.seek(0) + resp = Response(generate(), mimetype="audio/mpeg") + else: + resp = Response(stream_audio(), mimetype="audio/mpeg") + resp.headers.add_header("Cache-Control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + return resp + except Exception as e: + logging.error(f"音频流传输错误: {str(e)}", exc_info=True) + return get_error_data_result(message="音频流传输失败") + finally: + # 确保资源释放 + if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed: + tts_info['audio_stream'].close() + + +@manager.route('/chats//tts', methods=['POST']) +@token_required +def dialog_tts_post(tenant_id, chat_id): + try: + req = request.json + if not req.get("text"): + return get_error_data_result(message="Please input your question.") + delay_gen_audio = req.get('delay_gen_audio', False) + # text = extract_text_from_markdown(req.get('text')) + text = req.get('text') + model_name = req.get('model_name') + audio_stream_id = req.get('audio_stream_id', None) + if audio_stream_id is None: + audio_stream_id = str(uuid.uuid4()) + # 在这里生成音频流并存储到内存中 + dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) + tts_model_name = dia.tts_id + if model_name: tts_model_name = model_name + tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) + if delay_gen_audio: + audio_stream = None + else: + audio_stream = io.BytesIO() + # 结构化缓存数据 + tts_info = { + 'text': text, + 'tenant_id': tenant_id, + 'chat_id': chat_id, + 'created_at': time.time(), + 'audio_stream': audio_stream, # 维持原有逻辑 + 'model_name': req.get('model_name'), + 'delay_gen_audio': delay_gen_audio, # 明确存储状态 + audio_stream_id: audio_stream_id + } + + with cache_lock: + audio_text_cache[audio_stream_id] = tts_info + + if delay_gen_audio is False: + try: + """ + for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): + try: + if txt is None or txt.strip() == "": + continue + for chunk in tts_mdl.tts(txt): + audio_stream.write(chunk) + except Exception as e: + continue + """ + if text is None or text.strip() == "": + audio_stream.write(b'\x00' * 100) + else: + # 确保在流的末尾写入 + audio_stream.seek(0, io.SEEK_END) + for chunk in tts_mdl.tts(text): + audio_stream.write(chunk) + except Exception as e: + logging.info(f"--error:{e}") + with cache_lock: + audio_text_cache.pop(audio_stream_id, None) + return get_error_data_result(message="get tts audio stream error.") + + # 构建音频流URL + audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}" + logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}") + # 返回音频流URL + return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id}) + + except Exception as e: + logging.error(f"请求处理失败: {str(e)}", exc_info=True) + return get_error_data_result(message="服务器内部错误") diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 619c10e4..94a83eb4 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -19,22 +19,22 @@ import logging from copy import deepcopy from uuid import uuid4 from api.db import LLMType -from flask import request, Response, jsonify +from flask import request, Response, jsonify, stream_with_context from api.db.services.dialog_service import ask from agent.canvas import Canvas from api.db import StatusEnum from api.db.db_models import API4Conversation from api.db.services.api_service import API4ConversationService from api.db.services.canvas_service import UserCanvasService -from api.db.services.dialog_service import DialogService, ConversationService, chat +from api.db.services.dialog_service import DialogService, ConversationService, chat,stream_manager from api.db.services.knowledgebase_service import KnowledgebaseService from api.utils import get_uuid from api.utils.api_utils import get_error_data_result from api.utils.api_utils import get_result, token_required from api.db.services.llm_service import LLMBundle import uuid -import queue - +import queue,time +from threading import Lock,Thread @manager.route('/chats//sessions', methods=['POST']) @token_required @@ -239,186 +239,6 @@ def completion(tenant_id, chat_id): # chat_id 和 别的文件中的dialog_id break return get_result(data=answer) - -# 全角字符到半角字符的映射 - - -def fullwidth_to_halfwidth(s): - full_to_half_map = { - '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", - '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', - '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', - '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', - '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', - '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', - '、': '、', '・': '・', ':': ':' - } - return ''.join(full_to_half_map.get(char, char) for char in s) - - -def is_dale(s): - full_to_half_map = { - '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", - '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', - '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', - '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', - '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', - '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', - '、': '、', '・': '・', ':': ':', '。': '.' - } - - -def extract_text_from_markdown(markdown_text): - # 移除Markdown标题 - text = re.sub(r'#\s*[^#]+', '', markdown_text) - # 移除内联代码块 - text = re.sub(r'`[^`]+`', '', text) - # 移除代码块 - text = re.sub(r'```[\s\S]*?```', '', text) - # 移除加粗和斜体 - text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) - # 移除链接 - text = re.sub(r'\[.*?\]\(.*?\)', '', text) - # 移除图片 - text = re.sub(r'!\[.*?\]\(.*?\)', '', text) - # 移除HTML标签 - text = re.sub(r'<[^>]+>', '', text) - # 转换标点符号 - # text = re.sub(r'[^\w\s]', '', text) - text = fullwidth_to_halfwidth(text) - # 移除多余的空格 - text = re.sub(r'\s+', ' ', text).strip() - - return text - - -def split_text_at_punctuation(text, chunk_size=100): - # 使用正则表达式找到所有的标点符号和特殊字符 - punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' - tokens = re.split(punctuation_pattern, text) - - # 移除空字符串 - tokens = [token for token in tokens if token] - - # 存储最终的文本块 - chunks = [] - current_chunk = '' - - for token in tokens: - if len(current_chunk) + len(token) <= chunk_size: - # 如果添加当前token后长度不超过chunk_size,则添加到当前块 - current_chunk += (token + ' ') - else: - # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 - chunks.append(current_chunk.strip()) - current_chunk = token + ' ' - - # 添加最后一个块(如果有剩余) - if current_chunk: - chunks.append(current_chunk.strip()) - - return chunks - -audio_text_cache = {} - -@manager.route('/chats//tts/', methods=['GET']) -def dialog_tts_get(chat_id, audio_stream_id): - tts_info = audio_text_cache.pop(audio_stream_id, None) - req = tts_info - if not req: - return get_error_data_result(message="Audio stream not found or expired.") - audio_stream = req.get('audio_stream') - tenant_id = req.get('tenant_id') - chat_id = req.get('chat_id') - text = req.get('text', "..") - model_name = req.get('model_name') - dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) - if not dia: - return get_error_data_result(message="You do not own the chat") - tts_model_name = dia.tts_id - if model_name: tts_model_name = model_name - tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) - - def stream_audio(): - try: - for chunk in tts_mdl.tts(text): - yield chunk - except Exception as e: - yield ("data:" + json.dumps({"code": 500, "message": str(e), - "data": {"answer": "**ERROR**: " + str(e)}}, - ensure_ascii=False)).encode('utf-8') - - def generate(): - data = audio_stream.read(1024) - while data: - yield data - data = audio_stream.read(1024) - - if audio_stream: - # 确保流的位置在开始处 - audio_stream.seek(0) - resp = Response(generate(), mimetype="audio/mpeg") - else: - resp = Response(stream_audio(), mimetype="audio/mpeg") - resp.headers.add_header("Cache-Control", "no-cache") - resp.headers.add_header("Connection", "keep-alive") - resp.headers.add_header("X-Accel-Buffering", "no") - return resp - - -@manager.route('/chats//tts', methods=['POST']) -@token_required -def dialog_tts_post(tenant_id, chat_id): - req = request.json - if not req.get("text"): - return get_error_data_result(message="Please input your question.") - delay_gen_audio = req.get('delay_gen_audio', False) - # text = extract_text_from_markdown(req.get('text')) - text = req.get('text') - audio_stream_id = req.get('audio_stream_id') - # logging.info(f"request tts audio url:{text} audio_stream_id:{audio_stream_id} ") - if audio_stream_id is None: - audio_stream_id = str(uuid.uuid4()) - # 在这里生成音频流并存储到内存中 - model_name = req.get('model_name') - dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) - tts_model_name = dia.tts_id - if model_name: tts_model_name = model_name - logging.info(f"---tts {tts_model_name}") - tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) - if delay_gen_audio: - audio_stream = None - else: - audio_stream = io.BytesIO() - audio_text_cache[audio_stream_id] = {'text': text, 'chat_id': chat_id, "tenant_id": tenant_id, - 'audio_stream': audio_stream,'model_name':model_name} # 缓存文本以便后续生成音频流 - if delay_gen_audio is False: - try: - """ - for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): - try: - if txt is None or txt.strip() == "": - continue - for chunk in tts_mdl.tts(txt): - audio_stream.write(chunk) - except Exception as e: - continue - """ - if text is None or text.strip() == "": - audio_stream.write(b'\x00' * 100) - else: - for chunk in tts_mdl.tts(text): - audio_stream.write(chunk) - except Exception as e: - return get_error_data_result(message="get tts audio stream error.") - - # 构建音频流URL - audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}" - logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}") - # 返回音频流URL - return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id}) - - @manager.route('/agents//completions', methods=['POST']) @token_required def agent_completion(tenant_id, agent_id): diff --git a/api/db/db_models.py b/api/db/db_models.py index f97e1b4e..26d59279 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -988,7 +988,44 @@ class CanvasTemplate(DataBaseModel): class Meta: db_table = "canvas_template" +# ------------added by cyx for mesum overview +class MesumOverview(DataBaseModel): + name = CharField( + max_length=128, + null=False, + help_text="mesum name", + primary_key=False) + longitude = CharField( + max_length=40, + null=True, + help_text="Longitude", + index=False) + + latitude = CharField( + max_length=40, + null=True, + help_text="latitude", + index=False) + + antique=CharField( + max_length=1024, + null=True, + help_text="antique", + index=False) + + brief = CharField( + max_length=1024, + null=True, + help_text="brief", + index=False) + def __str__(self): + return self.name + + class Meta: + db_table = "mesum_overview" + +#------------------------------------------- def migrate_db(): with DB.transaction(): migrator = DatabaseMigrator[settings.DATABASE_TYPE.upper()].value(DB) diff --git a/api/db/services/brief_service.py b/api/db/services/brief_service.py new file mode 100644 index 00000000..507bea58 --- /dev/null +++ b/api/db/services/brief_service.py @@ -0,0 +1,31 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime + +import peewee +from werkzeug.security import generate_password_hash, check_password_hash + +from api.db import UserTenantRole +from api.db.db_models import DB, UserTenant +from api.db.db_models import User, Tenant, MesumOverview +from api.db.services.common_service import CommonService +from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format +from api.db import StatusEnum + + +class MesumOverviewService(CommonService): + model = MesumOverview + diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index aa518f67..e9f7e974 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -33,37 +33,103 @@ from rag.nlp.search import index_name from rag.utils import rmSpace, num_tokens_from_string, encoder from api.utils.file_utils import get_project_base_directory from peewee import fn -import threading, queue +import threading, queue,uuid,time +from concurrent.futures import ThreadPoolExecutor -# 创建一个 TTS 生成线程 -class TTSWorker(threading.Thread): - def __init__(self, tenant_id, tts_id, tts_text_queue, tts_audio_queue): - super().__init__() - self.tts_mdl = LLMBundle(tenant_id, LLMType.TTS, tts_id) - self.tts_text_queue = tts_text_queue - self.tts_audio_queue = tts_audio_queue - self.daemon = True # 设置为守护线程,主线程退出时,子线程也会自动退出 - def run(self): +class StreamSessionManager: + def __init__(self): + self.sessions = {} # {session_id: {'tts_model': obj, 'buffer': queue, 'task_queue': Queue}} + self.lock = threading.Lock() + self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池 + self.gc_interval = 300 # 5分钟清理一次 + + def create_session(self, tts_model): + session_id = str(uuid.uuid4()) + with self.lock: + self.sessions[session_id] = { + 'tts_model': tts_model, + 'buffer': queue.Queue(maxsize=100), # 线程安全队列 + 'task_queue': queue.Queue(), + 'active': True, + 'last_active': time.time(), + 'audio_chunk_count':0 + } + # 启动任务处理线程 + threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start() + return session_id + + def append_text(self, session_id, text): + with self.lock: + session = self.sessions.get(session_id) + if not session: return + # 将文本放入任务队列(非阻塞) + try: + session['task_queue'].put(text, block=False) + except queue.Full: + logging.warning(f"Session {session_id} task queue full") + + def _process_tasks(self, session_id): + """任务处理线程(每个会话独立)""" while True: - # 从队列中获取数据 - delta_ans = self.tts_text_queue.get() - if delta_ans is None: # 如果队列中没有数据,退出线程 + session = self.sessions.get(session_id) + if not session or not session['active']: break try: - # 调用 TTS 生成音频数据 - tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans) - if tts_input_is_valid: - logging.info(f"--tts threading {delta_ans} {tts_input_is_valid} {sanitized_text}") - bin = b"" - for chunk in self.tts_mdl.tts(sanitized_text): - bin += chunk - # 将生成的音频数据存储到队列中或直接处理 - self.tts_audio_queue.put(bin) - except Exception as e: - logging.error(f"Error generating TTS for text '{delta_ans}': {e}") + # 合并多个文本块(最多等待50ms) + texts = [] + while len(texts) < 5: # 最大合并5个文本块 + try: + text = session['task_queue'].get(timeout=0.05) + texts.append(text) + except queue.Empty: + break + if texts: + # 提交到线程池处理 + future=self.executor.submit( + self._generate_audio, + session_id, + ' '.join(texts) # 合并文本减少请求次数 + ) + future.result() # 等待转换任务执行完毕 + # 会话超时检查 + if time.time() - session['last_active'] > self.gc_interval: + self.close_session(session_id) + break + + except Exception as e: + logging.error(f"Task processing error: {str(e)}") + + def _generate_audio(self, session_id, text): + """实际生成音频(线程池执行)""" + session = self.sessions.get(session_id) + if not session: return + # logging.info(f"_generate_audio:{text}") + try: + for chunk in session['tts_model'].tts(text): + session['buffer'].put(chunk) + session['last_active'] = time.time() + session['audio_chunk_count'] = session['audio_chunk_count'] + 1 + logging.info(f"转换结束!!! {session['audio_chunk_count'] }") + except Exception as e: + session['buffer'].put(f"ERROR:{str(e)}") + + def close_session(self, session_id): + with self.lock: + if session_id in self.sessions: + # 标记会话为不活跃 + self.sessions[session_id]['active'] = False + # 延迟30秒后清理资源 + threading.Timer(10, self._clean_session, args=[session_id]).start() + + def _clean_session(self, session_id): + with self.lock: + if session_id in self.sessions: + del self.sessions[session_id] + +stream_manager = StreamSessionManager() class DialogService(CommonService): model = Dialog @@ -235,6 +301,73 @@ def validate_and_sanitize_tts_input(delta_ans, max_length=3000): # 如果通过所有检查,返回有效标志和修正后的文本 return True, delta_ans +def _should_flush(text_chunk,chunk_buffer,last_flush_time): + """智能判断是否需要立即生成音频""" + # 规则1:遇到句子结束标点 + if re.search(r'[。!?,]$', text_chunk): + return True + + if re.search(r'(\d{4})(年|月|日|,)', text_chunk): + return False # 不刷新,继续合并 + # 规则2:达到最大缓冲长度(200字符) + if sum(len(c) for c in chunk_buffer) >= 200: + return True + # 规则3:超过500ms未刷新 + if time.time() - last_flush_time > 0.5: + return True + return False + + +MAX_BUFFER_LEN = 200 # 最大缓冲长度 +FLUSH_TIMEOUT = 0.5 # 强制刷新时间(秒) + +# 智能查找文本最佳分割点(标点/语义单位/短语边界) +def find_split_position(text): + """智能查找最佳分割位置""" + # 优先查找句子结束符 + sentence_end = list(re.finditer(r'[。!?]', text)) + if sentence_end: + return sentence_end[-1].end() + + # 其次查找自然停顿符 + pause_mark = list(re.finditer(r'[,;、]', text)) + if pause_mark: + return pause_mark[-1].end() + + # 防止截断日期/数字短语 + date_pattern = re.search(r'\d+(年|月|日)(?!\d)', text) + if date_pattern: + return date_pattern.end() + + # 避免拆分常见短语 + for phrase in ["青少年", "博物馆", "参观"]: + idx = text.rfind(phrase) + if idx != -1 and idx + len(phrase) <= len(text): + return idx + len(phrase) + + return None + +# 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分 +def process_buffer(chunk_buffer, force_flush=False): + """处理文本缓冲区,返回待发送文本和剩余缓冲区""" + current_text = "".join(chunk_buffer) + if not current_text: + return "", [] + + split_pos = find_split_position(current_text) + + # 强制刷新逻辑 + if force_flush or len(current_text) >= MAX_BUFFER_LEN: + # 即使强制刷新也要尽量找合适的分割点 + if split_pos is None or split_pos < len(current_text) // 2: + split_pos = max(split_pos or 0, MAX_BUFFER_LEN) + split_pos = min(split_pos, len(current_text)) + + if split_pos is not None and split_pos > 0: + return current_text[:split_pos], [current_text[split_pos:]] + + return "", chunk_buffer + def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." st = timer() @@ -283,7 +416,10 @@ def chat(dialog, messages, stream=True, **kwargs): tts_mdl = None if prompt_config.get("tts"): - tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,dialog.tts_id) + if kwargs.get('tts_model'): + tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS,kwargs.get('tts_model')) + else: + tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS, dialog.tts_id) # try to use sql if field mapping is good to go if field_map: @@ -388,34 +524,83 @@ def chat(dialog, messages, stream=True, **kwargs): if stream: last_ans = "" answer = "" + # 创建TTS会话(提前初始化) + tts_session_id = stream_manager.create_session(tts_mdl) + audio_url = f"/tts_stream/{tts_session_id}" + first_chunk = True + chunk_buffer = [] # 新增文本缓冲 + last_flush_time = time.time() # 初始化时间戳 + for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): answer = ans delta_ans = ans[len(last_ans):] - if num_tokens_from_string(delta_ans) < 16: + if num_tokens_from_string(delta_ans) < 24: continue + last_ans = answer # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} # cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错 tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans) - #if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary - tts_input_is_valid = False + # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary + if kwargs.get('tts_disable'): + tts_input_is_valid =False + if tts_input_is_valid: + # 缓冲文本直到遇到标点 + chunk_buffer.append(sanitized_text) + # 处理缓冲区内容 + while True: + # 判断是否需要强制刷新 + force = time.time() - last_flush_time > FLUSH_TIMEOUT + to_send, remaining = process_buffer(chunk_buffer, force_flush=force) + + if not to_send: + break + + # 发送有效内容 + stream_manager.append_text(tts_session_id, to_send) + chunk_buffer = remaining + last_flush_time = time.time() + """ if tts_input_is_valid: yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)} else: yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}} + """ + + # 首块返回音频URL + if first_chunk: + yield { + "answer": answer, + "delta_ans": sanitized_text, + "audio_stream_url": audio_url, + "session_id": tts_session_id, + "reference": {} + } + first_chunk = False + else: + yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}} delta_ans = answer[len(last_ans):] if delta_ans: + # stream_manager.append_text(tts_session_id, delta_ans) # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - # cyx 2024 12 04 修正delta_ans 为空调用tts 出错 + # cyx 2024 12 04 修正delta_ans 为空 调用tts 出错 tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans) - #if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary - tts_input_is_valid = False + if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary + tts_input_is_valid = False + if tts_input_is_valid: + # 20250221 修改,在后端生成音频数据 + chunk_buffer.append(sanitized_text) + stream_manager.append_text(tts_session_id, ''.join(chunk_buffer)) + yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}} + """ if tts_input_is_valid: yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)} else: yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}} + """ + yield decorate_answer(answer) else: diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 866bde03..c98c1388 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -401,10 +401,11 @@ "max_tokens": 32768, "model_type": "chat" }, + { - "llm_name": "deepseek-coder", + "llm_name": "deepseek-reasoner", "tags": "LLM,CHAT,", - "max_tokens": 16385, + "max_tokens": 65535, "model_type": "chat" } ] diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 163aeaa3..d9ac36d3 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -143,7 +143,7 @@ class QwenTTS(Base): from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat #, SpeechSynthesisResult from dashscope.audio.tts import SpeechSynthesisResult from collections import deque - # print(f"--QwenTTS--tts_stream begin-- {text}") # cyx + print(f"--QwenTTS--tts_stream begin-- {text}") # cyx class Callback(ResultCallback): def __init__(self) -> None: self.dque = deque() @@ -206,7 +206,7 @@ class QwenTTS(Base): raise RuntimeError(str(response)) def on_close(self): - print("---Qwen call back close") # cyx + # print("---Qwen call back close") # cyx pass """ canceled for test 语音大模型CosyVoice def on_event(self, result: SpeechSynthesisResult): @@ -252,7 +252,7 @@ class QwenTTS(Base): try: for data in self.callback._run(): yield data - print(f"---Qwen return data {num_tokens_from_string(text)}") + # print(f"---Qwen return data {num_tokens_from_string(text)}") yield num_tokens_from_string(text) except Exception as e: