# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from flask import request , Response, jsonify,stream_with_context from api import settings from api.db import LLMType from api.db import StatusEnum from api.db.services.dialog_service import DialogService,stream_manager from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.db.services.brief_service import MesumOverviewService from api.db.services.llm_service import LLMBundle from api.db.services.antique_service import MesumAntiqueService from api.utils import get_uuid from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_result from api.utils.file_utils import get_project_base_directory import logging import base64, gzip from io import BytesIO import queue,time,uuid,os,array from threading import Lock,Thread from zhipuai import ZhipuAI # 用户已经添加的模型 cyx 2025-01-26 @manager.route('/get_llms', methods=['GET']) @token_required def my_llms(tenant_id): # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 model_type = request.args.get("type") try: res = {} for o in TenantLLMService.get_my_llms(tenant_id): if model_type is None or o["model_type"] == model_type: # 增加按类型的筛选 if o["llm_factory"] not in res: res[o["llm_factory"]] = { "tags": o["tags"], "llm": [] } res[o["llm_factory"]]["llm"].append({ "type": o["model_type"], "name": o["llm_name"], "used_token": o["used_tokens"] }) return get_result(data=res) except Exception as e: return get_error_data_result(message=f"Get LLMS error {e}") main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链" @manager.route('/photo/recongeText/', methods=['POST']) @token_required def upload_file(tenant_id,mesum_id): if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 antiques_selected = "" if mesum_id: """ e,mesum_breif = MesumOverviewService.get_by_id(mesum_id) if not e: logging.info(f"没有找到匹配的博物馆信息,mesum_id={mesum_id}") else: antiques_selected =f"结果从:{mesum_breif.antique} 中进行选择" """ mesum_id_str = str(mesum_id) antique_labels=get_antique_labels(mesum_id) # 使用列表推导式和str()函数将所有元素转换为字符串 string_elements = [str(element) for element in antique_labels] # 使用join()方法将字符串元素连接起来,以逗号为分隔符 joined_string = ','.join(string_elements) antiques_selected = f"结果从:{joined_string} 中进行选择" logging.info(f"{mesum_id} {joined_string}") prompt = (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家," f"请识别这个图片中文字,重点识别出含在文字中的某一文物标题、某一个历史事件或某一历史人物," f"你的回答有2个结果,第一个结果是是从文字中识别出历史文物、历史事件、历史人物," f"此回答时只给出匹配的文物、事件、人物,不需要其他多余的文字,{antiques_selected}" f",第二个结果是原始识别的所有文字" "2个结果输出以{ }的json格式给出,匹配文物、事件、人物的键值为antique,如果有多个请加序号,如:antique1,antique2," f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据") file = request.files['file'] if file.filename == '': return jsonify({'error': 'No selected file'}), 400 if file and allowed_file(file.filename): file_size = request.content_length img_base = base64.b64encode(file.read()).decode('utf-8') req_antique = request.form.get('antique',None) if req_antique is None: req_antique = main_antiquity logging.info(f"recevie photo file {file.filename} {file_size} 识别中....") client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey response = client.chat.completions.create( model="glm-4v-plus", # 填写需要调用的模型名称 messages=[ { "role": "user", "content": [ { "type": "image_url", "image_url": { "url": img_base } }, { "type": "text", "text": prompt } ] } ] ) message = response.choices[0].message logging.info(message.content) return jsonify({'message': 'File uploaded successfully','text':message.content }), 200 def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'} #get_all @manager.route('/mesum/list', methods=['GET']) @token_required def mesum_list(tenant_id): # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 # model_type = request.args.get("type") try: res = [] overviews=MesumOverviewService.get_all() for o in overviews: res.append(o.to_dict()) return get_result(data=res) except Exception as e: return get_error_data_result(message=f"Get LLMS error {e}") @manager.route('/mesum/set_antique', methods=['POST']) @token_required def mesum_set_antique(tenant_id): global main_antiquity # request.args.get("id") 通过request.args.get 获取GET 方法传入的参数 req_data = request.json req_data_antique=req_data.get('antique',None) try: if req_data_antique: main_antiquity = req_data_antique print(main_antiquity) return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'}) except Exception as e: return get_error_data_result(message=f"Get LLMS error {e}") audio_text_cache = {} cache_lock = Lock() CACHE_EXPIRE_SECONDS = 600 # 10分钟过期 # 全角字符到半角字符的映射 def fullwidth_to_halfwidth(s): full_to_half_map = { '!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'", '(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.', '/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?', '@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`', '{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「', '」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」', '、': '、', '・': '・', ':': ':' } return ''.join(full_to_half_map.get(char, char) for char in s) def split_text_at_punctuation(text, chunk_size=100): # 使用正则表达式找到所有的标点符号和特殊字符 punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+' tokens = re.split(punctuation_pattern, text) # 移除空字符串 tokens = [token for token in tokens if token] # 存储最终的文本块 chunks = [] current_chunk = '' for token in tokens: if len(current_chunk) + len(token) <= chunk_size: # 如果添加当前token后长度不超过chunk_size,则添加到当前块 current_chunk += (token + ' ') else: # 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块 chunks.append(current_chunk.strip()) current_chunk = token + ' ' # 添加最后一个块(如果有剩余) if current_chunk: chunks.append(current_chunk.strip()) return chunks def extract_text_from_markdown(markdown_text): # 移除Markdown标题 text = re.sub(r'#\s*[^#]+', '', markdown_text) # 移除内联代码块 text = re.sub(r'`[^`]+`', '', text) # 移除代码块 text = re.sub(r'```[\s\S]*?```', '', text) # 移除加粗和斜体 text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text) # 移除链接 text = re.sub(r'\[.*?\]\(.*?\)', '', text) # 移除图片 text = re.sub(r'!\[.*?\]\(.*?\)', '', text) # 移除HTML标签 text = re.sub(r'<[^>]+>', '', text) # 转换标点符号 # text = re.sub(r'[^\w\s]', '', text) text = fullwidth_to_halfwidth(text) # 移除多余的空格 text = re.sub(r'\s+', ' ', text).strip() return text def encode_gzip_base64(original_data: bytes) -> str: """核心编码过程:二进制数据 → Gzip压缩 → Base64编码""" # Step 1: Gzip 压缩 with BytesIO() as buf: with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file: gz_file.write(original_data) compressed_bytes = buf.getvalue() # Step 2: Base64 编码(配置与Android端匹配) return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP) def clean_audio_cache(): """定时清理过期缓存""" with cache_lock: now = time.time() expired_keys = [ k for k, v in audio_text_cache.items() if now - v['created_at'] > CACHE_EXPIRE_SECONDS ] for k in expired_keys: entry = audio_text_cache.pop(k, None) if entry and entry.get('audio_stream'): entry['audio_stream'].close() def start_background_cleaner(): """启动后台清理线程""" def cleaner_loop(): while True: time.sleep(180) # 每3分钟清理一次 clean_audio_cache() cleaner_thread = Thread(target=cleaner_loop, daemon=True) cleaner_thread.start() # 应用启动时启动清理线程 start_background_cleaner() @manager.route('/tts_stream/',methods=['GET']) def tts_stream(session_id): session = stream_manager.sessions.get(session_id) def generate(): count = 0; path = os.path.join(get_project_base_directory(), "api", "apps/sdk/test.mp3") fmp3 =open(path, 'rb') finished_event = session['finished'] try: while not finished_event.is_set() : if not session or not session['active']: break try: chunk = session['buffer'].get_nowait() # count = count + 1 if isinstance(chunk, str) and chunk.startswith("ERROR"): logging.info(f"---tts stream error!!!! {chunk}") yield f"data:{{'error':'{chunk[6:]}'}}\n\n" break if session['stream_format'] == "wav": gzip_base64_data = encode_gzip_base64(chunk) + "\r\n" yield gzip_base64_data else: yield chunk retry_count = 0 # 成功收到数据重置重试计数器 except queue.Empty: if session['stream_format'] == "wav": # yield encode_gzip_base64(b'\x03\x04' * 1) + "\r\n" pass else: yield b'' # 保持连接 #data = fmp3.read(1024) #yield data except Exception as e: logging.info(f"tts streag get error2 {e} ") finally: # 确保流结束后关闭会话 if session: # 延迟关闭会话,确保所有数据已发送 stream_manager.close_session(session_id) logging.info(f"Session {session_id} closed.") # 关键响应头设置 if session['stream_format'] == "wav": resp = Response(stream_with_context(generate()), mimetype="audio/mpeg") else: resp = Response(stream_with_context(generate()), mimetype="audio/wav") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") return resp @manager.route('/chats//tts/', methods=['GET']) def dialog_tts_get(chat_id, audio_stream_id): with cache_lock: tts_info = audio_text_cache.pop(audio_stream_id, None) # 取出即删除 try: req = tts_info if not req: return get_error_data_result(message="Audio stream not found or expired.") audio_stream = req.get('audio_stream') tenant_id = req.get('tenant_id') chat_id = req.get('chat_id') text = req.get('text', "..") model_name = req.get('model_name') sample_rate = req.get('tts_sample_rate',8000) # 默认8K stream_format = req.get('tts_stream_format','mp3') dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) if not dia: return get_error_data_result(message="You do not own the chat") tts_model_name = dia.tts_id if model_name: tts_model_name = model_name tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) logging.info(f"dialog_tts_get {sample_rate} {stream_format}") def stream_audio(): try: for chunk in tts_mdl.tts(text,sample_rate=sample_rate,stream_format=stream_format): if stream_format =='wav': yield encode_gzip_base64(chunk) + "\r\n" else: yield chunk except Exception as e: yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode('utf-8') def generate(): data = audio_stream.read(1024) while data: yield data data = audio_stream.read(1024) if audio_stream: # 确保流的位置在开始处 audio_stream.seek(0) resp = Response(generate(), mimetype="audio/mpeg") else: if stream_format == 'wav': resp = Response(stream_audio(), mimetype="audio/wav") else: resp = Response(stream_audio(), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") return resp except Exception as e: logging.error(f"音频流传输错误: {str(e)}", exc_info=True) return get_error_data_result(message="音频流传输失败") finally: # 确保资源释放 if tts_info and tts_info.get('audio_stream') and not tts_info['audio_stream'].closed: tts_info['audio_stream'].close() @manager.route('/chats//tts', methods=['POST']) @token_required def dialog_tts_post(tenant_id, chat_id): req = request.json try: if not req.get("text"): return get_error_data_result(message="Please input your question.") text = req.get('text') delay_gen_audio = req.get('delay_gen_audio', False) model_name = req.get('model_name') audio_stream_id = req.get('audio_stream_id', None) if audio_stream_id is None: audio_stream_id = str(uuid.uuid4()) # 在这里生成音频流并存储到内存中 dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) tts_model_name = dia.tts_id if model_name: tts_model_name = model_name tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) if delay_gen_audio: audio_stream = None else: audio_stream = io.BytesIO() tts_stream_format = req.get('tts_stream_format', "mp3") tts_sample_rate = req.get('tts_sample_rate', 8000) logging.info(f"tts post {tts_sample_rate} {tts_stream_format}") # 结构化缓存数据 tts_info = { 'text': text, 'tenant_id': tenant_id, 'chat_id': chat_id, 'created_at': time.time(), 'audio_stream': audio_stream, # 维持原有逻辑 'model_name': req.get('model_name'), 'delay_gen_audio': delay_gen_audio, # 明确存储状态 'audio_stream_id': audio_stream_id, 'tts_sample_rate':tts_sample_rate, 'tts_stream_format':tts_stream_format } with cache_lock: audio_text_cache[audio_stream_id] = tts_info if delay_gen_audio is False: try: audio_stream.seek(0, io.SEEK_END) if text is None or text.strip() == "": audio_stream.write(b'\x00' * 100) else: # 确保在流的末尾写入 for chunk in tts_mdl.tts(text,sample_rate=tts_sample_rate,stream_formate=tts_stream_format): audio_stream.write(chunk) except Exception as e: logging.info(f"--error:{e}") with cache_lock: audio_text_cache.pop(audio_stream_id, None) return get_error_data_result(message="get tts audio stream error.") # 构建音频流URL audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}" logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url} " f"{tts_sample_rate} {tts_stream_format}") # 返回音频流URL return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id, "sample_rate":tts_sample_rate, "stream_format":tts_stream_format,}) except Exception as e: logging.error(f"请求处理失败: {str(e)}", exc_info=True) return get_error_data_result(message="服务器内部错误") def get_antique_categories(mesum_id): res = MesumAntiqueService.get_all_categories() return res def get_labels_ext(mesum_id): res = MesumAntiqueService.get_labels_ext(mesum_id) return res def get_antique_labels(mesum_id): res = MesumAntiqueService.get_all_labels() return res def get_all_antiques(mesum_id): res =[] antiques=MesumAntiqueService.get_by_mesum_id(mesum_id) for o in antiques: res.append(o.to_dict()) return res @manager.route('/mesum/antique/', methods=['GET']) def mesum_antique_get(mesum_id): try: data = { "anqituqes":get_all_antiques(mesum_id), "categories":get_antique_categories(mesum_id), "labels":get_antique_labels(mesum_id) } return get_result(data=data) except Exception as e: return get_error_data_result(message=f"Get mesum antique error {e}") # 按照mesum_id 获得此博物馆的展品清单 @manager.route('/mesum/antique_brief/', methods=['GET']) @token_required def mesum_antique_get_brief(tenant_id,mesum_id): try: data = { "categories":get_antique_categories(mesum_id), "labels":get_labels_ext(mesum_id) } return get_result(data=data) except Exception as e: return get_error_data_result(message=f"Get mesum antique error {e}") @manager.route('/mesum/antique_detail//', methods=['GET']) @token_required def mesum_antique_get_full(tenant_id,mesum_id,antique_id): try: logging.info(f"mesum_antique_get_full {mesum_id} {antique_id}") return get_result(data=MesumAntiqueService.get_antique_by_id(mesum_id,antique_id)) except Exception as e: return get_error_data_result(message=f"Get mesum antique error {e}") def audio_fade_in(audio_data, fade_length): # 假设音频数据是16位单声道PCM # 将二进制数据转换为整数数组 samples = array.array('h', audio_data) # 对前fade_length个样本进行淡入处理 for i in range(fade_length): fade_factor = i / fade_length samples[i] = int(samples[i] * fade_factor) # 将整数数组转换回二进制数据 return samples.tobytes()