refactor retieval_test, add SQl retrieval methods (#61)

This commit is contained in:
KevinHuSh
2024-02-08 17:01:01 +08:00
committed by GitHub
parent 0a903c7714
commit 5e0a689c43
16 changed files with 238 additions and 74 deletions

View File

@@ -19,7 +19,6 @@ import dashscope
from openai import OpenAI
from FlagEmbedding import FlagModel
import torch
import os
import numpy as np
from rag.utils import num_tokens_from_string
@@ -114,4 +113,21 @@ class QWenEmbed(Base):
input=text[:2048],
text_type="query"
)
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
return np.array(resp["output"]["embeddings"][0]["embedding"]), resp["usage"]["input_tokens"]
from zhipuai import ZhipuAI
class ZhipuEmbed(Base):
def __init__(self, key, model_name="embedding-2"):
self.client = ZhipuAI(api_key=key)
self.model_name = model_name
def encode(self, texts: list, batch_size=32):
res = self.client.embeddings.create(input=texts,
model=self.model_name)
return np.array([d.embedding for d in res.data]), res.usage.total_tokens
def encode_queries(self, text):
res = self.client.embeddings.create(input=text,
model=self.model_name)
return np.array(res["data"][0]["embedding"]), res.usage.total_tokens