diff --git a/api/apps/sdk/dale_extra.py b/api/apps/sdk/dale_extra.py index b05119ec..88663308 100644 --- a/api/apps/sdk/dale_extra.py +++ b/api/apps/sdk/dale_extra.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from flask import request , Response, jsonify +from flask import request , Response, jsonify,stream_with_context from api import settings from api.db import LLMType from api.db import StatusEnum @@ -23,12 +23,15 @@ from api.db.services.llm_service import TenantLLMService from api.db.services.user_service import TenantService from api.db.services.brief_service import MesumOverviewService from api.db.services.llm_service import LLMBundle +from api.db.services.antique_service import MesumAntiqueService from api.utils import get_uuid from api.utils.api_utils import get_error_data_result, token_required from api.utils.api_utils import get_result +from api.utils.file_utils import get_project_base_directory import logging -import base64 -import queue,time,uuid +import base64, gzip +from io import BytesIO +import queue,time,uuid,os,array from threading import Lock,Thread from zhipuai import ZhipuAI @@ -59,12 +62,37 @@ def my_llms(tenant_id): main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链" -@manager.route('/photo/recongeText', methods=['POST']) +@manager.route('/photo/recongeText/', methods=['POST']) @token_required -def upload_file(tenant_id): +def upload_file(tenant_id,mesum_id): if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 + antiques_selected = "" + if mesum_id: + """ + e,mesum_breif = MesumOverviewService.get_by_id(mesum_id) + if not e: + logging.info(f"没有找到匹配的博物馆信息,mesum_id={mesum_id}") + else: + antiques_selected =f"结果从:{mesum_breif.antique} 中进行选择" + """ + mesum_id_str = str(mesum_id) + antique_labels=get_antique_labels(mesum_id) + # 使用列表推导式和str()函数将所有元素转换为字符串 + string_elements = [str(element) for element in antique_labels] + # 使用join()方法将字符串元素连接起来,以逗号为分隔符 + joined_string = ','.join(string_elements) + antiques_selected = f"结果从:{joined_string} 中进行选择" + logging.info(f"{mesum_id} {joined_string}") + prompt = (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家," + f"请识别这个图片中文字,重点识别出含在文字中的某一文物标题、某一个历史事件或某一历史人物," + f"你的回答有2个结果,第一个结果是是从文字中识别出历史文物、历史事件、历史人物," + f"此回答时只给出匹配的文物、事件、人物,不需要其他多余的文字,{antiques_selected}" + f",第二个结果是原始识别的所有文字" + "2个结果输出以{ }的json格式给出,匹配文物、事件、人物的键值为antique,如果有多个请加序号,如:antique1,antique2," + f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据") + file = request.files['file'] if file.filename == '': @@ -92,14 +120,7 @@ def upload_file(tenant_id): }, { "type": "text", - "text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家," - f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称," - f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍" - f"你的回答有2个结果,第一个结果是是从文字进行分析出匹配文物,候选文物只能如下:{req_antique}," - f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出," - f",第二个结果是原始识别的所有文字" - "2个结果输出以{ }的json格式给出,匹配文物的键值为antique,如果有多个请加序号,如:antique1,antique2," - f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据") + "text": prompt } ] } @@ -213,6 +234,17 @@ def extract_text_from_markdown(markdown_text): return text +def encode_gzip_base64(original_data: bytes) -> str: + """核心编码过程:二进制数据 → Gzip压缩 → Base64编码""" + # Step 1: Gzip 压缩 + with BytesIO() as buf: + with gzip.GzipFile(fileobj=buf, mode='wb') as gz_file: + gz_file.write(original_data) + compressed_bytes = buf.getvalue() + + # Step 2: Base64 编码(配置与Android端匹配) + return base64.b64encode(compressed_bytes).decode('utf-8') # 默认不带换行符(等同于Android的Base64.NO_WRAP) + def clean_audio_cache(): """定时清理过期缓存""" with cache_lock: @@ -241,38 +273,55 @@ def start_background_cleaner(): # 应用启动时启动清理线程 start_background_cleaner() -@manager.route('/tts_stream/') +@manager.route('/tts_stream/',methods=['GET']) def tts_stream(session_id): + session = stream_manager.sessions.get(session_id) def generate(): - retry_count = 0 - session = None count = 0; + path = os.path.join(get_project_base_directory(), "api", "apps/sdk/test.mp3") + fmp3 =open(path, 'rb') + finished_event = session['finished'] try: - while retry_count < 1: - session = stream_manager.sessions.get(session_id) + while not finished_event.is_set() : if not session or not session['active']: break try: - chunk = session['buffer'].get(timeout=5) # 30秒超时 + chunk = session['buffer'].get_nowait() # count = count + 1 if isinstance(chunk, str) and chunk.startswith("ERROR"): - logging.info("---tts stream error!!!!") + logging.info(f"---tts stream error!!!! {chunk}") yield f"data:{{'error':'{chunk[6:]}'}}\n\n" break - yield chunk + if session['stream_format'] == "wav": + gzip_base64_data = encode_gzip_base64(chunk) + "\r\n" + yield gzip_base64_data + else: + yield chunk retry_count = 0 # 成功收到数据重置重试计数器 except queue.Empty: - retry_count += 1 - yield b'' # 保持连接 + if session['stream_format'] == "wav": + # yield encode_gzip_base64(b'\x03\x04' * 1) + "\r\n" + pass + else: + yield b'' # 保持连接 + #data = fmp3.read(1024) + #yield data + except Exception as e: + logging.info(f"tts streag get error2 {e} ") + + finally: # 确保流结束后关闭会话 if session: # 延迟关闭会话,确保所有数据已发送 - time.sleep(5) # 等待5秒确保流结束 stream_manager.close_session(session_id) logging.info(f"Session {session_id} closed.") + # 关键响应头设置 - resp = Response(generate(), mimetype="audio/mpeg") + if session['stream_format'] == "wav": + resp = Response(stream_with_context(generate()), mimetype="audio/mpeg") + else: + resp = Response(stream_with_context(generate()), mimetype="audio/wav") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") @@ -291,17 +340,23 @@ def dialog_tts_get(chat_id, audio_stream_id): chat_id = req.get('chat_id') text = req.get('text', "..") model_name = req.get('model_name') + sample_rate = req.get('tts_sample_rate',8000) # 默认8K + stream_format = req.get('tts_stream_format','mp3') dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value) if not dia: return get_error_data_result(message="You do not own the chat") tts_model_name = dia.tts_id if model_name: tts_model_name = model_name tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id) + logging.info(f"dialog_tts_get {sample_rate} {stream_format}") def stream_audio(): try: - for chunk in tts_mdl.tts(text): - yield chunk + for chunk in tts_mdl.tts(text,sample_rate=sample_rate,stream_format=stream_format): + if stream_format =='wav': + yield encode_gzip_base64(chunk) + "\r\n" + else: + yield chunk except Exception as e: yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, @@ -318,7 +373,10 @@ def dialog_tts_get(chat_id, audio_stream_id): audio_stream.seek(0) resp = Response(generate(), mimetype="audio/mpeg") else: - resp = Response(stream_audio(), mimetype="audio/mpeg") + if stream_format == 'wav': + resp = Response(stream_audio(), mimetype="audio/wav") + else: + resp = Response(stream_audio(), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") @@ -328,20 +386,19 @@ def dialog_tts_get(chat_id, audio_stream_id): return get_error_data_result(message="音频流传输失败") finally: # 确保资源释放 - if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed: + if tts_info and tts_info.get('audio_stream') and not tts_info['audio_stream'].closed: tts_info['audio_stream'].close() @manager.route('/chats//tts', methods=['POST']) @token_required def dialog_tts_post(tenant_id, chat_id): + req = request.json try: - req = request.json if not req.get("text"): return get_error_data_result(message="Please input your question.") - delay_gen_audio = req.get('delay_gen_audio', False) - # text = extract_text_from_markdown(req.get('text')) text = req.get('text') + delay_gen_audio = req.get('delay_gen_audio', False) model_name = req.get('model_name') audio_stream_id = req.get('audio_stream_id', None) if audio_stream_id is None: @@ -355,6 +412,10 @@ def dialog_tts_post(tenant_id, chat_id): audio_stream = None else: audio_stream = io.BytesIO() + + tts_stream_format = req.get('tts_stream_format', "mp3") + tts_sample_rate = req.get('tts_sample_rate', 8000) + logging.info(f"tts post {tts_sample_rate} {tts_stream_format}") # 结构化缓存数据 tts_info = { 'text': text, @@ -364,30 +425,21 @@ def dialog_tts_post(tenant_id, chat_id): 'audio_stream': audio_stream, # 维持原有逻辑 'model_name': req.get('model_name'), 'delay_gen_audio': delay_gen_audio, # 明确存储状态 - audio_stream_id: audio_stream_id + 'audio_stream_id': audio_stream_id, + 'tts_sample_rate':tts_sample_rate, + 'tts_stream_format':tts_stream_format } with cache_lock: audio_text_cache[audio_stream_id] = tts_info - if delay_gen_audio is False: try: - """ - for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): - try: - if txt is None or txt.strip() == "": - continue - for chunk in tts_mdl.tts(txt): - audio_stream.write(chunk) - except Exception as e: - continue - """ + audio_stream.seek(0, io.SEEK_END) if text is None or text.strip() == "": audio_stream.write(b'\x00' * 100) else: # 确保在流的末尾写入 - audio_stream.seek(0, io.SEEK_END) - for chunk in tts_mdl.tts(text): + for chunk in tts_mdl.tts(text,sample_rate=tts_sample_rate,stream_formate=tts_stream_format): audio_stream.write(chunk) except Exception as e: logging.info(f"--error:{e}") @@ -397,10 +449,79 @@ def dialog_tts_post(tenant_id, chat_id): # 构建音频流URL audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}" - logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}") + logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url} " + f"{tts_sample_rate} {tts_stream_format}") # 返回音频流URL - return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id}) + return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id, + "sample_rate":tts_sample_rate, "stream_format":tts_stream_format,}) except Exception as e: logging.error(f"请求处理失败: {str(e)}", exc_info=True) return get_error_data_result(message="服务器内部错误") + +def get_antique_categories(mesum_id): + res = MesumAntiqueService.get_all_categories() + return res + +def get_labels_ext(mesum_id): + res = MesumAntiqueService.get_labels_ext(mesum_id) + return res + +def get_antique_labels(mesum_id): + res = MesumAntiqueService.get_all_labels() + return res + +def get_all_antiques(mesum_id): + res =[] + antiques=MesumAntiqueService.get_by_mesum_id(mesum_id) + for o in antiques: + res.append(o.to_dict()) + return res + + +@manager.route('/mesum/antique/', methods=['GET']) +def mesum_antique_get(mesum_id): + try: + data = { + "anqituqes":get_all_antiques(mesum_id), + "categories":get_antique_categories(mesum_id), + "labels":get_antique_labels(mesum_id) + } + return get_result(data=data) + except Exception as e: + return get_error_data_result(message=f"Get mesum antique error {e}") + +# 按照mesum_id 获得此博物馆的展品清单 +@manager.route('/mesum/antique_brief/', methods=['GET']) +@token_required +def mesum_antique_get_brief(tenant_id,mesum_id): + try: + data = { + "categories":get_antique_categories(mesum_id), + "labels":get_labels_ext(mesum_id) + } + return get_result(data=data) + except Exception as e: + return get_error_data_result(message=f"Get mesum antique error {e}") + +@manager.route('/mesum/antique_detail//', methods=['GET']) +@token_required +def mesum_antique_get_full(tenant_id,mesum_id,antique_id): + try: + logging.info(f"mesum_antique_get_full {mesum_id} {antique_id}") + return get_result(data=MesumAntiqueService.get_antique_by_id(mesum_id,antique_id)) + except Exception as e: + return get_error_data_result(message=f"Get mesum antique error {e}") + +def audio_fade_in(audio_data, fade_length): + # 假设音频数据是16位单声道PCM + # 将二进制数据转换为整数数组 + samples = array.array('h', audio_data) + + # 对前fade_length个样本进行淡入处理 + for i in range(fade_length): + fade_factor = i / fade_length + samples[i] = int(samples[i] * fade_factor) + + # 将整数数组转换回二进制数据 + return samples.tobytes() \ No newline at end of file diff --git a/api/apps/sdk/test.mp3 b/api/apps/sdk/test.mp3 new file mode 100644 index 00000000..e0d7c506 Binary files /dev/null and b/api/apps/sdk/test.mp3 differ diff --git a/api/db/db_models.py b/api/db/db_models.py index 26d59279..256c9eef 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -25,7 +25,7 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer from flask_login import UserMixin from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate from peewee import ( - BigIntegerField, BooleanField, CharField, + BigIntegerField, BooleanField, CharField,AutoField, CompositeKey, IntegerField, TextField, FloatField, DateTimeField, Field, Model, Metadata ) @@ -1025,6 +1025,23 @@ class MesumOverview(DataBaseModel): class Meta: db_table = "mesum_overview" +# added by cyx for mesum_antique +class MesumAntique(DataBaseModel): + sn = CharField(max_length=100, null=True) + label = CharField(max_length=100, null=True) + description = TextField(null=True) + category = CharField(max_length=100, null=True) + group = CharField(max_length=100, null=True) + background = TextField(null=True) + value = TextField(null=True) + discovery = TextField(null=True) + id = AutoField(primary_key=True) + mesum_id = CharField(max_length=100, null=True) + combined = TextField(null=True) + + + class Meta: + db_table = 'mesum_antique' #------------------------------------------- def migrate_db(): with DB.transaction(): diff --git a/api/db/services/antique_service.py b/api/db/services/antique_service.py new file mode 100644 index 00000000..fdcbdbe6 --- /dev/null +++ b/api/db/services/antique_service.py @@ -0,0 +1,88 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from datetime import datetime + +import peewee +from werkzeug.security import generate_password_hash, check_password_hash + +from api.db import UserTenantRole +from api.db.db_models import DB, UserTenant +from api.db.db_models import User, Tenant, MesumAntique +from api.db.services.common_service import CommonService +from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format +from api.db import StatusEnum + + +class MesumAntiqueService(CommonService): + model = MesumAntique + + @classmethod + @DB.connection_context() + def get_by_mesum_id(cls, mesum_id): + objs = cls.query(mesum_id=mesum_id) + return objs + + @classmethod + @DB.connection_context() + def get_all_categories(cls): + # 查询所有唯一的category + categories = [category.category for category in cls.model.select(cls.model.category).distinct().execute() if category.category] + return categories + + @classmethod + @DB.connection_context() + def get_all_labels(cls): + # 查询所有去重后的label + labels = [label.label for label in cls.model.select(cls.model.label).distinct().execute() if label.label] + return labels + + @classmethod + @DB.connection_context() + def get_labels_ext(cls, mesum_id): + # 根据mesum_id过滤,并排除空的category + query = cls.model.select().where( + (cls.model.mesum_id == mesum_id) & + (cls.model.category != "") + ).order_by(cls.model.category) + + # 按category分组并存储结果 + grouped_data = {} + for obj in query.dicts().execute(): + category = obj['category'] + if category not in grouped_data: + grouped_data[category] = [] + grouped_data[category].append({ + 'id': obj['id'], + 'label': obj['label'] + }) + + return grouped_data + + @classmethod + @DB.connection_context() + def get_antique_by_id(cls, mesum_id,antique_id): + + query = cls.model.select().where( + (cls.model.mesum_id == mesum_id) & + (cls.model.id == antique_id) + ) + + data = [] + for obj in query.dicts().execute(): + data.append(obj) + if len(data) > 0: + data = data[0] + return data diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index e9f7e974..a6481951 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -33,9 +33,21 @@ from rag.nlp.search import index_name from rag.utils import rmSpace, num_tokens_from_string, encoder from api.utils.file_utils import get_project_base_directory from peewee import fn -import threading, queue,uuid,time +import threading, queue,uuid,time,array from concurrent.futures import ThreadPoolExecutor +def audio_fade_in(audio_data, fade_length): + # 假设音频数据是16位单声道PCM + # 将二进制数据转换为整数数组 + samples = array.array('h', audio_data) + + # 对前fade_length个样本进行淡入处理 + for i in range(fade_length): + fade_factor = i / fade_length + samples[i] = int(samples[i] * fade_factor) + + # 将整数数组转换回二进制数据 + return samples.tobytes() class StreamSessionManager: @@ -44,17 +56,20 @@ class StreamSessionManager: self.lock = threading.Lock() self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池 self.gc_interval = 300 # 5分钟清理一次 - - def create_session(self, tts_model): + self.gc_tts = 3 # 3s + def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'): session_id = str(uuid.uuid4()) with self.lock: self.sessions[session_id] = { 'tts_model': tts_model, - 'buffer': queue.Queue(maxsize=100), # 线程安全队列 + 'buffer': queue.Queue(maxsize=300), # 线程安全队列 'task_queue': queue.Queue(), 'active': True, 'last_active': time.time(), - 'audio_chunk_count':0 + 'audio_chunk_count':0, + 'finished': threading.Event(), # 添加事件对象 + 'sample_rate':sample_rate, + 'stream_format':stream_format } # 启动任务处理线程 threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start() @@ -98,6 +113,9 @@ class StreamSessionManager: if time.time() - session['last_active'] > self.gc_interval: self.close_session(session_id) break + if time.time() - session['last_active'] > self.gc_tts: + session['finished'].set() + break except Exception as e: logging.error(f"Task processing error: {str(e)}") @@ -107,22 +125,36 @@ class StreamSessionManager: session = self.sessions.get(session_id) if not session: return # logging.info(f"_generate_audio:{text}") + first_chunk = True try: - for chunk in session['tts_model'].tts(text): - session['buffer'].put(chunk) + for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']): + if session['stream_format'] == 'wav': + if first_chunk: + chunk_len = len(chunk) + if chunk_len > 2048: + session['buffer'].put(audio_fade_in(chunk,1024)) + else: + session['buffer'].put(audio_fade_in(chunk, chunk_len)) + first_chunk = False + else: + session['buffer'].put(chunk) + else: + session['buffer'].put(chunk) session['last_active'] = time.time() session['audio_chunk_count'] = session['audio_chunk_count'] + 1 logging.info(f"转换结束!!! {session['audio_chunk_count'] }") except Exception as e: session['buffer'].put(f"ERROR:{str(e)}") + logging.info(f"--_generate_audio--error {str(e)}") + def close_session(self, session_id): with self.lock: if session_id in self.sessions: # 标记会话为不活跃 self.sessions[session_id]['active'] = False - # 延迟30秒后清理资源 - threading.Timer(10, self._clean_session, args=[session_id]).start() + # 延迟2秒后清理资源 + threading.Timer(1, self._clean_session, args=[session_id]).start() def _clean_session(self, session_id): with self.lock: @@ -297,7 +329,6 @@ def validate_and_sanitize_tts_input(delta_ans, max_length=3000): # 5. 检查长度 if len(delta_ans) == 0 or len(delta_ans) > max_length: return False, "" - # 如果通过所有检查,返回有效标志和修正后的文本 return True, delta_ans @@ -339,12 +370,6 @@ def find_split_position(text): if date_pattern: return date_pattern.end() - # 避免拆分常见短语 - for phrase in ["青少年", "博物馆", "参观"]: - idx = text.rfind(phrase) - if idx != -1 and idx + len(phrase) <= len(text): - return idx + len(phrase) - return None # 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分 @@ -353,9 +378,7 @@ def process_buffer(chunk_buffer, force_flush=False): current_text = "".join(chunk_buffer) if not current_text: return "", [] - split_pos = find_split_position(current_text) - # 强制刷新逻辑 if force_flush or len(current_text) >= MAX_BUFFER_LEN: # 即使强制刷新也要尽量找合适的分割点 @@ -366,7 +389,7 @@ def process_buffer(chunk_buffer, force_flush=False): if split_pos is not None and split_pos > 0: return current_text[:split_pos], [current_text[split_pos:]] - return "", chunk_buffer + return None, chunk_buffer def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." @@ -421,6 +444,8 @@ def chat(dialog, messages, stream=True, **kwargs): else: tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS, dialog.tts_id) + tts_sample_rate = kwargs.get("tts_sample_rate",8000) # 默认为8K + tts_stream_format = kwargs.get("tts_stream_format","mp3") # 默认为mp3格式 # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) @@ -465,20 +490,21 @@ def chat(dialog, messages, stream=True, **kwargs): doc_ids=attachments, top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl) knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]] - logging.debug( - "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) - + logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) + # 打印历史记录 + #logging.info( "{}->{}".format(" ".join(questions), "\n->".join(knowledges))) retrieval_tm = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] - yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)} + yield {"answer": empty_res, "reference": kbinfos, "audio_binary": + tts(tts_mdl, empty_res,sample_rate=tts_sample_rate,stream_format=tts_stream_format)} return {"answer": prompt_config["empty_response"], "reference": kbinfos} kwargs["knowledge"] = "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting - msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] + msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97)) @@ -502,6 +528,9 @@ def chat(dialog, messages, stream=True, **kwargs): embd_mdl, tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight) + # 上述转换过程中,发现有时候会在answer中插入类似##0$$ ##1$$ 这样的字符串,需要去除 + # cyx 20250407 + answer = re.sub(r'##\d+\$\$', '', answer).strip() #去除##0$$类似内容 同时去除多余空格 idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [ d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] @@ -512,7 +541,6 @@ def chat(dialog, messages, stream=True, **kwargs): for c in refs["chunks"]: if c.get("vector"): del c["vector"] - if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" done_tm = timer() @@ -525,83 +553,92 @@ def chat(dialog, messages, stream=True, **kwargs): last_ans = "" answer = "" # 创建TTS会话(提前初始化) - tts_session_id = stream_manager.create_session(tts_mdl) + tts_session_id = stream_manager.create_session(tts_mdl,sample_rate=tts_sample_rate,stream_format=tts_stream_format) audio_url = f"/tts_stream/{tts_session_id}" first_chunk = True chunk_buffer = [] # 新增文本缓冲 last_flush_time = time.time() # 初始化时间戳 + # 下面优先处理知识库中没有找到相关内容 cyx 20250323 修改 + if not kwargs["knowledge"] or kwargs["knowledge"] =="" or len(kwargs["knowledge"]) < 4: + stream_manager.append_text(tts_session_id, "未找到相关内容") + yield { + "answer": "未找到相关内容", + "delta_ans": "未找到相关内容", + "session_id": tts_session_id, + "reference": {}, + "audio_stream_url": audio_url, + "sample_rate": tts_sample_rate, + "stream_format": tts_stream_format, + } + else: + for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): + answer = ans + delta_ans = ans[len(last_ans):] + if num_tokens_from_string(delta_ans) < 24: + continue + last_ans = answer + # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + # cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错 + tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans) + # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary + if kwargs.get('tts_disable'): + tts_input_is_valid =False + if tts_input_is_valid: + # 缓冲文本直到遇到标点 + chunk_buffer.append(sanitized_text) + # 处理缓冲区内容 + while True: + # 判断是否需要强制刷新 + force = time.time() - last_flush_time > FLUSH_TIMEOUT + to_send, remaining = process_buffer(chunk_buffer, force_flush=force) + if not to_send: + break - for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf): + # 发送有效内容 + stream_manager.append_text(tts_session_id, to_send) + chunk_buffer = remaining + last_flush_time = time.time() + """ + if tts_input_is_valid: + yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)} + else: + yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}} + """ + # 首块返回音频URL + if first_chunk: + yield { + "answer": answer, + "delta_ans": sanitized_text, + "session_id": tts_session_id, + "reference": {}, + "audio_stream_url": audio_url, + "sample_rate":tts_sample_rate, + "stream_format":tts_stream_format, + } + first_chunk = False + else: + yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}} - answer = ans - delta_ans = ans[len(last_ans):] - if num_tokens_from_string(delta_ans) < 24: - continue - - last_ans = answer - # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - # cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错 - tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans) - # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary - if kwargs.get('tts_disable'): - tts_input_is_valid =False - if tts_input_is_valid: - # 缓冲文本直到遇到标点 - chunk_buffer.append(sanitized_text) - # 处理缓冲区内容 - while True: - # 判断是否需要强制刷新 - force = time.time() - last_flush_time > FLUSH_TIMEOUT - to_send, remaining = process_buffer(chunk_buffer, force_flush=force) - - if not to_send: - break - - # 发送有效内容 - stream_manager.append_text(tts_session_id, to_send) - chunk_buffer = remaining - last_flush_time = time.time() - """ - if tts_input_is_valid: - yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)} - else: + delta_ans = answer[len(last_ans):] + if delta_ans: + # stream_manager.append_text(tts_session_id, delta_ans) + # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} + # cyx 2024 12 04 修正delta_ans 为空 调用tts 出错 + tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans) + if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary + tts_input_is_valid = False + if tts_input_is_valid: + # 20250221 修改,在后端生成音频数据 + chunk_buffer.append(sanitized_text) + stream_manager.append_text(tts_session_id, ''.join(chunk_buffer)) yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}} - """ - - # 首块返回音频URL - if first_chunk: - yield { - "answer": answer, - "delta_ans": sanitized_text, - "audio_stream_url": audio_url, - "session_id": tts_session_id, - "reference": {} - } - first_chunk = False - else: - yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}} - - delta_ans = answer[len(last_ans):] - if delta_ans: - # stream_manager.append_text(tts_session_id, delta_ans) - # yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} - # cyx 2024 12 04 修正delta_ans 为空 调用tts 出错 - tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans) - if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary - tts_input_is_valid = False - if tts_input_is_valid: - # 20250221 修改,在后端生成音频数据 - chunk_buffer.append(sanitized_text) - stream_manager.append_text(tts_session_id, ''.join(chunk_buffer)) - yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}} - """ - if tts_input_is_valid: - yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)} - else: - yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}} - """ - - yield decorate_answer(answer) + """ + if tts_input_is_valid: + yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)} + else: + yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}} + """ + yield decorate_answer(answer) else: answer = chat_mdl.chat(prompt, msg[1:], gen_conf) @@ -611,7 +648,7 @@ def chat(dialog, messages, stream=True, **kwargs): if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary tts_input_is_valid = False else: - res["audio_binary"] = tts(tts_mdl, answer) + res["audio_binary"] = tts(tts_mdl, answer,tts_sample_rate,tts_stream_format) yield res @@ -899,10 +936,10 @@ Output: What's the weather in Rochester on {tomorrow}? return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] -def tts(tts_mdl, text): +def tts(tts_mdl, text,sample_rate=8000,stream_format = "mp3"): if not tts_mdl or not text: return bin = b"" - for chunk in tts_mdl.tts(text): + for chunk in tts_mdl.tts(text,sample_rate,stream_format): bin += chunk return binascii.hexlify(bin).decode("utf-8") diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 06ca6def..617133aa 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -248,8 +248,8 @@ class LLMBundle(object): "LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens)) return txt - def tts(self, text): # tts 调用 cyx - for chunk in self.mdl.tts(text): + def tts(self, text, sample_rate=8000, stream_format='mp3'): # tts 调用 cyx + for chunk in self.mdl.tts(text, sample_rate=sample_rate,stream_format = stream_format): if isinstance(chunk,int): if not TenantLLMService.increase_usage( self.tenant_id, self.llm_type, chunk, self.llm_name): diff --git a/conf/llm_factories.json b/conf/llm_factories.json index c98c1388..f9ca2847 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -167,6 +167,40 @@ } ] }, + { + "name": "Qianwen-Omni", + "logo": "", + "tags": "LLM,IMAGE2TEXT,MODERATION", + "status": "1", + "llm": [ + { + "llm_name": "qwen-omni-turbo", + "tags": "LLM,CHAT,32K", + "max_tokens": 30768, + "model_type": "chat" + }, + { + "llm_name": "qwen-omni-turbo-latest", + "tags": "LLM,CHAT,IMAGE2TEXT", + "max_tokens": 30768, + "model_type": "image2text" + } + ] + }, + { + "name": "LOCAL-LLM", + "logo": "", + "tags": "chat", + "status": "1", + "llm": [ + { + "llm_name": "chat", + "tags": "LLM,CHAT,32K", + "max_tokens": 12768, + "model_type": "chat" + } + ] + }, { "name": "ZHIPU-AI", "logo": "", diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 647b2a90..03d3120f 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -76,6 +76,8 @@ ChatModel = { "Azure-OpenAI": AzureChat, "ZHIPU-AI": ZhipuChat, "Tongyi-Qianwen": QWenChat, + "Qianwen-Omni": QWenOmniChat, + "LOCAL-LLM":LocalLLMChat, "Ollama": OllamaChat, "LocalAI": LocalAIChat, "Xinference": XinferenceChat, diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 012c625b..05f4aa50 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -127,6 +127,98 @@ class DeepSeekChat(Base): super().__init__(key, model_name, base_url) +class LocalLLMChat(Base): + def __init__(self, key, model_name="Qwen2.5-7B", base_url="http://106.52.71.204:9483/v1"): + if not base_url: base_url = "http://106.52.71.204:9483/v1" + super().__init__(key, model_name, base_url) + +class QWenOmniChat(Base): + def __init__(self, key, model_name="qwen-omni-turbo", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"): + if not base_url: base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + super().__init__(key, model_name, base_url) + def chat(self, system, history, gen_conf): + if system: + history.insert(0, {"role": "system", "content": system}) + ans = "" + total_tokens = 0 + try: + response = self.client.chat.completions.create( + model = self.model_name, + messages=history, + stream=True, + # 设置输出数据的模态,当前支持两种:["text","audio"]、["text"] + # modalities=["text", "audio"], + # audio={"voice": "Cherry", "format": "wav"}, + # stream 必须设置为 True,否则会报错 + # stream_options={"include_usage": True}, + **gen_conf + ) + for resp in response: + if not resp.choices: continue + if not resp.choices[0].delta.content: + resp.choices[0].delta.content = "" + ans += resp.choices[0].delta.content + + if not hasattr(resp, "usage") or not resp.usage: + total_tokens = ( + total_tokens + + num_tokens_from_string(resp.choices[0].delta.content) + ) + elif isinstance(resp.usage, dict): + total_tokens = resp.usage.get("total_tokens", total_tokens) + else: total_tokens = resp.usage.total_tokens + + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + break # 如果达到长度限制,可以跳出循环 + except openai.APIError as e: + ans= ans + "\n**ERROR**: " + str(e) + + return ans, total_tokens + + + def chat_streamly(self, system, history, gen_conf): + # logging.info(f"chat_streamly :{gen_conf}") + if system : + history.insert(0, {"role": "system", "content": + [{"type":"text","text": system}]}) + ans = "" + total_tokens = 0 + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=history, + stream=True, + #**gen_conf + ) + for resp in response: + if not resp.choices: continue + if not resp.choices[0].delta.content: + resp.choices[0].delta.content = "" + ans += resp.choices[0].delta.content + + if not hasattr(resp, "usage") or not resp.usage: + total_tokens = ( + total_tokens + + num_tokens_from_string(resp.choices[0].delta.content) + ) + elif isinstance(resp.usage, dict): + total_tokens = resp.usage.get("total_tokens", total_tokens) + else: total_tokens = resp.usage.total_tokens + + if resp.choices[0].finish_reason == "length": + ans += "...\nFor the content length reason, it stopped, continue?" if is_english( + [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" + yield ans + + except openai.APIError as e: + yield ans + "\n**ERROR**: " + str(e) + + yield total_tokens + + + class AzureChat(Base): def __init__(self, key, model_name, **kwargs): api_key = json.loads(key).get('api_key', '') diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index b6e2b887..71c2ef42 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -203,11 +203,14 @@ class ZhipuEmbed(Base): def encode(self, texts: list, batch_size=16): arr = [] tks_num = 0 - for txt in texts: - res = self.client.embeddings.create(input=txt, - model=self.model_name) - arr.append(res.data[0].embedding) - tks_num += res.usage.total_tokens + try: + for txt in texts: + res = self.client.embeddings.create(input=txt, + model=self.model_name) + arr.append(res.data[0].embedding) + tks_num += res.usage.total_tokens + except Exception as error: + logging.info(f"!!!ZhipuEmbed embedding error {error}") return np.array(arr), tks_num def encode_queries(self, text): diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index d9ac36d3..50c017ac 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -133,7 +133,8 @@ class QwenTTS(Base): if parts[0] == 'cosyvoice-v1': self.is_cosyvoice = True self.voice = parts[1] - def tts(self, text): + # 参数stream_format 为产生的tts 音频数据格式, mp3 wav pcm + def tts(self, text, sample_rate=8000,stream_format="mp3"): from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse if self.is_cosyvoice is False: from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult @@ -143,7 +144,7 @@ class QwenTTS(Base): from dashscope.audio.tts_v2 import ResultCallback, SpeechSynthesizer, AudioFormat #, SpeechSynthesisResult from dashscope.audio.tts import SpeechSynthesisResult from collections import deque - print(f"--QwenTTS--tts_stream begin-- {text}") # cyx + # print(f"--QwenTTS--tts_stream begin-- {text}") # cyx class Callback(ResultCallback): def __init__(self) -> None: self.dque = deque() @@ -237,12 +238,52 @@ class QwenTTS(Base): format="mp3") else: self.callback = Callback_v2() + print(f"--tts {sample_rate} {stream_format}") + if sample_rate == 8000: + if stream_format == 'mp3': + format = AudioFormat.MP3_8000HZ_MONO_128KBPS + elif stream_format == 'pcm': + format = AudioFormat.PCM_8000HZ_MONO_16BIT + elif stream_format == 'wav': + format = AudioFormat.WAV_8000HZ_MONO_16BIT + else: + format = AudioFormat.MP3_8000HZ_MONO_128KBPS + elif sample_rate == 16000: + if stream_format == 'mp3': + format = AudioFormat.MP3_16000HZ_MONO_128KBPS + elif stream_format == 'pcm': + format = AudioFormat.PCM_16000HZ_MONO_16BIT + elif stream_format == 'wav': + format = AudioFormat.WAV_16000HZ_MONO_16BIT + else: + format = AudioFormat.MP3_16000HZ_MONO_128KBPS + elif sample_rate == 22050: + if stream_format == 'mp3': + format = AudioFormat.MP3_22050HZ_MONO_256KBPS + elif stream_format == 'pcm': + format = AudioFormat.PCM_22050HZ_MONO_16BIT + elif stream_format == 'wav': + format = AudioFormat.WAV_22050HZ_MONO_16BIT + else: + format = AudioFormat.MP3_22050HZ_MONO_256KBPS + elif sample_rate == 44100: + if stream_format == 'mp3': + format = AudioFormat.MP3_44100HZ_MONO_256KBPS + elif stream_format == 'pcm': + format = AudioFormat.PCM_44100HZ_MONO_16BIT + elif stream_format == 'wav': + format = AudioFormat.WAV_44100HZ_MONO_16BIT + else: + format = AudioFormat.MP3_44100HZ_MONO_256KBPS + # format=AudioFormat.MP3_44100HZ_MONO_256KBPS + else: + format = AudioFormat.MP3_44100HZ_MONO_256KBPS self.synthesizer = SpeechSynthesizer( model='cosyvoice-v1', # voice="longyuan", #"longfei", voice = self.voice, callback=self.callback, - format=AudioFormat.MP3_44100HZ_MONO_256KBPS + format=format ) self.synthesizer.call(text) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index d7be6fa8..22a06bfb 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -317,7 +317,7 @@ def embedding(docs, mdl, parser_config=None, callback=None): vects = (title_w * tts + (1 - title_w) * cnts) if len(tts) == len(cnts) else cnts - assert len(vects) == len(docs) + # assert len(vects) == len(docs) vector_size = 0 for i, d in enumerate(docs): v = vects[i].tolist()