refactor retieval_test, add SQl retrieval methods (#61)
This commit is contained in:
@@ -19,7 +19,6 @@ import dashscope
|
||||
from openai import OpenAI
|
||||
from FlagEmbedding import FlagModel
|
||||
import torch
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from rag.utils import num_tokens_from_string
|
||||
@@ -114,4 +113,21 @@ class QWenEmbed(Base):
|
||||
input=text[:2048],
|
||||
text_type="query"
|
||||
)
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
|
||||
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
|
||||
|
||||
|
||||
from zhipuai import ZhipuAI
|
||||
class ZhipuEmbed(Base):
|
||||
def __init__(self, key, model_name="embedding-2"):
|
||||
self.client = ZhipuAI(api_key=key)
|
||||
self.model_name = model_name
|
||||
|
||||
def encode(self, texts: list, batch_size=32):
|
||||
res = self.client.embeddings.create(input=texts,
|
||||
model=self.model_name)
|
||||
return np.array([d.embedding for d in res.data]), res.usage.total_tokens
|
||||
|
||||
def encode_queries(self, text):
|
||||
res = self.client.embeddings.create(input=text,
|
||||
model=self.model_name)
|
||||
return np.array(res["data"][0]["embedding"]), res.usage.total_tokens
|
||||
Reference in New Issue
Block a user