add Moonshot, debug my_llm (#126)
This commit is contained in:
@@ -309,13 +309,13 @@ def use_sql(question, field_map, tenant_id, chat_mdl):
|
||||
# compose markdown table
|
||||
clmns = "|"+"|".join([re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in clmn_idx]) + ("|原文|" if docid_idx and docid_idx else "|")
|
||||
line = "|"+"|".join(["------" for _ in range(len(clmn_idx))]) + ("|------|" if docid_idx and docid_idx else "")
|
||||
line = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}\|", "|", line)
|
||||
rows = ["|"+"|".join([rmSpace(str(r[i])) for i in clmn_idx]).replace("None", " ") + "|" for r in tbl["rows"]]
|
||||
if not docid_idx or not docnm_idx:
|
||||
chat_logger.warning("SQL missing field: " + sql)
|
||||
return "\n".join([clmns, line, "\n".join(rows)]), []
|
||||
|
||||
rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)])
|
||||
rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows)
|
||||
docid_idx = list(docid_idx)[0]
|
||||
docnm_idx = list(docnm_idx)[0]
|
||||
return "\n".join([clmns, line, rows]), [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]]
|
||||
|
||||
@@ -39,36 +39,40 @@ def factories():
|
||||
def set_api_key():
|
||||
req = request.json
|
||||
# test if api key works
|
||||
chat_passed = False
|
||||
factory = req["llm_factory"]
|
||||
msg = ""
|
||||
for llm in LLMService.query(fid=req["llm_factory"]):
|
||||
for llm in LLMService.query(fid=factory):
|
||||
if llm.model_type == LLMType.EMBEDDING.value:
|
||||
mdl = EmbeddingModel[req["llm_factory"]](
|
||||
mdl = EmbeddingModel[factory](
|
||||
req["api_key"], llm.llm_name)
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0 or tc ==0: raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({llm.llm_name}) using this api key."
|
||||
elif llm.model_type == LLMType.CHAT.value:
|
||||
mdl = ChatModel[req["llm_factory"]](
|
||||
elif not chat_passed and llm.model_type == LLMType.CHAT.value:
|
||||
mdl = ChatModel[factory](
|
||||
req["api_key"], llm.llm_name)
|
||||
try:
|
||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
if not tc: raise Exception(m)
|
||||
chat_passed = True
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({llm.llm_name}) using this api key." + str(e)
|
||||
|
||||
if msg: return get_data_error_result(retmsg=msg)
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": req["llm_factory"],
|
||||
"api_key": req["api_key"]
|
||||
}
|
||||
for n in ["model_type", "llm_name"]:
|
||||
if n in req: llm[n] = req[n]
|
||||
|
||||
TenantLLMService.filter_update([TenantLLM.tenant_id==llm["tenant_id"], TenantLLM.llm_factory==llm["llm_factory"]], llm)
|
||||
if not TenantLLMService.filter_update([TenantLLM.tenant_id==current_user.id, TenantLLM.llm_factory==factory], llm):
|
||||
for llm in LLMService.query(fid=factory):
|
||||
TenantLLMService.save(tenant_id=current_user.id, llm_factory=factory, llm_name=llm.llm_name, model_type=llm.model_type, api_key=req["api_key"])
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user