Added support for Baichuan LLM (#934)
### What problem does this PR solve? - Added support for Baichuan LLM ### Type of change - [x] New Feature (non-breaking change which adds functionality) Co-authored-by: 海贼宅 <stu_xyx@163.com>
This commit is contained in:
@@ -95,6 +95,84 @@ class DeepSeekChat(Base):
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
|
||||
class BaiChuanChat(Base):
|
||||
def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
|
||||
if not base_url:
|
||||
base_url = "https://api.baichuan-ai.com/v1"
|
||||
super().__init__(key, model_name, base_url)
|
||||
|
||||
@staticmethod
|
||||
def _format_params(params):
|
||||
return {
|
||||
"temperature": params.get("temperature", 0.3),
|
||||
"max_tokens": params.get("max_tokens", 2048),
|
||||
"top_p": params.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={
|
||||
"tools": [{
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_mode": "performance_first"
|
||||
}
|
||||
}]
|
||||
},
|
||||
**self._format_params(gen_conf))
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
return ans, response.usage.total_tokens
|
||||
except openai.APIError as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
extra_body={
|
||||
"tools": [{
|
||||
"type": "web_search",
|
||||
"web_search": {
|
||||
"enable": True,
|
||||
"search_mode": "performance_first"
|
||||
}
|
||||
}]
|
||||
},
|
||||
stream=True,
|
||||
**self._format_params(gen_conf))
|
||||
for resp in response:
|
||||
if resp.choices[0].finish_reason == "stop":
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
total_tokens = resp.usage.get('total_tokens', 0)
|
||||
if not resp.choices[0].delta.content:
|
||||
continue
|
||||
ans += resp.choices[0].delta.content
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
||||
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class QWenChat(Base):
|
||||
def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
|
||||
import dashscope
|
||||
|
||||
Reference in New Issue
Block a user