add support for TTS model (#2095)
### What problem does this PR solve? add support for TTS model #1853 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Zhedong Cen <cenzhedong2@126.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@@ -20,7 +20,7 @@ from api.utils.api_utils import server_error_response, get_data_error_result, va
|
||||
from api.db import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel,CvModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel
|
||||
import requests
|
||||
import ast
|
||||
|
||||
@@ -142,6 +142,10 @@ def add_llm():
|
||||
llm_name = req["llm_name"]
|
||||
api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \
|
||||
f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}'
|
||||
elif factory == "Fish Audio":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = '{' + f'"fish_audio_ak": "{req.get("fish_audio_ak", "")}", ' \
|
||||
f'"fish_audio_refid": "{req.get("fish_audio_refid", "59cb5986671546eaa6ca8ae6f29f6d22")}"' + '}'
|
||||
else:
|
||||
llm_name = req["llm_name"]
|
||||
api_key = req.get("api_key","xxxxxxxxxxxxxxx")
|
||||
@@ -215,6 +219,15 @@ def add_llm():
|
||||
pass
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
|
||||
elif llm["model_type"] == LLMType.TTS:
|
||||
mdl = TTSModel[factory](
|
||||
key=llm["api_key"], model_name=llm["llm_name"], base_url=llm["api_base"]
|
||||
)
|
||||
try:
|
||||
for resp in mdl.transcription("Hello~ Ragflower!"):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({llm['llm_name']})." + str(e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
|
||||
@@ -410,7 +410,7 @@ def tenant_info():
|
||||
|
||||
@manager.route("/set_tenant_info", methods=["POST"])
|
||||
@login_required
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id")
|
||||
@validate_request("tenant_id", "asr_id", "embd_id", "img2txt_id", "llm_id", "tts_id")
|
||||
def set_tenant_info():
|
||||
req = request.json
|
||||
try:
|
||||
|
||||
@@ -55,6 +55,7 @@ class LLMType(StrEnum):
|
||||
SPEECH2TEXT = 'speech2text'
|
||||
IMAGE2TEXT = 'image2text'
|
||||
RERANK = 'rerank'
|
||||
TTS = 'tts'
|
||||
|
||||
|
||||
class ChatStyle(StrEnum):
|
||||
|
||||
@@ -449,6 +449,11 @@ class Tenant(DataBaseModel):
|
||||
null=False,
|
||||
help_text="default rerank model ID",
|
||||
index=True)
|
||||
tts_id = CharField(
|
||||
max_length=256,
|
||||
null=True,
|
||||
help_text="default tts model ID",
|
||||
index=True)
|
||||
parser_ids = CharField(
|
||||
max_length=256,
|
||||
null=False,
|
||||
@@ -958,6 +963,13 @@ def migrate_db():
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("tenant","tts_id",
|
||||
CharField(max_length=256,null=True,help_text="default tts model ID",index=True))
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column('api_4_conversation', 'source',
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.settings import database_logger
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel
|
||||
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
|
||||
from api.db import LLMType
|
||||
from api.db.db_models import DB, UserTenant
|
||||
from api.db.db_models import LLMFactories, LLM, TenantLLM
|
||||
@@ -75,6 +75,8 @@ class TenantLLMService(CommonService):
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
@@ -127,6 +129,14 @@ class TenantLLMService(CommonService):
|
||||
model_config["api_key"], model_config["llm_name"], lang,
|
||||
base_url=model_config["api_base"]
|
||||
)
|
||||
if llm_type == LLMType.TTS:
|
||||
if model_config["llm_factory"] not in TTSModel:
|
||||
return
|
||||
return TTSModel[model_config["llm_factory"]](
|
||||
model_config["api_key"],
|
||||
model_config["llm_name"],
|
||||
base_url=model_config["api_base"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
@@ -144,7 +154,9 @@ class TenantLLMService(CommonService):
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.RERANK:
|
||||
mdlnm = tenant.llm_id if not llm_name else llm_name
|
||||
mdlnm = tenant.rerank_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.TTS:
|
||||
mdlnm = tenant.tts_id if not llm_name else llm_name
|
||||
else:
|
||||
assert False, "LLM type error"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user