# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from flask import request , Response, jsonify from api import settings from api.db import LLMType from api.db import StatusEnum from api.db.services.dialog_service import DialogService,stream_manager from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.db.services.brief_service import MesumOverviewService from api.db.services.llm_service import LLMBundle from api.utils import get_uuid from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_result import logging import base64 import queue,time,uuid from threading import Lock,Thread from zhipuai import ZhipuAI # 用户已经添加的模型 cyx 2025-01-26 @manager.route('/get_llms', methods=['GET']) @token_required def my_llms(tenant_id): # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 model_type = request.args.get("type") try: res = {} for o in TenantLLMService.get_my_llms(tenant_id): if model_type is None or o["model_type"] == model_type: # 增加按类型的筛选 if o["llm_factory"] not in res: res[o["llm_factory"]] = { "tags": o["tags"], "llm": [] } res[o["llm_factory"]]["llm"].append({ "type": o["model_type"], "name": o["llm_name"], "used_token": o["used_tokens"] }) return get_result(data=res) except Exception as e: return get_error_data_result(message=f"Get LLMS error {e}") main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链" @manager.route('/photo/recongeText', methods=['POST']) @token_required def upload_file(tenant_id): if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'No selected file'}), 400 if file and allowed_file(file.filename): file_size = request.content_length img_base = base64.b64encode(file.read()).decode('utf-8') req_antique = request.form.get('antique',None) if req_antique is None: req_antique = main_antiquity logging.info(f"recevie photo file {file.filename} {file_size} 识别中....") client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey response = client.chat.completions.create( model="glm-4v-plus", # 填写需要调用的模型名称 messages=[ { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": img_base } }, { "type": "text", "text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家," f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称," f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍" f"你的回答有2个结果,第一个结果是是从文字进行分析出匹配文物,候选文物只能如下:{req_antique}," f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出," f",第二个结果是原始识别的所有文字" "2个结果输出以{ }的json格式给出,匹配文物的键值为antique,如果有多个请加序号,如:antique1,antique2," f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据") } ] } ] ) message = response.choices[0].message logging.info(message.content) return jsonify({'message': 'File uploaded successfully','text':message.content }), 200 def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} #get_all @manager.route('/mesum/list', methods=['GET']) @token_required def mesum_list(tenant_id): # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 # model_type = request.args.get("type") try: res = [] overviews=MesumOverviewService.get_all() for o in overviews: res.append(o.to_dict()) return get_result(data=res) except Exception as e: return get_error_data_result(message=f"Get LLMS error {e}") @manager.route('/mesum/set_antique', methods=['POST']) @token_required def mesum_set_antique(tenant_id): global main_antiquity # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 req_data = request.json req_data_antique=req_data.get('antique',None) try: if req_data_antique: main_antiquity = req_data_antique print(main_antiquity) return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'}) except Exception as e: return get_error_data_result(message=f"Get LLMS error {e}") audio_text_cache = {} cache_lock = Lock() CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 # 全角字符到半角字符的映射 def fullwidth_to_halfwidth(s): full_to_half_map = { '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', '、': '、', '・': '・', ':': ':' } return ''.join(full_to_half_map.get(char, char) for char in s) def split_text_at_punctuation(text, chunk_size=100): # 使用正则表达式找到所有的标点符号和特殊字符 punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' tokens = re.split(punctuation_pattern, text) # 移除空字符串 tokens = [token for token in tokens if token] # 存储最终的文本块 chunks = [] current_chunk = '' for token in tokens: if len(current_chunk) + len(token) <= chunk_size: # 如果添加当前token后长度不超过chunk_size,则添加到当前块 current_chunk += (token + ' ') else: # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 chunks.append(current_chunk.strip()) current_chunk = token + ' ' # 添加最后一个块(如果有剩余) if current_chunk: chunks.append(current_chunk.strip()) return chunks def extract_text_from_markdown(markdown_text): # 移除Markdown标题 text = re.sub(r'#\s*[^#]+', '', markdown_text) # 移除内联代码块 text = re.sub(r'`[^`]+`', '', text) # 移除代码块 text = re.sub(r'```[\s\S]*?```', '', text) # 移除加粗和斜体 text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) # 移除链接 text = re.sub(r'\[.*?\]\(.*?\)', '', text) # 移除图片 text = re.sub(r'!\[.*?\]\(.*?\)', '', text) # 移除HTML标签 text = re.sub(r'<[^>]+>', '', text) # 转换标点符号 # text = re.sub(r'[^\w\s]', '', text) text = fullwidth_to_halfwidth(text) # 移除多余的空格 text = re.sub(r'\s+', ' ', text).strip() return text def clean_audio_cache(): """定时清理过期缓存""" with cache_lock: now = time.time() expired_keys = [ k for k, v in audio_text_cache.items() if now - v['created_at'] > CACHE_EXPIRE_SECONDS ] for k in expired_keys: entry = audio_text_cache.pop(k, None) if entry and entry.get('audio_stream'): entry['audio_stream'].close() def start_background_cleaner(): """启动后台清理线程""" def cleaner_loop(): while True: time.sleep(180) # 每3分钟清理一次 clean_audio_cache() cleaner_thread = Thread(target=cleaner_loop, daemon=True) cleaner_thread.start() # 应用启动时启动清理线程 start_background_cleaner() @manager.route('/tts_stream/') def tts_stream(session_id): def generate(): retry_count = 0 session = None count = 0; try: while retry_count < 1: session = stream_manager.sessions.get(session_id) if not session or not session['active']: break try: chunk = session['buffer'].get(timeout=5) # 30秒超时 count = count + 1 if isinstance(chunk, str) and chunk.startswith("ERROR"): logging.info("---tts stream error!!!!") yield f"data:{{'error':'{chunk[6:]}'}}\n\n" break yield chunk retry_count = 0 # 成功收到数据重置重试计数器 except queue.Empty: retry_count += 1 yield b'' # 保持连接 finally: # 确保流结束后关闭会话 if session: # 延迟关闭会话,确保所有数据已发送 time.sleep(5) # 等待5秒确保流结束 stream_manager.close_session(session_id) logging.info(f"Session {session_id} closed.") resp = Response(generate(), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") return resp @manager.route('/chats//tts/', methods=['GET']) def dialog_tts_get(chat_id, audio_stream_id): with cache_lock: tts_info = audio_text_cache.pop(audio_stream_id, None) # 取出即删除 try: req = tts_info if not req: return get_error_data_result(message="Audio stream not found or expired.") audio_stream = req.get('audio_stream') tenant_id = req.get('tenant_id') chat_id = req.get('chat_id') text = req.get('text', "..") model_name = req.get('model_name') dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) if not dia: return get_error_data_result(message="You do not own the chat") tts_model_name = dia.tts_id if model_name: tts_model_name = model_name tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) def stream_audio(): try: for chunk in tts_mdl.tts(text): yield chunk except Exception as e: yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode('utf-8') def generate(): data = audio_stream.read(1024) while data: yield data data = audio_stream.read(1024) if audio_stream: # 确保流的位置在开始处 audio_stream.seek(0) resp = Response(generate(), mimetype="audio/mpeg") else: resp = Response(stream_audio(), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") return resp except Exception as e: logging.error(f"音频流传输错误: {str(e)}", exc_info=True) return get_error_data_result(message="音频流传输失败") finally: # 确保资源释放 if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed: tts_info['audio_stream'].close() @manager.route('/chats//tts', methods=['POST']) @token_required def dialog_tts_post(tenant_id, chat_id): try: req = request.json if not req.get("text"): return get_error_data_result(message="Please input your question.") delay_gen_audio = req.get('delay_gen_audio', False) # text = extract_text_from_markdown(req.get('text')) text = req.get('text') model_name = req.get('model_name') audio_stream_id = req.get('audio_stream_id', None) if audio_stream_id is None: audio_stream_id = str(uuid.uuid4()) # 在这里生成音频流并存储到内存中 dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) tts_model_name = dia.tts_id if model_name: tts_model_name = model_name tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) if delay_gen_audio: audio_stream = None else: audio_stream = io.BytesIO() # 结构化缓存数据 tts_info = { 'text': text, 'tenant_id': tenant_id, 'chat_id': chat_id, 'created_at': time.time(), 'audio_stream': audio_stream, # 维持原有逻辑 'model_name': req.get('model_name'), 'delay_gen_audio': delay_gen_audio, # 明确存储状态 audio_stream_id: audio_stream_id } with cache_lock: audio_text_cache[audio_stream_id] = tts_info if delay_gen_audio is False: try: """ for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): try: if txt is None or txt.strip() == "": continue for chunk in tts_mdl.tts(txt): audio_stream.write(chunk) except Exception as e: continue """ if text is None or text.strip() == "": audio_stream.write(b'\x00' * 100) else: # 确保在流的末尾写入 audio_stream.seek(0, io.SEEK_END) for chunk in tts_mdl.tts(text): audio_stream.write(chunk) except Exception as e: logging.info(f"--error:{e}") with cache_lock: audio_text_cache.pop(audio_stream_id, None) return get_error_data_result(message="get tts audio stream error.") # 构建音频流URL audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}" logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}") # 返回音频流URL return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id}) except Exception as e: logging.error(f"请求处理失败: {str(e)}", exc_info=True) return get_error_data_result(message="服务器内部错误")