Add bce-embedding and fastembed (#383)
### What problem does this PR solve? Issue link:#326 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@@ -80,8 +80,12 @@ def chat(dialog, messages, **kwargs):
|
||||
raise LookupError("LLM(%s) not found" % dialog.llm_id)
|
||||
max_tokens = 1024
|
||||
else: max_tokens = llm[0].max_tokens
|
||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||
embd_nms = list(set([kb.embd_id for kb in kbs]))
|
||||
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
|
||||
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"]
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING)
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embd_nms[0])
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
|
||||
@@ -66,7 +66,7 @@ class TenantLLMService(CommonService):
|
||||
raise LookupError("Tenant not found")
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id
|
||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
@@ -77,9 +77,14 @@ class TenantLLMService(CommonService):
|
||||
assert False, "LLM type error"
|
||||
|
||||
model_config = cls.get_api_key(tenant_id, mdlnm)
|
||||
if model_config: model_config = model_config.to_dict()
|
||||
if not model_config:
|
||||
raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
model_config = model_config.to_dict()
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
llm = LLMService.query(llm_name=llm_name)
|
||||
if llm and llm[0].fid in ["QAnything", "FastEmbed"]:
|
||||
model_config = {"llm_factory": llm[0].fid, "api_key":"", "llm_name": llm_name, "api_base": ""}
|
||||
if not model_config: raise LookupError("Model({}) not authorized".format(mdlnm))
|
||||
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
if model_config["llm_factory"] not in EmbeddingModel:
|
||||
return
|
||||
|
||||
@@ -41,7 +41,7 @@ class TaskService(CommonService):
|
||||
Document.size,
|
||||
Knowledgebase.tenant_id,
|
||||
Knowledgebase.language,
|
||||
Tenant.embd_id,
|
||||
Knowledgebase.embd_id,
|
||||
Tenant.img2txt_id,
|
||||
Tenant.asr_id,
|
||||
cls.model.update_time]
|
||||
|
||||
Reference in New Issue
Block a user