在生成对话文字时,同时在后台生成tts音频,增加朗读音色选择,增加博物馆的概况接口
This commit is contained in:
@@ -13,17 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from flask import request
|
||||
from flask import request , Response, jsonify
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
from api.db import StatusEnum
|
||||
from api.db.services.dialog_service import DialogService
|
||||
from api.db.services.dialog_service import DialogService,stream_manager
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.brief_service import MesumOverviewService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
|
||||
import logging
|
||||
import base64
|
||||
import queue,time,uuid
|
||||
from threading import Lock,Thread
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
# 用户已经添加的模型 cyx 2025-01-26
|
||||
@manager.route('/get_llms', methods=['GET'])
|
||||
@@ -48,4 +55,352 @@ def my_llms(tenant_id):
|
||||
})
|
||||
return get_result(data=res)
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||||
|
||||
|
||||
main_antiquity="浮雕故事,绿釉刻花瓷枕函,走马灯,水晶项链"
|
||||
@manager.route('/photo/recongeText', methods=['POST'])
|
||||
@token_required
|
||||
def upload_file(tenant_id):
|
||||
if 'file' not in request.files:
|
||||
return jsonify({'error': 'No file part'}), 400
|
||||
|
||||
file = request.files['file']
|
||||
|
||||
if file.filename == '':
|
||||
return jsonify({'error': 'No selected file'}), 400
|
||||
|
||||
if file and allowed_file(file.filename):
|
||||
file_size = request.content_length
|
||||
img_base = base64.b64encode(file.read()).decode('utf-8')
|
||||
req_antique = request.form.get('antique',None)
|
||||
if req_antique is None:
|
||||
req_antique = main_antiquity
|
||||
logging.info(f"recevie photo file {file.filename} {file_size} 识别中....")
|
||||
client = ZhipuAI(api_key="5685053e23939bf82e515f9b0a3b59be.C203PF4ExLDUJUZ3") # 填写您自己的APIKey
|
||||
response = client.chat.completions.create(
|
||||
model="glm-4v-plus", # 填写需要调用的模型名称
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": img_base
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": (f"你是一名资深的博物馆知识和文物讲解专家,同时也是一名历史学家,"
|
||||
f"请识别这个图片中文字,如果字数较少,优先匹配候选中的某一文物名称,"
|
||||
f"如果字符较多,在匹配文物名称同时分析识别出的文字是不是候选中某一文物的简单介绍"
|
||||
f"你的回答有2个结果,第一个结果是是从文字进行分析出匹配文物,候选文物只能如下:{req_antique},"
|
||||
f"回答时只给出匹配的文物,不需要其他多余的文字,如果没有匹配,则不输出,"
|
||||
f",第二个结果是原始识别的所有文字"
|
||||
"2个结果输出以{ }的json格式给出,匹配文物的键值为antique,如果有多个请加序号,如:antique1,antique2,"
|
||||
f"原始数据的键值为text,输出是1个完整的JSON数据,不要有多余的前置和后置内容,确保前端能正确解析出JSON数据")
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
logging.info(message.content)
|
||||
return jsonify({'message': 'File uploaded successfully','text':message.content }), 200
|
||||
|
||||
def allowed_file(filename):
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in {'png', 'jpg', 'jpeg', 'gif'}
|
||||
|
||||
|
||||
#get_all
|
||||
|
||||
@manager.route('/mesum/list', methods=['GET'])
|
||||
@token_required
|
||||
def mesum_list(tenant_id):
|
||||
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
|
||||
# model_type = request.args.get("type")
|
||||
try:
|
||||
res = []
|
||||
overviews=MesumOverviewService.get_all()
|
||||
for o in overviews:
|
||||
res.append(o.to_dict())
|
||||
return get_result(data=res)
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||||
|
||||
@manager.route('/mesum/set_antique', methods=['POST'])
|
||||
@token_required
|
||||
def mesum_set_antique(tenant_id):
|
||||
global main_antiquity
|
||||
# request.args.get("id") 通过request.args.get 获取GET 方法传入的参数
|
||||
req_data = request.json
|
||||
req_data_antique=req_data.get('antique',None)
|
||||
try:
|
||||
if req_data_antique:
|
||||
main_antiquity = req_data_antique
|
||||
print(main_antiquity)
|
||||
return get_result({'statusCode':200,'code':0,'message': 'antique set successfully'})
|
||||
except Exception as e:
|
||||
return get_error_data_result(message=f"Get LLMS error {e}")
|
||||
|
||||
audio_text_cache = {}
|
||||
cache_lock = Lock()
|
||||
CACHE_EXPIRE_SECONDS = 600 # 10分钟过期
|
||||
# 全角字符到半角字符的映射
|
||||
def fullwidth_to_halfwidth(s):
|
||||
full_to_half_map = {
|
||||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||||
'、': '、', '・': '・', ':': ':'
|
||||
}
|
||||
return ''.join(full_to_half_map.get(char, char) for char in s)
|
||||
|
||||
def split_text_at_punctuation(text, chunk_size=100):
|
||||
# 使用正则表达式找到所有的标点符号和特殊字符
|
||||
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
||||
tokens = re.split(punctuation_pattern, text)
|
||||
|
||||
# 移除空字符串
|
||||
tokens = [token for token in tokens if token]
|
||||
|
||||
# 存储最终的文本块
|
||||
chunks = []
|
||||
current_chunk = ''
|
||||
|
||||
for token in tokens:
|
||||
if len(current_chunk) + len(token) <= chunk_size:
|
||||
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
||||
current_chunk += (token + ' ')
|
||||
else:
|
||||
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = token + ' '
|
||||
|
||||
# 添加最后一个块(如果有剩余)
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def extract_text_from_markdown(markdown_text):
|
||||
# 移除Markdown标题
|
||||
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
||||
# 移除内联代码块
|
||||
text = re.sub(r'`[^`]+`', '', text)
|
||||
# 移除代码块
|
||||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||||
# 移除加粗和斜体
|
||||
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
||||
# 移除链接
|
||||
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
||||
# 移除图片
|
||||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||||
# 移除HTML标签
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
# 转换标点符号
|
||||
# text = re.sub(r'[^\w\s]', '', text)
|
||||
text = fullwidth_to_halfwidth(text)
|
||||
# 移除多余的空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
def clean_audio_cache():
|
||||
"""定时清理过期缓存"""
|
||||
with cache_lock:
|
||||
now = time.time()
|
||||
expired_keys = [
|
||||
k for k, v in audio_text_cache.items()
|
||||
if now - v['created_at'] > CACHE_EXPIRE_SECONDS
|
||||
]
|
||||
for k in expired_keys:
|
||||
entry = audio_text_cache.pop(k, None)
|
||||
if entry and entry.get('audio_stream'):
|
||||
entry['audio_stream'].close()
|
||||
|
||||
|
||||
def start_background_cleaner():
|
||||
"""启动后台清理线程"""
|
||||
|
||||
def cleaner_loop():
|
||||
while True:
|
||||
time.sleep(180) # 每3分钟清理一次
|
||||
clean_audio_cache()
|
||||
|
||||
cleaner_thread = Thread(target=cleaner_loop, daemon=True)
|
||||
cleaner_thread.start()
|
||||
|
||||
# 应用启动时启动清理线程
|
||||
start_background_cleaner()
|
||||
|
||||
@manager.route('/tts_stream/<session_id>')
|
||||
def tts_stream(session_id):
|
||||
def generate():
|
||||
retry_count = 0
|
||||
session = None
|
||||
count = 0;
|
||||
try:
|
||||
while retry_count < 1:
|
||||
session = stream_manager.sessions.get(session_id)
|
||||
if not session or not session['active']:
|
||||
break
|
||||
try:
|
||||
chunk = session['buffer'].get(timeout=5) # 30秒超时
|
||||
count = count + 1
|
||||
if isinstance(chunk, str) and chunk.startswith("ERROR"):
|
||||
logging.info("---tts stream error!!!!")
|
||||
yield f"data:{{'error':'{chunk[6:]}'}}\n\n"
|
||||
break
|
||||
yield chunk
|
||||
retry_count = 0 # 成功收到数据重置重试计数器
|
||||
except queue.Empty:
|
||||
retry_count += 1
|
||||
yield b'' # 保持连接
|
||||
finally:
|
||||
# 确保流结束后关闭会话
|
||||
if session:
|
||||
# 延迟关闭会话,确保所有数据已发送
|
||||
time.sleep(5) # 等待5秒确保流结束
|
||||
stream_manager.close_session(session_id)
|
||||
logging.info(f"Session {session_id} closed.")
|
||||
|
||||
resp = Response(generate(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
return resp
|
||||
|
||||
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
|
||||
def dialog_tts_get(chat_id, audio_stream_id):
|
||||
with cache_lock:
|
||||
tts_info = audio_text_cache.pop(audio_stream_id, None) # 取出即删除
|
||||
try:
|
||||
req = tts_info
|
||||
if not req:
|
||||
return get_error_data_result(message="Audio stream not found or expired.")
|
||||
audio_stream = req.get('audio_stream')
|
||||
tenant_id = req.get('tenant_id')
|
||||
chat_id = req.get('chat_id')
|
||||
text = req.get('text', "..")
|
||||
model_name = req.get('model_name')
|
||||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
tts_model_name = dia.tts_id
|
||||
if model_name: tts_model_name = model_name
|
||||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||||
|
||||
def stream_audio():
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e)}},
|
||||
ensure_ascii=False)).encode('utf-8')
|
||||
|
||||
def generate():
|
||||
data = audio_stream.read(1024)
|
||||
while data:
|
||||
yield data
|
||||
data = audio_stream.read(1024)
|
||||
|
||||
if audio_stream:
|
||||
# 确保流的位置在开始处
|
||||
audio_stream.seek(0)
|
||||
resp = Response(generate(), mimetype="audio/mpeg")
|
||||
else:
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
return resp
|
||||
except Exception as e:
|
||||
logging.error(f"音频流传输错误: {str(e)}", exc_info=True)
|
||||
return get_error_data_result(message="音频流传输失败")
|
||||
finally:
|
||||
# 确保资源释放
|
||||
if tts_info.get('audio_stream') and not tts_info['audio_stream'].closed:
|
||||
tts_info['audio_stream'].close()
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
|
||||
@token_required
|
||||
def dialog_tts_post(tenant_id, chat_id):
|
||||
try:
|
||||
req = request.json
|
||||
if not req.get("text"):
|
||||
return get_error_data_result(message="Please input your question.")
|
||||
delay_gen_audio = req.get('delay_gen_audio', False)
|
||||
# text = extract_text_from_markdown(req.get('text'))
|
||||
text = req.get('text')
|
||||
model_name = req.get('model_name')
|
||||
audio_stream_id = req.get('audio_stream_id', None)
|
||||
if audio_stream_id is None:
|
||||
audio_stream_id = str(uuid.uuid4())
|
||||
# 在这里生成音频流并存储到内存中
|
||||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
tts_model_name = dia.tts_id
|
||||
if model_name: tts_model_name = model_name
|
||||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||||
if delay_gen_audio:
|
||||
audio_stream = None
|
||||
else:
|
||||
audio_stream = io.BytesIO()
|
||||
# 结构化缓存数据
|
||||
tts_info = {
|
||||
'text': text,
|
||||
'tenant_id': tenant_id,
|
||||
'chat_id': chat_id,
|
||||
'created_at': time.time(),
|
||||
'audio_stream': audio_stream, # 维持原有逻辑
|
||||
'model_name': req.get('model_name'),
|
||||
'delay_gen_audio': delay_gen_audio, # 明确存储状态
|
||||
audio_stream_id: audio_stream_id
|
||||
}
|
||||
|
||||
with cache_lock:
|
||||
audio_text_cache[audio_stream_id] = tts_info
|
||||
|
||||
if delay_gen_audio is False:
|
||||
try:
|
||||
"""
|
||||
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||
try:
|
||||
if txt is None or txt.strip() == "":
|
||||
continue
|
||||
for chunk in tts_mdl.tts(txt):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
continue
|
||||
"""
|
||||
if text is None or text.strip() == "":
|
||||
audio_stream.write(b'\x00' * 100)
|
||||
else:
|
||||
# 确保在流的末尾写入
|
||||
audio_stream.seek(0, io.SEEK_END)
|
||||
for chunk in tts_mdl.tts(text):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
logging.info(f"--error:{e}")
|
||||
with cache_lock:
|
||||
audio_text_cache.pop(audio_stream_id, None)
|
||||
return get_error_data_result(message="get tts audio stream error.")
|
||||
|
||||
# 构建音频流URL
|
||||
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
|
||||
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
|
||||
# 返回音频流URL
|
||||
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"请求处理失败: {str(e)}", exc_info=True)
|
||||
return get_error_data_result(message="服务器内部错误")
|
||||
|
||||
@@ -19,22 +19,22 @@ import logging
|
||||
from copy import deepcopy
|
||||
from uuid import uuid4
|
||||
from api.db import LLMType
|
||||
from flask import request, Response, jsonify
|
||||
from flask import request, Response, jsonify, stream_with_context
|
||||
from api.db.services.dialog_service import ask
|
||||
from agent.canvas import Canvas
|
||||
from api.db import StatusEnum
|
||||
from api.db.db_models import API4Conversation
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.dialog_service import DialogService, ConversationService, chat
|
||||
from api.db.services.dialog_service import DialogService, ConversationService, chat,stream_manager
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.utils import get_uuid
|
||||
from api.utils.api_utils import get_error_data_result
|
||||
from api.utils.api_utils import get_result, token_required
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
import uuid
|
||||
import queue
|
||||
|
||||
import queue,time
|
||||
from threading import Lock,Thread
|
||||
|
||||
@manager.route('/chats/<chat_id>/sessions', methods=['POST'])
|
||||
@token_required
|
||||
@@ -239,186 +239,6 @@ def completion(tenant_id, chat_id): # chat_id 和 别的文件中的dialog_id
|
||||
break
|
||||
return get_result(data=answer)
|
||||
|
||||
|
||||
# 全角字符到半角字符的映射
|
||||
|
||||
|
||||
def fullwidth_to_halfwidth(s):
|
||||
full_to_half_map = {
|
||||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||||
'、': '、', '・': '・', ':': ':'
|
||||
}
|
||||
return ''.join(full_to_half_map.get(char, char) for char in s)
|
||||
|
||||
|
||||
def is_dale(s):
|
||||
full_to_half_map = {
|
||||
'!': '!', '"': '"', '#': '#', '$': '$', '%': '%', '&': '&', ''': "'",
|
||||
'(': '(', ')': ')', '*': '*', '+': '+', ',': ',', '-': '-', '.': '.',
|
||||
'/': '/', ':': ':', ';': ';', '<': '<', '=': '=', '>': '>', '?': '?',
|
||||
'@': '@', '[': '[', '\': '\\', ']': ']', '^': '^', '_': '_', '`': '`',
|
||||
'{': '{', '|': '|', '}': '}', '~': '~', '⦅': '⦅', '⦆': '⦆', '「': '「',
|
||||
'」': '」', '、': ',', '・': '.', 'ー': '-', '。': '.', '「': '「', '」': '」',
|
||||
'、': '、', '・': '・', ':': ':', '。': '.'
|
||||
}
|
||||
|
||||
|
||||
def extract_text_from_markdown(markdown_text):
|
||||
# 移除Markdown标题
|
||||
text = re.sub(r'#\s*[^#]+', '', markdown_text)
|
||||
# 移除内联代码块
|
||||
text = re.sub(r'`[^`]+`', '', text)
|
||||
# 移除代码块
|
||||
text = re.sub(r'```[\s\S]*?```', '', text)
|
||||
# 移除加粗和斜体
|
||||
text = re.sub(r'[*_]{1,3}(?=\S)(.*?\S[*_]{1,3})', '', text)
|
||||
# 移除链接
|
||||
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
||||
# 移除图片
|
||||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||||
# 移除HTML标签
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
# 转换标点符号
|
||||
# text = re.sub(r'[^\w\s]', '', text)
|
||||
text = fullwidth_to_halfwidth(text)
|
||||
# 移除多余的空格
|
||||
text = re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def split_text_at_punctuation(text, chunk_size=100):
|
||||
# 使用正则表达式找到所有的标点符号和特殊字符
|
||||
punctuation_pattern = r'[\s,.!?;:\-\—\(\)\[\]{}"\'\\\/]+'
|
||||
tokens = re.split(punctuation_pattern, text)
|
||||
|
||||
# 移除空字符串
|
||||
tokens = [token for token in tokens if token]
|
||||
|
||||
# 存储最终的文本块
|
||||
chunks = []
|
||||
current_chunk = ''
|
||||
|
||||
for token in tokens:
|
||||
if len(current_chunk) + len(token) <= chunk_size:
|
||||
# 如果添加当前token后长度不超过chunk_size,则添加到当前块
|
||||
current_chunk += (token + ' ')
|
||||
else:
|
||||
# 如果长度超过chunk_size,则将当前块添加到chunks列表,并开始新块
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = token + ' '
|
||||
|
||||
# 添加最后一个块(如果有剩余)
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk.strip())
|
||||
|
||||
return chunks
|
||||
|
||||
audio_text_cache = {}
|
||||
|
||||
@manager.route('/chats/<chat_id>/tts/<audio_stream_id>', methods=['GET'])
|
||||
def dialog_tts_get(chat_id, audio_stream_id):
|
||||
tts_info = audio_text_cache.pop(audio_stream_id, None)
|
||||
req = tts_info
|
||||
if not req:
|
||||
return get_error_data_result(message="Audio stream not found or expired.")
|
||||
audio_stream = req.get('audio_stream')
|
||||
tenant_id = req.get('tenant_id')
|
||||
chat_id = req.get('chat_id')
|
||||
text = req.get('text', "..")
|
||||
model_name = req.get('model_name')
|
||||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
return get_error_data_result(message="You do not own the chat")
|
||||
tts_model_name = dia.tts_id
|
||||
if model_name: tts_model_name = model_name
|
||||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||||
|
||||
def stream_audio():
|
||||
try:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
yield ("data:" + json.dumps({"code": 500, "message": str(e),
|
||||
"data": {"answer": "**ERROR**: " + str(e)}},
|
||||
ensure_ascii=False)).encode('utf-8')
|
||||
|
||||
def generate():
|
||||
data = audio_stream.read(1024)
|
||||
while data:
|
||||
yield data
|
||||
data = audio_stream.read(1024)
|
||||
|
||||
if audio_stream:
|
||||
# 确保流的位置在开始处
|
||||
audio_stream.seek(0)
|
||||
resp = Response(generate(), mimetype="audio/mpeg")
|
||||
else:
|
||||
resp = Response(stream_audio(), mimetype="audio/mpeg")
|
||||
resp.headers.add_header("Cache-Control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/chats/<chat_id>/tts', methods=['POST'])
|
||||
@token_required
|
||||
def dialog_tts_post(tenant_id, chat_id):
|
||||
req = request.json
|
||||
if not req.get("text"):
|
||||
return get_error_data_result(message="Please input your question.")
|
||||
delay_gen_audio = req.get('delay_gen_audio', False)
|
||||
# text = extract_text_from_markdown(req.get('text'))
|
||||
text = req.get('text')
|
||||
audio_stream_id = req.get('audio_stream_id')
|
||||
# logging.info(f"request tts audio url:{text} audio_stream_id:{audio_stream_id} ")
|
||||
if audio_stream_id is None:
|
||||
audio_stream_id = str(uuid.uuid4())
|
||||
# 在这里生成音频流并存储到内存中
|
||||
model_name = req.get('model_name')
|
||||
dia = DialogService.get(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||
tts_model_name = dia.tts_id
|
||||
if model_name: tts_model_name = model_name
|
||||
logging.info(f"---tts {tts_model_name}")
|
||||
tts_mdl = LLMBundle(dia.tenant_id, LLMType.TTS, tts_model_name) # dia.tts_id)
|
||||
if delay_gen_audio:
|
||||
audio_stream = None
|
||||
else:
|
||||
audio_stream = io.BytesIO()
|
||||
audio_text_cache[audio_stream_id] = {'text': text, 'chat_id': chat_id, "tenant_id": tenant_id,
|
||||
'audio_stream': audio_stream,'model_name':model_name} # 缓存文本以便后续生成音频流
|
||||
if delay_gen_audio is False:
|
||||
try:
|
||||
"""
|
||||
for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text):
|
||||
try:
|
||||
if txt is None or txt.strip() == "":
|
||||
continue
|
||||
for chunk in tts_mdl.tts(txt):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
continue
|
||||
"""
|
||||
if text is None or text.strip() == "":
|
||||
audio_stream.write(b'\x00' * 100)
|
||||
else:
|
||||
for chunk in tts_mdl.tts(text):
|
||||
audio_stream.write(chunk)
|
||||
except Exception as e:
|
||||
return get_error_data_result(message="get tts audio stream error.")
|
||||
|
||||
# 构建音频流URL
|
||||
audio_stream_url = f"/chats/{chat_id}/tts/{audio_stream_id}"
|
||||
logging.info(f"--return request tts audio url {audio_stream_id} {audio_stream_url}")
|
||||
# 返回音频流URL
|
||||
return jsonify({"tts_url": audio_stream_url, "audio_stream_id": audio_stream_id})
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/completions', methods=['POST'])
|
||||
@token_required
|
||||
def agent_completion(tenant_id, agent_id):
|
||||
|
||||
Reference in New Issue
Block a user