be more specific for error message (#1409)

### What problem does this PR solve?

#918 

### Type of change

- [x] Refactoring
This commit is contained in:
KevinHuSh
2024-07-08 09:32:44 +08:00
committed by GitHub
parent dcb3fb2073
commit b3ebc66b13
9 changed files with 126 additions and 61 deletions

View File

@@ -95,14 +95,16 @@ def run():
final_ans = {"reference": [], "content": ""} final_ans = {"reference": [], "content": ""}
try: try:
canvas = Canvas(cvs.dsl, current_user.id) canvas = Canvas(cvs.dsl, current_user.id)
print(canvas)
if "message" in req: if "message" in req:
canvas.messages.append({"role": "user", "content": req["message"]}) canvas.messages.append({"role": "user", "content": req["message"]})
canvas.add_user_input(req["message"]) canvas.add_user_input(req["message"])
answer = canvas.run(stream=stream) answer = canvas.run(stream=stream)
print(canvas)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
assert answer, "Nothing. Is it over?"
if stream: if stream:
assert isinstance(answer, partial) assert isinstance(answer, partial)
@@ -116,7 +118,7 @@ def run():
yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
if "reference" in final_ans: if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"]) canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas)) cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict()) UserCanvasService.update_by_id(req["id"], cvs.to_dict())
@@ -134,7 +136,7 @@ def run():
return resp return resp
canvas.messages.append({"role": "assistant", "content": final_ans["content"]}) canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
if "reference" in final_ans: if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"]) canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas)) cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict()) UserCanvasService.update_by_id(req["id"], cvs.to_dict())

View File

@@ -121,7 +121,6 @@ class Canvas(ABC):
if desc["to"] not in cpn["downstream"]: if desc["to"] not in cpn["downstream"]:
cpn["downstream"].append(desc["to"]) cpn["downstream"].append(desc["to"])
self.path = self.dsl["path"] self.path = self.dsl["path"]
self.history = self.dsl["history"] self.history = self.dsl["history"]
self.messages = self.dsl["messages"] self.messages = self.dsl["messages"]
@@ -136,9 +135,21 @@ class Canvas(ABC):
self.dsl["answer"] = self.answer self.dsl["answer"] = self.answer
self.dsl["reference"] = self.reference self.dsl["reference"] = self.reference
self.dsl["embed_id"] = self._embed_id self.dsl["embed_id"] = self._embed_id
dsl = deepcopy(self.dsl) dsl = {
"components": {}
}
for k in self.dsl.keys():
if k in ["components"]:continue
dsl[k] = deepcopy(self.dsl[k])
for k, cpn in self.components.items(): for k, cpn in self.components.items():
dsl["components"][k]["obj"] = json.loads(str(cpn["obj"])) if k not in dsl["components"]:
dsl["components"][k] = {}
for c in cpn.keys():
if c == "obj":
dsl["components"][k][c] = json.loads(str(cpn["obj"]))
continue
dsl["components"][k][c] = deepcopy(cpn[c])
return json.dumps(dsl, ensure_ascii=False) return json.dumps(dsl, ensure_ascii=False)
def reset(self): def reset(self):
@@ -161,6 +172,9 @@ class Canvas(ABC):
except Exception as e: except Exception as e:
ans = ComponentBase.be_output(str(e)) ans = ComponentBase.be_output(str(e))
self.path[-1].append(cpn_id) self.path[-1].append(cpn_id)
if kwargs.get("stream"):
assert isinstance(ans, partial)
return ans
self.history.append(("assistant", ans.to_dict("records"))) self.history.append(("assistant", ans.to_dict("records")))
return ans return ans
@@ -190,6 +204,8 @@ class Canvas(ABC):
cpn = self.get_component(cpn_id) cpn = self.get_component(cpn_id)
if not cpn["downstream"]: break if not cpn["downstream"]: break
if self._find_loop(): raise OverflowError("Too much loops!")
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0] switch_out = cpn["obj"].output()[1].iloc[0, 0]
assert switch_out in self.components, \ assert switch_out in self.components, \
@@ -249,3 +265,27 @@ class Canvas(ABC):
def get_embedding_model(self): def get_embedding_model(self):
return self._embed_id return self._embed_id
def _find_loop(self, max_loops=2):
path = self.path[-1][::-1]
if len(path) < 2: return False
for i in range(len(path)):
if path[i].lower().find("answer") >= 0:
path = path[:i]
break
if len(path) < 2: return False
for l in range(1, len(path) // 2):
pat = ",".join(path[0:l])
path_str = ",".join(path)
if len(pat) >= len(path_str): return False
path_str = path_str[len(pat):]
loop = max_loops
while path_str.find(pat) >= 0 and loop >= 0:
loop -= 1
path_str = path_str[len(pat):]
if loop < 0: return True
return False

View File

@@ -19,7 +19,7 @@ import json
import os import os
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import List, Dict from typing import List, Dict, Tuple, Union
import pandas as pd import pandas as pd
@@ -246,7 +246,7 @@ class ComponentParamBase(ABC):
def check_empty(param, descr): def check_empty(param, descr):
if not param: if not param:
raise ValueError( raise ValueError(
descr + " {} not supported empty value." descr + " does not support empty value."
) )
@staticmethod @staticmethod
@@ -411,13 +411,24 @@ class ComponentBase(ABC):
def _run(self, history, **kwargs): def _run(self, history, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def output(self) -> pd.DataFrame: def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name) o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame): if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
if not isinstance(o, list): o = [o] if not isinstance(o, list): o = [o]
o = pd.DataFrame(o) o = pd.DataFrame(o)
if allow_partial or not isinstance(o, partial):
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
return pd.DataFrame(o if isinstance(o, list) else [o])
return self._param.output_var_name, o return self._param.output_var_name, o
outs = None
for oo in o():
if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
else: outs = oo
return self._param.output_var_name, outs
def reset(self): def reset(self):
setattr(self._param, self._param.output_var_name, None) setattr(self._param, self._param.output_var_name, None)
@@ -446,7 +457,7 @@ class ComponentBase(ABC):
if self.component_name.lower().find("answer") >= 0: if self.component_name.lower().find("answer") >= 0:
if self.get_component_name(u) in ["relevant"]: continue if self.get_component_name(u) in ["relevant"]: continue
upstream_outs.append(self._canvas.get_component(u)["obj"].output()[1]) else: upstream_outs.append(self._canvas.get_component(u)["obj"].output(allow_partial=False)[1])
break break
return pd.concat(upstream_outs, ignore_index=False) return pd.concat(upstream_outs, ignore_index=False)

