@@ -2,7 +2,7 @@
import json
import re
from elasticsearch_dsl import Q , Search , A
from typing import List , Optional , Tuple , Dict , Union
from typing import List , Optional , Dict , Union
from dataclasses import dataclass
from rag . settings import es_logger
@@ -20,6 +20,8 @@ class Dealer:
self . qryr . flds = [
" title_tks^10 " ,
" title_sm_tks^5 " ,
" important_kwd^30 " ,
" important_tks^20 " ,
" content_ltks^2 " ,
" content_sm_ltks " ]
self . es = es
@@ -38,10 +40,10 @@ class Dealer:
def _vector ( self , txt , emb_mdl , sim = 0.8 , topk = 10 ) :
qv , c = emb_mdl . encode_queries ( txt )
return {
" field " : " q_ %d _vec " % len ( qv ) ,
" field " : " q_ %d _vec " % len ( qv ) ,
" k " : topk ,
" similarity " : sim ,
" num_candidates " : topk * 2 ,
" num_candidates " : topk * 2 ,
" query_vector " : qv
}
@@ -53,16 +55,18 @@ class Dealer:
if req . get ( " doc_ids " ) :
bqry . filter . append ( Q ( " terms " , doc_id = req [ " doc_ids " ] ) )
if " available_int " in req :
if req [ " available_int " ] == 0 : bqry . filter . append ( Q ( " range " , available_int = { " lt " : 1 } ) )
else : bqry . filter . append ( Q ( " bool " , must_not = Q ( " range" , available_int = { " lt " : 1 } ) ) )
if req [ " available_int " ] == 0 :
bqry . filter . append ( Q ( " range " , available_int = { " lt " : 1 } ) )
else :
bqry . filter . append ( Q ( " bool " , must_not = Q ( " range " , available_int = { " lt " : 1 } ) ) )
bqry . boost = 0.05
s = Search ( )
pg = int ( req . get ( " page " , 1 ) ) - 1
ps = int ( req . get ( " size " , 1000 ) )
src = req . get ( " fields " , [ " docnm_kwd " , " content_ltks " , " kb_id " , " img_id " ,
" image_id " , " doc_id " , " q_512_vec " , " q_768_vec " ,
" q_1024_vec " , " q_1536_vec " , " available_int " ] )
src = req . get ( " fields " , [ " docnm_kwd " , " content_ltks " , " kb_id " , " img_id " ,
" image_id " , " doc_id " , " q_512_vec " , " q_768_vec " ,
" q_1024_vec " , " q_1536_vec " , " available_int " ] )
s = s . query ( bqry ) [ pg * ps : ( pg + 1 ) * ps ]
s = s . highlight ( " content_ltks " )
@@ -171,74 +175,106 @@ class Dealer:
def trans2floats ( txt ) :
return [ float ( t ) for t in txt . split ( " \t " ) ]
def insert_citations ( self , ans , top_idx , sres , emb_mdl ,
vfield = " q_vec " , cfield = " content_ltks " ) :
def insert_citations ( self , answer , chunks , chunk_v , embd _mdl , tkweight = 0.3 , vtweight = 0.7 ) :
pieces = re . split ( r " ([;。?!! \ n]|[a-z][.?;!][ \ n]) " , answer )
for i in range ( 1 , len ( pieces ) ) :
if re . match ( r " [a-z][.?;!][ \ n] " , pieces [ i ] ) :
pieces [ i - 1 ] + = pieces [ i ] [ 0 ]
pieces [ i ] = pieces [ i ] [ 1 : ]
idx = [ ]
pieces_ = [ ]
for i , t in enumerate ( pieces ) :
if len ( t ) < 5 : continue
idx . append ( i )
pieces_ . append ( t )
if not pieces_ : return answer
i ns_embd = [ Dealer . trans2floats (
sres . field [ sres . ids [ i ] ] [ vfield ] ) for i in top_idx ]
ins_tw = [ sres . field [ sres . ids [ i ] ] [ cfield ] . split ( " " ) for i in top_idx ]
s = 0
e = 0
res = " "
a ns_v = embd_mdl . encode ( pieces_ )
assert len ( ans_v [ 0 ] ) == len ( chunk_v [ 0 ] ) , " The dimension of query and chunk do not match: {} vs. {} " . format (
len ( ans_v [ 0 ] ) , len ( chunk_v [ 0 ] ) )
def citeit ( ) :
nonlocal s , e , ans , res , emb_mdl
if not ins_embd :
return
embd = emb_mdl . encode ( ans [ s : e ] )
sim = self . qryr . hybrid_similar ity ( embd ,
ins_embd ,
huqie . qie ( ans [ s : e ] ) . split ( " " ) ,
ins_tw )
chunks_tks = [ huqie . qie ( ck ) . split ( " " ) for ck in chunks ]
cites = { }
for i , a in enumerate ( pieces_ ) :
sim , tksim , vtsim = self . qryr . hybrid_similarity ( ans_v [ i ] ,
chunk_v ,
huqie . qie ( pieces_ [ i ] ) . spl it ( " " ) ,
chunks_tks ,
tkweight , vtweight )
mx = np . max ( sim ) * 0.99
if mx < 0.55 :
return
cita = list ( set ( [ top_idx [ i ]
for i in range ( len ( ins_embd ) ) if sim [ i ] > mx ] ) ) [ : 4 ]
for i in cita :
res + = f " @? { i } ?@ "
if mx < 0.55 : continue
cites [ idx [ i ] ] = list ( set ( [ str ( i ) for i in range ( len ( chunk_v ) ) if sim [ i ] > mx ] ) ) [ : 4 ]
return cita
punct = set ( " ;。?!! " )
if not self . qryr . isChinese ( ans ) :
punct . add ( " ? " )
punct . add ( " . " )
while e < len ( ans ) :
if e - s < 12 or ans [ e ] not in punct :
e + = 1
continue
if ans [ e ] == " . " and e + \
1 < len ( ans ) and re . match ( r " [0-9] " , ans [ e + 1 ] ) :
e + = 1
continue
if ans [ e ] == " . " and e - 2 > = 0 and ans [ e - 2 ] == " \n " :
e + = 1
continue
res + = ans [ s : e ]
citeit ( )
res + = ans [ e ]
e + = 1
s = e
if s < len ( ans ) :
res + = ans [ s : ]
citeit ( )
res = " "
for i , p in enumerate ( pieces ) :
res + = p
if i not in idx : continue
if i not in cites : continue
res + = " ## %s $$ " % " $ " . join ( cites [ i ] )
return res
def rerank ( self , sres , query , tkweight = 0.3 , vtweight = 0.7 , cfield = " content_ltks " ) :
ins_embd = [
Dealer . trans2floats (
sres . field [ i ] [ " q_ %d _vec " % len ( sres . query_vector ) ] ) for i in sres . ids ]
sres . field [ i ] [ " q_ %d _vec " % len ( sres . query_vector ) ] ) for i in sres . ids ]
if not ins_embd :
return [ ]
ins_tw = [ sres . field [ i ] [ cfield ] . split ( " " ) for i in sres . ids ]
ins_tw = [ huqie . qie ( sres. field [ i ] [ cfield ] ) . split ( " " ) for i in sres . ids ]
sim , tksim , vtsim = self . qryr . hybrid_similarity ( sres . query_vector ,
ins_embd ,
huqie . qie ( query ) . split ( " " ) ,
ins_tw , tkweight , vtweight )
ins_embd ,
huqie . qie ( query ) . split ( " " ) ,
ins_tw , tkweight , vtweight )
return sim , tksim , vtsim
def hybrid_similarity ( self , ans_embd , ins_embd , ans , inst ) :
return self . qryr . hybrid_similarity ( ans_embd ,
ins_embd ,
huqie . qie ( ans ) . split ( " " ) ,
huqie . qie ( inst ) . split ( " " ) )
def retrieval ( self , question , embd_mdl , tenant_id , kb_ids , page , page_size , similarity_threshold = 0.2 ,
vector_similarity_weight = 0.3 , top = 1024 , doc_ids = None , aggs = True ) :
req = { " kb_ids " : kb_ids , " doc_ids " : doc_ids , " size " : top ,
" question " : question , " vector " : True ,
" similarity " : similarity_threshold }
sres = self . search ( req , index_name ( tenant_id ) , embd_mdl )
sim , tsim , vsim = self . rerank (
sres , question , 1 - vector_similarity_weight , vector_similarity_weight )
idx = np . argsort ( sim * - 1 )
ranks = { " total " : 0 , " chunks " : [ ] , " doc_aggs " : { } }
dim = len ( sres . query_vector )
start_idx = ( page - 1 ) * page_size
for i in idx :
ranks [ " total " ] + = 1
if sim [ i ] < similarity_threshold :
break
start_idx - = 1
if start_idx > = 0 :
continue
if len ( ranks [ " chunks " ] ) == page_size :
if aggs :
continue
break
id = sres . ids [ i ]
dnm = sres . field [ id ] [ " docnm_kwd " ]
d = {
" chunk_id " : id ,
" content_ltks " : sres . field [ id ] [ " content_ltks " ] ,
" doc_id " : sres . field [ id ] [ " doc_id " ] ,
" docnm_kwd " : dnm ,
" kb_id " : sres . field [ id ] [ " kb_id " ] ,
" important_kwd " : sres . field [ id ] . get ( " important_kwd " , [ ] ) ,
" img_id " : sres . field [ id ] . get ( " img_id " , " " ) ,
" similarity " : sim [ i ] ,
" vector_similarity " : vsim [ i ] ,
" term_similarity " : tsim [ i ] ,
" vector " : self . trans2floats ( sres . field [ id ] . get ( " q_ %d _vec " % dim , " \t " . join ( [ " 0 " ] * dim ) ) )
}
ranks [ " chunks " ] . append ( d )
if dnm not in ranks [ " doc_aggs " ] :
ranks [ " doc_aggs " ] [ dnm ] = 0
ranks [ " doc_aggs " ] [ dnm ] + = 1
return ranks