Add Q&A and Book, fix task running bugs (#50)

This commit is contained in:
KevinHuSh
2024-02-01 18:53:56 +08:00
committed by GitHub
parent 6224edcd1b
commit e6acaf6738
21 changed files with 628 additions and 276 deletions

View File

@@ -1,130 +1,138 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
import random
from timeit import default_timer as timer
from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService
from rag.parser.pdf_parser import HuParser
from rag.settings import cron_logger
from rag.utils import MINIO
from rag.utils import findMaxTm
import pandas as pd
from api.db import FileType
from api.db.services.document_service import DocumentService
from api.settings import database_logger
from api.utils import get_format_time, get_uuid
from api.utils.file_utils import get_project_base_directory
def collect(tm):
docs = DocumentService.get_newly_uploaded(tm)
if len(docs) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
def set_dispatching(docid):
try:
DocumentService.update_by_id(
docid, {"progress": random.randint(0, 3) / 100.,
"progress_msg": "Task dispatched...",
"process_begin_at": get_format_time()
})
except Exception as e:
cron_logger.error("set_dispatching:({}), {}".format(docid, str(e)))
def dispatch():
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm")
tm = findMaxTm(tm_fnm)
rows = collect(tm)
if len(rows) == 0:
return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
try:
tsks = TaskService.query(doc_id=r["id"])
if tsks:
for t in tsks:
TaskService.delete_by_id(t.id)
except Exception as e:
cron_logger.error("delete task exception:" + str(e))
def new_task():
nonlocal r
return {
"id": get_uuid(),
"doc_id": r["id"]
}
tsks = []
if r["type"] == FileType.PDF.value:
pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for p in range(0, pages, 10):
task = new_task()
task["from_page"] = p
task["to_page"] = min(p + 10, pages)
tsks.append(task)
else:
tsks.append(new_task())
print(tsks)
bulk_insert_into_db(Task, tsks, True)
set_dispatching(r["id"])
tmf.write(str(r["update_time"]) + "\n")
tmf.close()
def update_progress():
docs = DocumentService.get_unfinished_docs()
for d in docs:
try:
tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks:continue
msg = []
prg = 0
finished = True
bad = 0
for t in tsks:
if 0 <= t.progress < 1: finished = False
prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg)
if t.progress == -1: bad += 1
prg /= len(tsks)
if finished and bad: prg = -1
msg = "\n".join(msg)
DocumentService.update_by_id(d["id"], {"progress": prg, "progress_msg": msg, "process_duation": timer()-d["process_begin_at"].timestamp()})
except Exception as e:
cron_logger.error("fetch task exception:" + str(e))
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
while True:
dispatch()
time.sleep(3)
update_progress()
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
import random
from datetime import datetime
from api.db.db_models import Task
from api.db.db_utils import bulk_insert_into_db
from api.db.services.task_service import TaskService
from rag.parser.pdf_parser import HuParser
from rag.settings import cron_logger
from rag.utils import MINIO
from rag.utils import findMaxTm
import pandas as pd
from api.db import FileType, TaskStatus
from api.db.services.document_service import DocumentService
from api.settings import database_logger
from api.utils import get_format_time, get_uuid
from api.utils.file_utils import get_project_base_directory
def collect(tm):
docs = DocumentService.get_newly_uploaded(tm)
if len(docs) == 0:
return pd.DataFrame()
docs = pd.DataFrame(docs)
mtm = docs["update_time"].max()
cron_logger.info("TOTAL:{}, To:{}".format(len(docs), mtm))
return docs
def set_dispatching(docid):
try:
DocumentService.update_by_id(
docid, {"progress": random.randint(0, 3) / 100.,
"progress_msg": "Task dispatched...",
"process_begin_at": get_format_time()
})
except Exception as e:
cron_logger.error("set_dispatching:({}), {}".format(docid, str(e)))
def dispatch():
tm_fnm = os.path.join(get_project_base_directory(), "rag/res", f"broker.tm")
tm = findMaxTm(tm_fnm)
rows = collect(tm)
if len(rows) == 0:
return
tmf = open(tm_fnm, "a+")
for _, r in rows.iterrows():
try:
tsks = TaskService.query(doc_id=r["id"])
if tsks:
for t in tsks:
TaskService.delete_by_id(t.id)
except Exception as e:
cron_logger.error("delete task exception:" + str(e))
def new_task():
nonlocal r
return {
"id": get_uuid(),
"doc_id": r["id"]
}
tsks = []
if r["type"] == FileType.PDF.value:
pages = HuParser.total_page_number(r["name"], MINIO.get(r["kb_id"], r["location"]))
for p in range(0, pages, 10):
task = new_task()
task["from_page"] = p
task["to_page"] = min(p + 10, pages)
tsks.append(task)
else:
tsks.append(new_task())
print(tsks)
bulk_insert_into_db(Task, tsks, True)
set_dispatching(r["id"])
tmf.write(str(r["update_time"]) + "\n")
tmf.close()
def update_progress():
docs = DocumentService.get_unfinished_docs()
for d in docs:
try:
tsks = TaskService.query(doc_id=d["id"], order_by=Task.create_time)
if not tsks:continue
msg = []
prg = 0
finished = True
bad = 0
status = TaskStatus.RUNNING.value
for t in tsks:
if 0 <= t.progress < 1: finished = False
prg += t.progress if t.progress >= 0 else 0
msg.append(t.progress_msg)
if t.progress == -1: bad += 1
prg /= len(tsks)
if finished and bad:
prg = -1
status = TaskStatus.FAIL.value
elif finished: status = TaskStatus.DONE.value
msg = "\n".join(msg)
info = {"process_duation": datetime.timestamp(datetime.now())-d["process_begin_at"].timestamp(), "run": status}
if prg !=0 : info["progress"] = prg
if msg: info["progress_msg"] = msg
DocumentService.update_by_id(d["id"], info)
except Exception as e:
cron_logger.error("fetch task exception:" + str(e))
if __name__ == "__main__":
peewee_logger = logging.getLogger('peewee')
peewee_logger.propagate = False
peewee_logger.addHandler(database_logger.handlers[0])
peewee_logger.setLevel(database_logger.level)
while True:
dispatch()
time.sleep(3)
update_progress()