View File

@@ -35,7 +35,10 @@ class CategorizeParam(GenerateParam):
def check(self): def check(self):
super().check() super().check()
self.check_empty(self.category_description, "Category examples") self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k: raise ValueError(f"[Categorize] Category name can not be empty!")
if not v["to"]: raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
def get_prompt(self): def get_prompt(self):
cate_lines = [] cate_lines = []

View File

@@ -33,31 +33,31 @@ class GenerateParam(ComponentParamBase):
super().__init__() super().__init__()
self.llm_id = "" self.llm_id = ""
self.prompt = "" self.prompt = ""
self.max_tokens = 256 self.max_tokens = 0
self.temperature = 0.1 self.temperature = 0
self.top_p = 0.3 self.top_p = 0
self.presence_penalty = 0.4 self.presence_penalty = 0
self.frequency_penalty = 0.7 self.frequency_penalty = 0
self.cite = True self.cite = True
#self.parameters = [] self.parameters = []
def check(self): def check(self):
self.check_decimal_float(self.temperature, "Temperature") self.check_decimal_float(self.temperature, "[Generate] Temperature")
self.check_decimal_float(self.presence_penalty, "Presence penalty") self.check_decimal_float(self.presence_penalty, "[Generate] Presence penalty")
self.check_decimal_float(self.frequency_penalty, "Frequency penalty") self.check_decimal_float(self.frequency_penalty, "[Generate] Frequency penalty")
self.check_positive_number(self.max_tokens, "Max tokens") self.check_nonnegative_number(self.max_tokens, "[Generate] Max tokens")
self.check_decimal_float(self.top_p, "Top P") self.check_decimal_float(self.top_p, "[Generate] Top P")
self.check_empty(self.llm_id, "LLM") self.check_empty(self.llm_id, "[Generate] LLM")
# self.check_defined_type(self.parameters, "Parameters", ["list"]) # self.check_defined_type(self.parameters, "Parameters", ["list"])
def gen_conf(self): def gen_conf(self):
return { conf = {}
"max_tokens": self.max_tokens, if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
"temperature": self.temperature, if self.temperature > 0: conf["temperature"] = self.temperature
"top_p": self.top_p, if self.top_p > 0: conf["top_p"] = self.top_p
"presence_penalty": self.presence_penalty, if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
"frequency_penalty": self.frequency_penalty, if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
} return conf
class Generate(ComponentBase): class Generate(ComponentBase):
@@ -69,7 +69,10 @@ class Generate(ComponentBase):
retrieval_res = self.get_input() retrieval_res = self.get_input()
input = "\n- ".join(retrieval_res["content"]) input = "\n- ".join(retrieval_res["content"])
for para in self._param.parameters:
cpn = self._canvas.get_component(para["component_id"])["obj"]
_, out = cpn.output(allow_partial=False)
kwargs[para["key"]] = "\n - ".join(out["content"])
kwargs["input"] = input kwargs["input"] = input
for n, v in kwargs.items(): for n, v in kwargs.items():
@@ -82,7 +85,8 @@ class Generate(ComponentBase):
if "empty_response" in retrieval_res.columns: if "empty_response" in retrieval_res.columns:
return Generate.be_output(input) return Generate.be_output(input)
ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf()) ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns: if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
ans, idx = retrievaler.insert_citations(ans, ans, idx = retrievaler.insert_citations(ans,
@@ -90,7 +94,8 @@ class Generate(ComponentBase):
for _, ck in retrieval_res.iterrows()], for _, ck in retrieval_res.iterrows()],
[ck["vector"] [ck["vector"]
for _, ck in retrieval_res.iterrows()], for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, self._canvas.get_embedding_model()), LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()),
tkweight=0.7, tkweight=0.7,
vtweight=0.3) vtweight=0.3)
del retrieval_res["vector"] del retrieval_res["vector"]
@@ -116,7 +121,8 @@ class Generate(ComponentBase):
return return
answer = "" answer = ""
for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf()): for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size),
self._param.gen_conf()):
res = {"content": ans, "reference": []} res = {"content": ans, "reference": []}
answer = ans answer = ans
yield res yield res
@@ -127,7 +133,8 @@ class Generate(ComponentBase):
for _, ck in retrieval_res.iterrows()], for _, ck in retrieval_res.iterrows()],
[ck["vector"] [ck["vector"]
for _, ck in retrieval_res.iterrows()], for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, self._canvas.get_embedding_model()), LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()),
tkweight=0.7, tkweight=0.7,
vtweight=0.3) vtweight=0.3)
doc_ids = set([]) doc_ids = set([])
@@ -152,5 +159,3 @@ class Generate(ComponentBase):
yield res yield res
self.set_output(res) self.set_output(res)

