增加加了博物馆展品清单数据库及对前端获取展品清单、展品详细的接口,增加了QWenOmni多模态大模型的支 持(主要为了测试),增加了本地部署大模型支持(主要为了测试,在autoDL上),修正了TTS生成和返回前端的逻辑与参数,增加了判断用户问题有没有在知识库中检索到相关片段、如果没有则直接返回并提示未包含
This commit is contained in:
@@ -25,7 +25,7 @@ from itsdangerous.url_safe import URLSafeTimedSerializer as Serializer
|
||||
from flask_login import UserMixin
|
||||
from playhouse.migrate import MySQLMigrator, PostgresqlMigrator, migrate
|
||||
from peewee import (
|
||||
BigIntegerField, BooleanField, CharField,
|
||||
BigIntegerField, BooleanField, CharField,AutoField,
|
||||
CompositeKey, IntegerField, TextField, FloatField, DateTimeField,
|
||||
Field, Model, Metadata
|
||||
)
|
||||
@@ -1025,6 +1025,23 @@ class MesumOverview(DataBaseModel):
|
||||
class Meta:
|
||||
db_table = "mesum_overview"
|
||||
|
||||
# added by cyx for mesum_antique
|
||||
class MesumAntique(DataBaseModel):
|
||||
sn = CharField(max_length=100, null=True)
|
||||
label = CharField(max_length=100, null=True)
|
||||
description = TextField(null=True)
|
||||
category = CharField(max_length=100, null=True)
|
||||
group = CharField(max_length=100, null=True)
|
||||
background = TextField(null=True)
|
||||
value = TextField(null=True)
|
||||
discovery = TextField(null=True)
|
||||
id = AutoField(primary_key=True)
|
||||
mesum_id = CharField(max_length=100, null=True)
|
||||
combined = TextField(null=True)
|
||||
|
||||
|
||||
class Meta:
|
||||
db_table = 'mesum_antique'
|
||||
#-------------------------------------------
|
||||
def migrate_db():
|
||||
with DB.transaction():
|
||||
|
||||
88
api/db/services/antique_service.py
Normal file
88
api/db/services/antique_service.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from datetime import datetime
|
||||
|
||||
import peewee
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from api.db import UserTenantRole
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import User, Tenant, MesumAntique
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.utils import get_uuid, get_format_time, current_timestamp, datetime_format
|
||||
from api.db import StatusEnum
|
||||
|
||||
|
||||
class MesumAntiqueService(CommonService):
|
||||
model = MesumAntique
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_mesum_id(cls, mesum_id):
|
||||
objs = cls.query(mesum_id=mesum_id)
|
||||
return objs
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_categories(cls):
|
||||
# 查询所有唯一的category
|
||||
categories = [category.category for category in cls.model.select(cls.model.category).distinct().execute() if category.category]
|
||||
return categories
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_labels(cls):
|
||||
# 查询所有去重后的label
|
||||
labels = [label.label for label in cls.model.select(cls.model.label).distinct().execute() if label.label]
|
||||
return labels
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_labels_ext(cls, mesum_id):
|
||||
# 根据mesum_id过滤,并排除空的category
|
||||
query = cls.model.select().where(
|
||||
(cls.model.mesum_id == mesum_id) &
|
||||
(cls.model.category != "")
|
||||
).order_by(cls.model.category)
|
||||
|
||||
# 按category分组并存储结果
|
||||
grouped_data = {}
|
||||
for obj in query.dicts().execute():
|
||||
category = obj['category']
|
||||
if category not in grouped_data:
|
||||
grouped_data[category] = []
|
||||
grouped_data[category].append({
|
||||
'id': obj['id'],
|
||||
'label': obj['label']
|
||||
})
|
||||
|
||||
return grouped_data
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_antique_by_id(cls, mesum_id,antique_id):
|
||||
|
||||
query = cls.model.select().where(
|
||||
(cls.model.mesum_id == mesum_id) &
|
||||
(cls.model.id == antique_id)
|
||||
)
|
||||
|
||||
data = []
|
||||
for obj in query.dicts().execute():
|
||||
data.append(obj)
|
||||
if len(data) > 0:
|
||||
data = data[0]
|
||||
return data
|
||||
@@ -33,9 +33,21 @@ from rag.nlp.search import index_name
|
||||
from rag.utils import rmSpace, num_tokens_from_string, encoder
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from peewee import fn
|
||||
import threading, queue,uuid,time
|
||||
import threading, queue,uuid,time,array
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
def audio_fade_in(audio_data, fade_length):
|
||||
# 假设音频数据是16位单声道PCM
|
||||
# 将二进制数据转换为整数数组
|
||||
samples = array.array('h', audio_data)
|
||||
|
||||
# 对前fade_length个样本进行淡入处理
|
||||
for i in range(fade_length):
|
||||
fade_factor = i / fade_length
|
||||
samples[i] = int(samples[i] * fade_factor)
|
||||
|
||||
# 将整数数组转换回二进制数据
|
||||
return samples.tobytes()
|
||||
|
||||
|
||||
class StreamSessionManager:
|
||||
@@ -44,17 +56,20 @@ class StreamSessionManager:
|
||||
self.lock = threading.Lock()
|
||||
self.executor = ThreadPoolExecutor(max_workers=30) # 固定大小线程池
|
||||
self.gc_interval = 300 # 5分钟清理一次
|
||||
|
||||
def create_session(self, tts_model):
|
||||
self.gc_tts = 3 # 3s
|
||||
def create_session(self, tts_model,sample_rate =8000, stream_format='mp3'):
|
||||
session_id = str(uuid.uuid4())
|
||||
with self.lock:
|
||||
self.sessions[session_id] = {
|
||||
'tts_model': tts_model,
|
||||
'buffer': queue.Queue(maxsize=100), # 线程安全队列
|
||||
'buffer': queue.Queue(maxsize=300), # 线程安全队列
|
||||
'task_queue': queue.Queue(),
|
||||
'active': True,
|
||||
'last_active': time.time(),
|
||||
'audio_chunk_count':0
|
||||
'audio_chunk_count':0,
|
||||
'finished': threading.Event(), # 添加事件对象
|
||||
'sample_rate':sample_rate,
|
||||
'stream_format':stream_format
|
||||
}
|
||||
# 启动任务处理线程
|
||||
threading.Thread(target=self._process_tasks, args=(session_id,), daemon=True).start()
|
||||
@@ -98,6 +113,9 @@ class StreamSessionManager:
|
||||
if time.time() - session['last_active'] > self.gc_interval:
|
||||
self.close_session(session_id)
|
||||
break
|
||||
if time.time() - session['last_active'] > self.gc_tts:
|
||||
session['finished'].set()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Task processing error: {str(e)}")
|
||||
@@ -107,22 +125,36 @@ class StreamSessionManager:
|
||||
session = self.sessions.get(session_id)
|
||||
if not session: return
|
||||
# logging.info(f"_generate_audio:{text}")
|
||||
first_chunk = True
|
||||
try:
|
||||
for chunk in session['tts_model'].tts(text):
|
||||
session['buffer'].put(chunk)
|
||||
for chunk in session['tts_model'].tts(text,session['sample_rate'],session['stream_format']):
|
||||
if session['stream_format'] == 'wav':
|
||||
if first_chunk:
|
||||
chunk_len = len(chunk)
|
||||
if chunk_len > 2048:
|
||||
session['buffer'].put(audio_fade_in(chunk,1024))
|
||||
else:
|
||||
session['buffer'].put(audio_fade_in(chunk, chunk_len))
|
||||
first_chunk = False
|
||||
else:
|
||||
session['buffer'].put(chunk)
|
||||
else:
|
||||
session['buffer'].put(chunk)
|
||||
session['last_active'] = time.time()
|
||||
session['audio_chunk_count'] = session['audio_chunk_count'] + 1
|
||||
logging.info(f"转换结束!!! {session['audio_chunk_count'] }")
|
||||
except Exception as e:
|
||||
session['buffer'].put(f"ERROR:{str(e)}")
|
||||
logging.info(f"--_generate_audio--error {str(e)}")
|
||||
|
||||
|
||||
def close_session(self, session_id):
|
||||
with self.lock:
|
||||
if session_id in self.sessions:
|
||||
# 标记会话为不活跃
|
||||
self.sessions[session_id]['active'] = False
|
||||
# 延迟30秒后清理资源
|
||||
threading.Timer(10, self._clean_session, args=[session_id]).start()
|
||||
# 延迟2秒后清理资源
|
||||
threading.Timer(1, self._clean_session, args=[session_id]).start()
|
||||
|
||||
def _clean_session(self, session_id):
|
||||
with self.lock:
|
||||
@@ -297,7 +329,6 @@ def validate_and_sanitize_tts_input(delta_ans, max_length=3000):
|
||||
# 5. 检查长度
|
||||
if len(delta_ans) == 0 or len(delta_ans) > max_length:
|
||||
return False, ""
|
||||
|
||||
# 如果通过所有检查,返回有效标志和修正后的文本
|
||||
return True, delta_ans
|
||||
|
||||
@@ -339,12 +370,6 @@ def find_split_position(text):
|
||||
if date_pattern:
|
||||
return date_pattern.end()
|
||||
|
||||
# 避免拆分常见短语
|
||||
for phrase in ["青少年", "博物馆", "参观"]:
|
||||
idx = text.rfind(phrase)
|
||||
if idx != -1 and idx + len(phrase) <= len(text):
|
||||
return idx + len(phrase)
|
||||
|
||||
return None
|
||||
|
||||
# 管理文本缓冲区,根据语义规则动态分割并返回待处理内容,分割出语义完整的部分
|
||||
@@ -353,9 +378,7 @@ def process_buffer(chunk_buffer, force_flush=False):
|
||||
current_text = "".join(chunk_buffer)
|
||||
if not current_text:
|
||||
return "", []
|
||||
|
||||
split_pos = find_split_position(current_text)
|
||||
|
||||
# 强制刷新逻辑
|
||||
if force_flush or len(current_text) >= MAX_BUFFER_LEN:
|
||||
# 即使强制刷新也要尽量找合适的分割点
|
||||
@@ -366,7 +389,7 @@ def process_buffer(chunk_buffer, force_flush=False):
|
||||
if split_pos is not None and split_pos > 0:
|
||||
return current_text[:split_pos], [current_text[split_pos:]]
|
||||
|
||||
return "", chunk_buffer
|
||||
return None, chunk_buffer
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
@@ -421,6 +444,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
else:
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS, dialog.tts_id)
|
||||
|
||||
tts_sample_rate = kwargs.get("tts_sample_rate",8000) # 默认为8K
|
||||
tts_stream_format = kwargs.get("tts_stream_format","mp3") # 默认为mp3格式
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
@@ -465,20 +490,21 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl)
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
logging.debug(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
|
||||
logging.debug( "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
# 打印历史记录
|
||||
#logging.info( "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
retrieval_tm = timer()
|
||||
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
||||
yield {"answer": empty_res, "reference": kbinfos, "audio_binary":
|
||||
tts(tts_mdl, empty_res,sample_rate=tts_sample_rate,stream_format=tts_stream_format)}
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
|
||||
kwargs["knowledge"] = "\n\n------\n\n".join(knowledges)
|
||||
gen_conf = dialog.llm_setting
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
|
||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
||||
for m in messages if m["role"] != "system"])
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
||||
@@ -502,6 +528,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
# 上述转换过程中,发现有时候会在answer中插入类似##0$$ ##1$$ 这样的字符串,需要去除
|
||||
# cyx 20250407
|
||||
answer = re.sub(r'##\d+\$\$', '', answer).strip() #去除##0$$类似内容 同时去除多余空格
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
@@ -512,7 +541,6 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
for c in refs["chunks"]:
|
||||
if c.get("vector"):
|
||||
del c["vector"]
|
||||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
done_tm = timer()
|
||||
@@ -525,83 +553,92 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
# 创建TTS会话(提前初始化)
|
||||
tts_session_id = stream_manager.create_session(tts_mdl)
|
||||
tts_session_id = stream_manager.create_session(tts_mdl,sample_rate=tts_sample_rate,stream_format=tts_stream_format)
|
||||
audio_url = f"/tts_stream/{tts_session_id}"
|
||||
first_chunk = True
|
||||
chunk_buffer = [] # 新增文本缓冲
|
||||
last_flush_time = time.time() # 初始化时间戳
|
||||
# 下面优先处理知识库中没有找到相关内容 cyx 20250323 修改
|
||||
if not kwargs["knowledge"] or kwargs["knowledge"] =="" or len(kwargs["knowledge"]) < 4:
|
||||
stream_manager.append_text(tts_session_id, "未找到相关内容")
|
||||
yield {
|
||||
"answer": "未找到相关内容",
|
||||
"delta_ans": "未找到相关内容",
|
||||
"session_id": tts_session_id,
|
||||
"reference": {},
|
||||
"audio_stream_url": audio_url,
|
||||
"sample_rate": tts_sample_rate,
|
||||
"stream_format": tts_stream_format,
|
||||
}
|
||||
else:
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 24:
|
||||
continue
|
||||
last_ans = answer
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
# cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
if kwargs.get('tts_disable'):
|
||||
tts_input_is_valid =False
|
||||
if tts_input_is_valid:
|
||||
# 缓冲文本直到遇到标点
|
||||
chunk_buffer.append(sanitized_text)
|
||||
# 处理缓冲区内容
|
||||
while True:
|
||||
# 判断是否需要强制刷新
|
||||
force = time.time() - last_flush_time > FLUSH_TIMEOUT
|
||||
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
|
||||
if not to_send:
|
||||
break
|
||||
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||
# 发送有效内容
|
||||
stream_manager.append_text(tts_session_id, to_send)
|
||||
chunk_buffer = remaining
|
||||
last_flush_time = time.time()
|
||||
"""
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||||
"""
|
||||
# 首块返回音频URL
|
||||
if first_chunk:
|
||||
yield {
|
||||
"answer": answer,
|
||||
"delta_ans": sanitized_text,
|
||||
"session_id": tts_session_id,
|
||||
"reference": {},
|
||||
"audio_stream_url": audio_url,
|
||||
"sample_rate":tts_sample_rate,
|
||||
"stream_format":tts_stream_format,
|
||||
}
|
||||
first_chunk = False
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||||
|
||||
answer = ans
|
||||
delta_ans = ans[len(last_ans):]
|
||||
if num_tokens_from_string(delta_ans) < 24:
|
||||
continue
|
||||
|
||||
last_ans = answer
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空 ,调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
# cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
if kwargs.get('tts_disable'):
|
||||
tts_input_is_valid =False
|
||||
if tts_input_is_valid:
|
||||
# 缓冲文本直到遇到标点
|
||||
chunk_buffer.append(sanitized_text)
|
||||
# 处理缓冲区内容
|
||||
while True:
|
||||
# 判断是否需要强制刷新
|
||||
force = time.time() - last_flush_time > FLUSH_TIMEOUT
|
||||
to_send, remaining = process_buffer(chunk_buffer, force_flush=force)
|
||||
|
||||
if not to_send:
|
||||
break
|
||||
|
||||
# 发送有效内容
|
||||
stream_manager.append_text(tts_session_id, to_send)
|
||||
chunk_buffer = remaining
|
||||
last_flush_time = time.time()
|
||||
"""
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
# stream_manager.append_text(tts_session_id, delta_ans)
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
if tts_input_is_valid:
|
||||
# 20250221 修改,在后端生成音频数据
|
||||
chunk_buffer.append(sanitized_text)
|
||||
stream_manager.append_text(tts_session_id, ''.join(chunk_buffer))
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||||
"""
|
||||
|
||||
# 首块返回音频URL
|
||||
if first_chunk:
|
||||
yield {
|
||||
"answer": answer,
|
||||
"delta_ans": sanitized_text,
|
||||
"audio_stream_url": audio_url,
|
||||
"session_id": tts_session_id,
|
||||
"reference": {}
|
||||
}
|
||||
first_chunk = False
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||||
|
||||
delta_ans = answer[len(last_ans):]
|
||||
if delta_ans:
|
||||
# stream_manager.append_text(tts_session_id, delta_ans)
|
||||
# yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# cyx 2024 12 04 修正delta_ans 为空 调用tts 出错
|
||||
tts_input_is_valid, sanitized_text = validate_and_sanitize_tts_input(delta_ans)
|
||||
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
if tts_input_is_valid:
|
||||
# 20250221 修改,在后端生成音频数据
|
||||
chunk_buffer.append(sanitized_text)
|
||||
stream_manager.append_text(tts_session_id, ''.join(chunk_buffer))
|
||||
yield {"answer": answer, "delta_ans": sanitized_text, "reference": {}}
|
||||
"""
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||||
"""
|
||||
|
||||
yield decorate_answer(answer)
|
||||
"""
|
||||
if tts_input_is_valid:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}, "audio_binary": tts(tts_mdl, sanitized_text)}
|
||||
else:
|
||||
yield {"answer": answer, "delta_ans": sanitized_text,"reference": {}}
|
||||
"""
|
||||
yield decorate_answer(answer)
|
||||
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||||
@@ -611,7 +648,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
||||
if kwargs.get('tts_disable'): # cyx 2025 01 18 前端传入tts_disable 参数,就不生成tts 音频给前端,即:没有audio_binary
|
||||
tts_input_is_valid = False
|
||||
else:
|
||||
res["audio_binary"] = tts(tts_mdl, answer)
|
||||
res["audio_binary"] = tts(tts_mdl, answer,tts_sample_rate,tts_stream_format)
|
||||
yield res
|
||||
|
||||
|
||||
@@ -899,10 +936,10 @@ Output: What's the weather in Rochester on {tomorrow}?
|
||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||
|
||||
|
||||
def tts(tts_mdl, text):
|
||||
def tts(tts_mdl, text,sample_rate=8000,stream_format = "mp3"):
|
||||
if not tts_mdl or not text: return
|
||||
bin = b""
|
||||
for chunk in tts_mdl.tts(text):
|
||||
for chunk in tts_mdl.tts(text,sample_rate,stream_format):
|
||||
bin += chunk
|
||||
return binascii.hexlify(bin).decode("utf-8")
|
||||
|
||||
|
||||
@@ -248,8 +248,8 @@ class LLMBundle(object):
|
||||
"LLMBundle.transcription can't update token usage for {}/SEQUENCE2TXT used_tokens: {}".format(self.tenant_id, used_tokens))
|
||||
return txt
|
||||
|
||||
def tts(self, text): # tts 调用 cyx
|
||||
for chunk in self.mdl.tts(text):
|
||||
def tts(self, text, sample_rate=8000, stream_format='mp3'): # tts 调用 cyx
|
||||
for chunk in self.mdl.tts(text, sample_rate=sample_rate,stream_format = stream_format):
|
||||
if isinstance(chunk,int):
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, chunk, self.llm_name):
|
||||
|
||||
Reference in New Issue
Block a user