Refactor API for document and session (#2819)

### What problem does this PR solve?

Refactor API for document and session.

### Type of change


- [x] Refactoring

---------

Co-authored-by: liuhua <10215101452@stu.ecun.edu.cn>
This commit is contained in:
liuhua
2024-10-12 19:35:19 +08:00
committed by GitHub
parent 7d80fc474c
commit 6eed115723
9 changed files with 821 additions and 745 deletions

View File

@@ -19,6 +19,8 @@ import json
import re
from copy import deepcopy
from timeit import default_timer as timer
from api.db import LLMType, ParserType,StatusEnum
from api.db.db_models import Dialog, Conversation,DB
from api.db.services.common_service import CommonService
@@ -61,6 +63,22 @@ class DialogService(CommonService):
class ConversationService(CommonService):
model = Conversation
@classmethod
@DB.connection_context()
def get_list(cls,dialog_id,page_number, items_per_page, orderby, desc, id , name):
sessions = cls.model.select().where(cls.model.dialog_id ==dialog_id)
if id:
sessions = sessions.where(cls.model.id == id)
if name:
sessions = sessions.where(cls.model.name == name)
if desc:
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
sessions = sessions.paginate(page_number, items_per_page)
return list(sessions.dicts())
def message_fit_in(msg, max_length=4000):
def count():

View File

@@ -49,6 +49,29 @@ from rag.utils.redis_conn import REDIS_CONN
class DocumentService(CommonService):
model = Document
@classmethod
@DB.connection_context()
def get_list(cls, kb_id, page_number, items_per_page,
orderby, desc, keywords, id):
docs =cls.model.select().where(cls.model.kb_id==kb_id)
if id:
docs = docs.where(
cls.model.id== id )
if keywords:
docs = docs.where(
fn.LOWER(cls.model.name).contains(keywords.lower())
)
count = docs.count()
if desc:
docs = docs.order_by(cls.model.getter_by(orderby).desc())
else:
docs = docs.order_by(cls.model.getter_by(orderby).asc())
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
@@ -268,7 +291,7 @@ class DocumentService(CommonService):
@classmethod
@DB.connection_context()
def get_thumbnails(cls, docids):
fields = [cls.model.id, cls.model.kb_id, cls.model.thumbnail]
fields = [cls.model.id, cls.model.thumbnail]
return list(cls.model.select(
*fields).where(cls.model.id.in_(docids)).dicts())