View File

@@ -32,7 +32,7 @@ class MessageParam(ComponentParamBase):
self.messages = [] self.messages = []
def check(self): def check(self):
self.check_empty(self.messages, "Message") self.check_empty(self.messages, "[Message]")
return True return True

View File

@@ -33,6 +33,8 @@ class RelevantParam(GenerateParam):
def check(self): def check(self):
super().check() super().check()
self.check_empty(self.yes, "[Relevant] 'Yes'")
self.check_empty(self.no, "[Relevant] 'No'")
def get_prompt(self): def get_prompt(self):
self.prompt = """ self.prompt = """

View File

@@ -40,10 +40,10 @@ class RetrievalParam(ComponentParamBase):
self.empty_response = "" self.empty_response = ""
def check(self): def check(self):
self.check_decimal_float(self.similarity_threshold, "Similarity threshold") self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
self.check_decimal_float(self.keywords_similarity_weight, "Keywords similarity weight") self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keywords similarity weight")
self.check_positive_number(self.top_n, "Top N") self.check_positive_number(self.top_n, "[Retrieval] Top N")
self.check_empty(self.kb_ids, "Knowledge bases") self.check_empty(self.kb_ids, "[Retrieval] Knowledge bases")
class Retrieval(ComponentBase, ABC): class Retrieval(ComponentBase, ABC):

View File

@@ -44,8 +44,10 @@ class SwitchParam(ComponentParamBase):
self.default = "" self.default = ""
def check(self): def check(self):
self.check_empty(self.conditions, "Switch conditions") self.check_empty(self.conditions, "[Switch] conditions")
self.check_empty(self.default, "Default path") self.check_empty(self.default, "[Switch] Default path")
for cond in self.conditions:
if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
def operators(self, field, op, value): def operators(self, field, op, value):
if op == "gt": if op == "gt":