View File

@@ -24,8 +24,9 @@ import sys
from functools import partial
from timeit import default_timer as timer
from elasticsearch_dsl import Q
from api.db.services.task_service import TaskService
from rag.llm import EmbeddingModel, CvModel
from rag.settings import cron_logger, DOC_MAXIMUM_SIZE
from rag.utils import ELASTICSEARCH
from rag.utils import MINIO
@@ -35,7 +36,7 @@ from rag.nlp import search
from io import BytesIO
import pandas as pd
from rag.app import laws, paper, presentation, manual
from rag.app import laws, paper, presentation, manual, qa
from api.db import LLMType, ParserType
from api.db.services.document_service import DocumentService
@@ -51,13 +52,14 @@ FACTORY = {
ParserType.PRESENTATION.value: presentation,
ParserType.MANUAL.value: manual,
ParserType.LAWS.value: laws,
ParserType.QA.value: qa,
}
def set_progress(task_id, from_page, to_page, prog=None, msg="Processing..."):
cancel = TaskService.do_cancel(task_id)
if cancel:
msg = "Canceled."
msg += " [Canceled]"
prog = -1
if to_page > 0: msg = f"Page({from_page}~{to_page}): " + msg
@@ -166,13 +168,16 @@ def init_kb(row):
def embedding(docs, mdl):
tts, cnts = [d["docnm_kwd"] for d in docs], [d["content_with_weight"] for d in docs]
tts, cnts = [d["docnm_kwd"] for d in docs if d.get("docnm_kwd")], [d["content_with_weight"] for d in docs]
tk_count = 0
tts, c = mdl.encode(tts)
tk_count += c
if len(tts) == len(cnts):
tts, c = mdl.encode(tts)
tk_count += c
cnts, c = mdl.encode(cnts)
tk_count += c
vects = 0.1 * tts + 0.9 * cnts
vects = (0.1 * tts + 0.9 * cnts) if len(tts) == len(cnts) else cnts
assert len(vects) == len(docs)
for i, d in enumerate(docs):
v = vects[i].tolist()
@@ -215,12 +220,14 @@ def main(comm, mod):
callback(msg="Finished embedding! Start to build index!")
init_kb(r)
chunk_count = len(set([c["_id"] for c in cks]))
callback(1., "Done!")
es_r = ELASTICSEARCH.bulk(cks, search.index_name(r["tenant_id"]))
if es_r:
callback(-1, "Index failure!")
cron_logger.error(str(es_r))
else:
if TaskService.do_cancel(r["id"]):
ELASTICSEARCH.deleteByQuery(Q("match", doc_id=r["doc_id"]), idxnm=search.index_name(r["tenant_id"]))
callback(1., "Done!")
DocumentService.increment_chunk_num(r["doc_id"], r["kb_id"], tk_count, chunk_count, 0)
cron_logger.info("Chunk doc({}), token({}), chunks({})".format(r["id"], tk_count, len(cks)))