rename vision, add layour and tsr recognizer (#70)
* rename vision, add layour and tsr recognizer * trivial fixing
This commit is contained in:
4
deepdoc/vision/__init__.py
Normal file
4
deepdoc/vision/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .ocr import OCR
|
||||
from .recognizer import Recognizer
|
||||
from .layout_recognizer import LayoutRecognizer
|
||||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
119
deepdoc/vision/layout_recognizer.py
Normal file
119
deepdoc/vision/layout_recognizer.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .recognizer import Recognizer
|
||||
|
||||
|
||||
class LayoutRecognizer(Recognizer):
|
||||
def __init__(self, domain):
|
||||
self.layout_labels = [
|
||||
"_background_",
|
||||
"Text",
|
||||
"Title",
|
||||
"Figure",
|
||||
"Figure caption",
|
||||
"Table",
|
||||
"Table caption",
|
||||
"Header",
|
||||
"Footer",
|
||||
"Reference",
|
||||
"Equation",
|
||||
]
|
||||
super().__init__(self.layout_labels, domain,
|
||||
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||
|
||||
def __call__(self, image_list, ocr_res, scale_factor=3, thr=0.7, batch_size=16):
|
||||
def __is_garbage(b):
|
||||
patt = [r"^•+$", r"(版权归©|免责条款|地址[::])", r"\.{3,}", "^[0-9]{1,2} / ?[0-9]{1,2}$",
|
||||
r"^[0-9]{1,2} of [0-9]{1,2}$", "^http://[^ ]{12,}",
|
||||
"(资料|数据)来源[::]", "[0-9a-z._-]+@[a-z0-9-]+\\.[a-z]{2,3}",
|
||||
"\\(cid *: *[0-9]+ *\\)"
|
||||
]
|
||||
return any([re.search(p, b["text"]) for p in patt])
|
||||
|
||||
layouts = super().__call__(image_list, thr, batch_size)
|
||||
# save_results(image_list, layouts, self.layout_labels, output_dir='output/', threshold=0.7)
|
||||
assert len(image_list) == len(ocr_res)
|
||||
# Tag layout type
|
||||
boxes = []
|
||||
assert len(image_list) == len(layouts)
|
||||
garbages = {}
|
||||
page_layout = []
|
||||
for pn, lts in enumerate(layouts):
|
||||
bxs = ocr_res[pn]
|
||||
lts = [{"type": b["type"],
|
||||
"score": float(b["score"]),
|
||||
"x0": b["bbox"][0] / scale_factor, "x1": b["bbox"][2] / scale_factor,
|
||||
"top": b["bbox"][1] / scale_factor, "bottom": b["bbox"][-1] / scale_factor,
|
||||
"page_number": pn,
|
||||
} for b in lts]
|
||||
lts = self.sort_Y_firstly(lts, np.mean([l["bottom"]-l["top"] for l in lts]) / 2)
|
||||
lts = self.layouts_cleanup(bxs, lts)
|
||||
page_layout.append(lts)
|
||||
|
||||
# Tag layout type, layouts are ready
|
||||
def findLayout(ty):
|
||||
nonlocal bxs, lts, self
|
||||
lts_ = [lt for lt in lts if lt["type"] == ty]
|
||||
i = 0
|
||||
while i < len(bxs):
|
||||
if bxs[i].get("layout_type"):
|
||||
i += 1
|
||||
continue
|
||||
if __is_garbage(bxs[i]):
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
ii = self.find_overlapped_with_threashold(bxs[i], lts_,
|
||||
thr=0.4)
|
||||
if ii is None: # belong to nothing
|
||||
bxs[i]["layout_type"] = ""
|
||||
i += 1
|
||||
continue
|
||||
lts_[ii]["visited"] = True
|
||||
if lts_[ii]["type"] in ["footer", "header", "reference"]:
|
||||
if lts_[ii]["type"] not in garbages:
|
||||
garbages[lts_[ii]["type"]] = []
|
||||
garbages[lts_[ii]["type"]].append(bxs[i]["text"])
|
||||
bxs.pop(i)
|
||||
continue
|
||||
|
||||
bxs[i]["layoutno"] = f"{ty}-{ii}"
|
||||
bxs[i]["layout_type"] = lts_[ii]["type"]
|
||||
i += 1
|
||||
|
||||
for lt in ["footer", "header", "reference", "figure caption",
|
||||
"table caption", "title", "text", "table", "figure", "equation"]:
|
||||
findLayout(lt)
|
||||
|
||||
# add box to figure layouts which has not text box
|
||||
for i, lt in enumerate(
|
||||
[lt for lt in lts if lt["type"] == "figure"]):
|
||||
if lt.get("visited"):
|
||||
continue
|
||||
lt = deepcopy(lt)
|
||||
del lt["type"]
|
||||
lt["text"] = ""
|
||||
lt["layout_type"] = "figure"
|
||||
lt["layoutno"] = f"figure-{i}"
|
||||
bxs.append(lt)
|
||||
|
||||
boxes.extend(bxs)
|
||||
|
||||
ocr_res = boxes
|
||||
|
||||
garbag_set = set()
|
||||
for k in garbages.keys():
|
||||
garbages[k] = Counter(garbages[k])
|
||||
for g, c in garbages[k].items():
|
||||
if c > 1:
|
||||
garbag_set.add(g)
|
||||
|
||||
ocr_res = [b for b in ocr_res if b["text"].strip() not in garbag_set]
|
||||
return ocr_res, page_layout
|
||||
|
||||
561
deepdoc/vision/ocr.py
Normal file
561
deepdoc/vision/ocr.py
Normal file
@@ -0,0 +1,561 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import copy
|
||||
import time
|
||||
import os
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from .operators import *
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from .postprocess import build_post_process
|
||||
from rag.settings import cron_logger
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
""" transform """
|
||||
if ops is None:
|
||||
ops = []
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
|
||||
def create_operators(op_param_list, global_config=None):
|
||||
"""
|
||||
create operators based on the config
|
||||
|
||||
Args:
|
||||
params(list): a dict list, used to create some operators
|
||||
"""
|
||||
assert isinstance(
|
||||
op_param_list, list), ('operator config should be a list')
|
||||
ops = []
|
||||
for operator in op_param_list:
|
||||
assert isinstance(operator,
|
||||
dict) and len(operator) == 1, "yaml format error"
|
||||
op_name = list(operator)[0]
|
||||
param = {} if operator[op_name] is None else operator[op_name]
|
||||
if global_config is not None:
|
||||
param.update(global_config)
|
||||
op = eval(op_name)(**param)
|
||||
ops.append(op)
|
||||
return ops
|
||||
|
||||
|
||||
def load_model(model_dir, nm):
|
||||
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(
|
||||
model_file_path))
|
||||
sess = ort.InferenceSession(model_file_path)
|
||||
return sess, sess.get_inputs()[0]
|
||||
|
||||
|
||||
class TextRecognizer(object):
|
||||
def __init__(self, model_dir):
|
||||
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
||||
self.rec_batch_num = 16
|
||||
postprocess_params = {
|
||||
'name': 'CTCLabelDecode',
|
||||
"character_dict_path": os.path.join(os.path.dirname(os.path.realpath(__file__)), "ocr.res"),
|
||||
"use_space_char": True
|
||||
}
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor = load_model(model_dir, 'rec')
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
|
||||
assert imgC == img.shape[2]
|
||||
imgW = int((imgH * max_wh_ratio))
|
||||
w = self.input_tensor.shape[3:][0]
|
||||
if isinstance(w, str):
|
||||
pass
|
||||
elif w is not None and w > 0:
|
||||
imgW = w
|
||||
h, w = img.shape[:2]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def resize_norm_img_vl(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
img = img[:, :, ::-1] # bgr2rgb
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_srn(self, img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(
|
||||
gsrm_slf_attn_bias1,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(
|
||||
gsrm_slf_attn_bias2,
|
||||
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [
|
||||
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
|
||||
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
|
||||
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
||||
def resize_norm_img_sar(self, img, image_shape,
|
||||
width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
def resize_norm_img_spin(self, img):
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
# return padding_im
|
||||
img = cv2.resize(img, tuple([100, 32]), cv2.INTER_CUBIC)
|
||||
img = np.array(img, np.float32)
|
||||
img = np.expand_dims(img, -1)
|
||||
img = img.transpose((2, 0, 1))
|
||||
mean = [127.5]
|
||||
std = [127.5]
|
||||
mean = np.array(mean, dtype=np.float32)
|
||||
std = np.array(std, dtype=np.float32)
|
||||
mean = np.float32(mean.reshape(1, -1))
|
||||
stdinv = 1 / np.float32(std.reshape(1, -1))
|
||||
img -= mean
|
||||
img *= stdinv
|
||||
return img
|
||||
|
||||
def resize_norm_img_svtr(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
def resize_norm_img_abinet(self, img, image_shape):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image / 255.
|
||||
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
resized_image = (
|
||||
resized_image - mean[None, None, ...]) / std[None, None, ...]
|
||||
resized_image = resized_image.transpose((2, 0, 1))
|
||||
resized_image = resized_image.astype('float32')
|
||||
|
||||
return resized_image
|
||||
|
||||
def norm_img_can(self, img, image_shape):
|
||||
|
||||
img = cv2.cvtColor(
|
||||
img, cv2.COLOR_BGR2GRAY) # CAN only predict gray scale image
|
||||
|
||||
if self.rec_image_shape[0] == 1:
|
||||
h, w = img.shape
|
||||
_, imgH, imgW = self.rec_image_shape
|
||||
if h < imgH or w < imgW:
|
||||
padding_h = max(imgH - h, 0)
|
||||
padding_w = max(imgW - w, 0)
|
||||
img_padded = np.pad(img, ((0, padding_h), (0, padding_w)),
|
||||
'constant',
|
||||
constant_values=(255))
|
||||
img = img_padded
|
||||
|
||||
img = np.expand_dims(img, 0) / 255.0 # h,w,c -> c,h,w
|
||||
img = img.astype('float32')
|
||||
|
||||
return img
|
||||
|
||||
def __call__(self, img_list):
|
||||
img_num = len(img_list)
|
||||
# Calculate the aspect ratio of all text bars
|
||||
width_list = []
|
||||
for img in img_list:
|
||||
width_list.append(img.shape[1] / float(img.shape[0]))
|
||||
# Sorting can speed up the recognition process
|
||||
indices = np.argsort(np.array(width_list))
|
||||
rec_res = [['', 0.0]] * img_num
|
||||
batch_num = self.rec_batch_num
|
||||
st = time.time()
|
||||
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
imgC, imgH, imgW = self.rec_image_shape[:3]
|
||||
max_wh_ratio = imgW / imgH
|
||||
# max_wh_ratio = 0
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
h, w = img_list[indices[ino]].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
for ino in range(beg_img_no, end_img_no):
|
||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||
max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = norm_img_batch
|
||||
outputs = self.predictor.run(None, input_dict)
|
||||
preds = outputs[0]
|
||||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
|
||||
return rec_res, time.time() - st
|
||||
|
||||
|
||||
class TextDetector(object):
|
||||
def __init__(self, model_dir):
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
'limit_side_len': 960,
|
||||
'limit_type': "max",
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'scale': '1./255.',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': ['image', 'shape']
|
||||
}
|
||||
}]
|
||||
postprocess_params = {"name": "DBPostProcess", "thresh": 0.3, "box_thresh": 0.6, "max_candidates": 1000,
|
||||
"unclip_ratio": 1.5, "use_dilation": False, "score_mode": "fast", "box_type": "quad"}
|
||||
|
||||
self.postprocess_op = build_post_process(postprocess_params)
|
||||
self.predictor, self.input_tensor = load_model(model_dir, 'det')
|
||||
|
||||
img_h, img_w = self.input_tensor.shape[2:]
|
||||
if isinstance(img_h, str) or isinstance(img_w, str):
|
||||
pass
|
||||
elif img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
|
||||
pre_process_list[0] = {
|
||||
'DetResizeForTest': {
|
||||
'image_shape': [img_h, img_w]
|
||||
}
|
||||
}
|
||||
self.preprocess_op = create_operators(pre_process_list)
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
s = pts.sum(axis=1)
|
||||
rect[0] = pts[np.argmin(s)]
|
||||
rect[2] = pts[np.argmax(s)]
|
||||
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
||||
diff = np.diff(np.array(tmp), axis=1)
|
||||
rect[1] = tmp[np.argmin(diff)]
|
||||
rect[3] = tmp[np.argmax(diff)]
|
||||
return rect
|
||||
|
||||
def clip_det_res(self, points, img_height, img_width):
|
||||
for pno in range(points.shape[0]):
|
||||
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
|
||||
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
|
||||
return points
|
||||
|
||||
def filter_tag_det_res(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if isinstance(box, list):
|
||||
box = np.array(box)
|
||||
box = self.order_points_clockwise(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
||||
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
||||
if rect_width <= 3 or rect_height <= 3:
|
||||
continue
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
|
||||
img_height, img_width = image_shape[0:2]
|
||||
dt_boxes_new = []
|
||||
for box in dt_boxes:
|
||||
if isinstance(box, list):
|
||||
box = np.array(box)
|
||||
box = self.clip_det_res(box, img_height, img_width)
|
||||
dt_boxes_new.append(box)
|
||||
dt_boxes = np.array(dt_boxes_new)
|
||||
return dt_boxes
|
||||
|
||||
def __call__(self, img):
|
||||
ori_im = img.copy()
|
||||
data = {'image': img}
|
||||
|
||||
st = time.time()
|
||||
data = transform(data, self.preprocess_op)
|
||||
img, shape_list = data
|
||||
if img is None:
|
||||
return None, 0
|
||||
img = np.expand_dims(img, axis=0)
|
||||
shape_list = np.expand_dims(shape_list, axis=0)
|
||||
img = img.copy()
|
||||
input_dict = {}
|
||||
input_dict[self.input_tensor.name] = img
|
||||
outputs = self.predictor.run(None, input_dict)
|
||||
|
||||
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
||||
return dt_boxes, time.time() - st
|
||||
|
||||
|
||||
class OCR(object):
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/deepdoc")
|
||||
|
||||
self.text_detector = TextDetector(model_dir)
|
||||
self.text_recognizer = TextRecognizer(model_dir)
|
||||
self.drop_score = 0.5
|
||||
self.crop_image_res_index = 0
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
'''
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
top = int(np.min(points[:, 1]))
|
||||
bottom = int(np.max(points[:, 1]))
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
'''
|
||||
assert len(points) == 4, "shape of points must be 4*2"
|
||||
img_crop_width = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[1]),
|
||||
np.linalg.norm(points[2] - points[3])))
|
||||
img_crop_height = int(
|
||||
max(
|
||||
np.linalg.norm(points[0] - points[3]),
|
||||
np.linalg.norm(points[1] - points[2])))
|
||||
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height]])
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img,
|
||||
M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
return dst_img
|
||||
|
||||
def sorted_boxes(self, dt_boxes):
|
||||
"""
|
||||
Sort text boxes in order from top to bottom, left to right
|
||||
args:
|
||||
dt_boxes(array):detected text boxes with shape [4, 2]
|
||||
return:
|
||||
sorted boxes(array) with shape [4, 2]
|
||||
"""
|
||||
num_boxes = dt_boxes.shape[0]
|
||||
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
||||
_boxes = list(sorted_boxes)
|
||||
|
||||
for i in range(num_boxes - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
||||
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
||||
tmp = _boxes[j]
|
||||
_boxes[j] = _boxes[j + 1]
|
||||
_boxes[j + 1] = tmp
|
||||
else:
|
||||
break
|
||||
return _boxes
|
||||
|
||||
def __call__(self, img, cls=True):
|
||||
time_dict = {'det': 0, 'rec': 0, 'cls': 0, 'all': 0}
|
||||
|
||||
if img is None:
|
||||
return None, None, time_dict
|
||||
|
||||
start = time.time()
|
||||
ori_im = img.copy()
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
time_dict['det'] = elapse
|
||||
|
||||
if dt_boxes is None:
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
return None, None, time_dict
|
||||
else:
|
||||
cron_logger.debug("dt_boxes num : {}, elapsed : {}".format(
|
||||
len(dt_boxes), elapse))
|
||||
img_crop_list = []
|
||||
|
||||
dt_boxes = self.sorted_boxes(dt_boxes)
|
||||
|
||||
for bno in range(len(dt_boxes)):
|
||||
tmp_box = copy.deepcopy(dt_boxes[bno])
|
||||
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
||||
img_crop_list.append(img_crop)
|
||||
|
||||
rec_res, elapse = self.text_recognizer(img_crop_list)
|
||||
time_dict['rec'] = elapse
|
||||
cron_logger.debug("rec_res num : {}, elapsed : {}".format(
|
||||
len(rec_res), elapse))
|
||||
|
||||
filter_boxes, filter_rec_res = [], []
|
||||
for box, rec_result in zip(dt_boxes, rec_res):
|
||||
text, score = rec_result
|
||||
if score >= self.drop_score:
|
||||
filter_boxes.append(box)
|
||||
filter_rec_res.append(rec_result)
|
||||
end = time.time()
|
||||
time_dict['all'] = end - start
|
||||
|
||||
#for bno in range(len(img_crop_list)):
|
||||
# print(f"{bno}, {rec_res[bno]}")
|
||||
|
||||
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))
|
||||
6623
deepdoc/vision/ocr.res
Normal file
6623
deepdoc/vision/ocr.res
Normal file
File diff suppressed because it is too large
Load Diff
710
deepdoc/vision/operators.py
Normal file
710
deepdoc/vision/operators.py
Normal file
@@ -0,0 +1,710 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
""" decode image """
|
||||
|
||||
def __init__(self,
|
||||
img_mode='RGB',
|
||||
channel_first=False,
|
||||
ignore_orientation=False,
|
||||
**kwargs):
|
||||
self.img_mode = img_mode
|
||||
self.channel_first = channel_first
|
||||
self.ignore_orientation = ignore_orientation
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if six.PY2:
|
||||
assert isinstance(img, str) and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
else:
|
||||
assert isinstance(img, bytes) and len(
|
||||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
if self.ignore_orientation:
|
||||
img = cv2.imdecode(img, cv2.IMREAD_IGNORE_ORIENTATION |
|
||||
cv2.IMREAD_COLOR)
|
||||
else:
|
||||
img = cv2.imdecode(img, 1)
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
|
||||
img.shape)
|
||||
img = img[:, :, ::-1]
|
||||
|
||||
if self.channel_first:
|
||||
img = img.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
class StandardizeImage(object):
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
std (list): im / std
|
||||
is_scale (bool): whether need im / 255
|
||||
norm_type (str): type in ['mean_std', 'none']
|
||||
"""
|
||||
|
||||
def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
|
||||
self.mean = mean
|
||||
self.std = std
|
||||
self.is_scale = is_scale
|
||||
self.norm_type = norm_type
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.astype(np.float32, copy=False)
|
||||
if self.is_scale:
|
||||
scale = 1.0 / 255.0
|
||||
im *= scale
|
||||
|
||||
if self.norm_type == 'mean_std':
|
||||
mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
|
||||
std = np.array(self.std)[np.newaxis, np.newaxis, :]
|
||||
im -= mean
|
||||
im /= std
|
||||
return im, im_info
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
""" normalize image such as substract mean, divide std
|
||||
"""
|
||||
|
||||
def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):
|
||||
if isinstance(scale, str):
|
||||
scale = eval(scale)
|
||||
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
|
||||
mean = mean if mean is not None else [0.485, 0.456, 0.406]
|
||||
std = std if std is not None else [0.229, 0.224, 0.225]
|
||||
|
||||
shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
|
||||
self.mean = np.array(mean).reshape(shape).astype('float32')
|
||||
self.std = np.array(std).reshape(shape).astype('float32')
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
""" convert hwc image to chw image
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
from PIL import Image
|
||||
if isinstance(img, Image.Image):
|
||||
img = np.array(img)
|
||||
data['image'] = img.transpose((2, 0, 1))
|
||||
return data
|
||||
|
||||
|
||||
class Fasttext(object):
|
||||
def __init__(self, path="None", **kwargs):
|
||||
import fasttext
|
||||
self.fast_model = fasttext.load_model(path)
|
||||
|
||||
def __call__(self, data):
|
||||
label = data['label']
|
||||
fast_label = self.fast_model[label]
|
||||
data['fast_label'] = fast_label
|
||||
return data
|
||||
|
||||
|
||||
class KeepKeys(object):
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
def __call__(self, data):
|
||||
data_list = []
|
||||
for key in self.keep_keys:
|
||||
data_list.append(data[key])
|
||||
return data_list
|
||||
|
||||
|
||||
class Pad(object):
|
||||
def __init__(self, size=None, size_div=32, **kwargs):
|
||||
if size is not None and not isinstance(size, (int, list, tuple)):
|
||||
raise TypeError("Type of target_size is invalid. Now is {}".format(
|
||||
type(size)))
|
||||
if isinstance(size, int):
|
||||
size = [size, size]
|
||||
self.size = size
|
||||
self.size_div = size_div
|
||||
|
||||
def __call__(self, data):
|
||||
|
||||
img = data['image']
|
||||
img_h, img_w = img.shape[0], img.shape[1]
|
||||
if self.size:
|
||||
resize_h2, resize_w2 = self.size
|
||||
assert (
|
||||
img_h < resize_h2 and img_w < resize_w2
|
||||
), '(h, w) of target size should be greater than (img_h, img_w)'
|
||||
else:
|
||||
resize_h2 = max(
|
||||
int(math.ceil(img.shape[0] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
resize_w2 = max(
|
||||
int(math.ceil(img.shape[1] / self.size_div) * self.size_div),
|
||||
self.size_div)
|
||||
img = cv2.copyMakeBorder(
|
||||
img,
|
||||
0,
|
||||
resize_h2 - img_h,
|
||||
0,
|
||||
resize_w2 - img_w,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
data['image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class LinearResize(object):
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
keep_ratio (bool): whether keep_ratio or not, default true
|
||||
interp (int): method of resize
|
||||
"""
|
||||
|
||||
def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
|
||||
if isinstance(target_size, int):
|
||||
target_size = [target_size, target_size]
|
||||
self.target_size = target_size
|
||||
self.keep_ratio = keep_ratio
|
||||
self.interp = interp
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
assert len(self.target_size) == 2
|
||||
assert self.target_size[0] > 0 and self.target_size[1] > 0
|
||||
im_channel = im.shape[2]
|
||||
im_scale_y, im_scale_x = self.generate_scale(im)
|
||||
im = cv2.resize(
|
||||
im,
|
||||
None,
|
||||
None,
|
||||
fx=im_scale_x,
|
||||
fy=im_scale_y,
|
||||
interpolation=self.interp)
|
||||
im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
|
||||
im_info['scale_factor'] = np.array(
|
||||
[im_scale_y, im_scale_x]).astype('float32')
|
||||
return im, im_info
|
||||
|
||||
def generate_scale(self, im):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
Returns:
|
||||
im_scale_x: the resize ratio of X
|
||||
im_scale_y: the resize ratio of Y
|
||||
"""
|
||||
origin_shape = im.shape[:2]
|
||||
im_c = im.shape[2]
|
||||
if self.keep_ratio:
|
||||
im_size_min = np.min(origin_shape)
|
||||
im_size_max = np.max(origin_shape)
|
||||
target_size_min = np.min(self.target_size)
|
||||
target_size_max = np.max(self.target_size)
|
||||
im_scale = float(target_size_min) / float(im_size_min)
|
||||
if np.round(im_scale * im_size_max) > target_size_max:
|
||||
im_scale = float(target_size_max) / float(im_size_max)
|
||||
im_scale_x = im_scale
|
||||
im_scale_y = im_scale
|
||||
else:
|
||||
resize_h, resize_w = self.target_size
|
||||
im_scale_y = resize_h / float(origin_shape[0])
|
||||
im_scale_x = resize_w / float(origin_shape[1])
|
||||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class Resize(object):
|
||||
def __init__(self, size=(640, 640), **kwargs):
|
||||
self.size = size
|
||||
|
||||
def resize_image(self, img):
|
||||
resize_h, resize_w = self.size
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'polys' in data:
|
||||
text_polys = data['polys']
|
||||
|
||||
img_resize, [ratio_h, ratio_w] = self.resize_image(img)
|
||||
if 'polys' in data:
|
||||
new_boxes = []
|
||||
for box in text_polys:
|
||||
new_box = []
|
||||
for cord in box:
|
||||
new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
|
||||
new_boxes.append(new_box)
|
||||
data['polys'] = np.array(new_boxes, dtype=np.float32)
|
||||
data['image'] = img_resize
|
||||
return data
|
||||
|
||||
|
||||
class DetResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(DetResizeForTest, self).__init__()
|
||||
self.resize_type = 0
|
||||
self.keep_ratio = False
|
||||
if 'image_shape' in kwargs:
|
||||
self.image_shape = kwargs['image_shape']
|
||||
self.resize_type = 1
|
||||
if 'keep_ratio' in kwargs:
|
||||
self.keep_ratio = kwargs['keep_ratio']
|
||||
elif 'limit_side_len' in kwargs:
|
||||
self.limit_side_len = kwargs['limit_side_len']
|
||||
self.limit_type = kwargs.get('limit_type', 'min')
|
||||
elif 'resize_long' in kwargs:
|
||||
self.resize_type = 2
|
||||
self.resize_long = kwargs.get('resize_long', 960)
|
||||
else:
|
||||
self.limit_side_len = 736
|
||||
self.limit_type = 'min'
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if sum([src_h, src_w]) < 64:
|
||||
img = self.image_padding(img)
|
||||
|
||||
if self.resize_type == 0:
|
||||
# img, shape = self.resize_image_type0(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||
elif self.resize_type == 2:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||
else:
|
||||
# img, shape = self.resize_image_type1(img)
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||
data['image'] = img
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def image_padding(self, im, value=0):
|
||||
h, w, c = im.shape
|
||||
im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value
|
||||
im_pad[:h, :w, :] = im
|
||||
return im_pad
|
||||
|
||||
def resize_image_type1(self, img):
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
if self.keep_ratio is True:
|
||||
resize_w = ori_w * resize_h / ori_h
|
||||
N = math.ceil(resize_w / 32)
|
||||
resize_w = N * 32
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
# return img, np.array([ori_h, ori_w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type0(self, img):
|
||||
"""
|
||||
resize image to a size multiple of 32 which is required by the network
|
||||
args:
|
||||
img(array): array with shape [h, w, c]
|
||||
return(tuple):
|
||||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
if max(h, w) > limit_side_len:
|
||||
if h > w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'min':
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h, w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||
|
||||
try:
|
||||
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||
return None, (None, None)
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
except BaseException:
|
||||
print(img.shape, resize_w, resize_h)
|
||||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
ratio = float(self.resize_long) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class E2EResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(E2EResizeForTest, self).__init__()
|
||||
self.max_side_len = kwargs['max_side_len']
|
||||
self.valid_set = kwargs['valid_set']
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if self.valid_set == 'totaltext':
|
||||
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
||||
img, max_side_len=self.max_side_len)
|
||||
else:
|
||||
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
||||
img, max_side_len=self.max_side_len)
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_for_totaltext(self, im, max_side_len=512):
|
||||
|
||||
h, w, _ = im.shape
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image(self, im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
class KieResize(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(KieResize, self).__init__()
|
||||
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
||||
'img_scale'][1]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
points = data['points']
|
||||
src_h, src_w, _ = img.shape
|
||||
im_resized, scale_factor, [ratio_h, ratio_w
|
||||
], [new_h, new_w] = self.resize_image(img)
|
||||
resize_points = self.resize_boxes(img, points, scale_factor)
|
||||
data['ori_image'] = img
|
||||
data['ori_boxes'] = points
|
||||
data['points'] = resize_points
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([new_h, new_w])
|
||||
return data
|
||||
|
||||
def resize_image(self, img):
|
||||
norm_img = np.zeros([1024, 1024, 3], dtype='float32')
|
||||
scale = [512, 1024]
|
||||
h, w = img.shape[:2]
|
||||
max_long_edge = max(scale)
|
||||
max_short_edge = min(scale)
|
||||
scale_factor = min(max_long_edge / max(h, w),
|
||||
max_short_edge / min(h, w))
|
||||
resize_w, resize_h = int(w * float(scale_factor) + 0.5), int(h * float(
|
||||
scale_factor) + 0.5)
|
||||
max_stride = 32
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(img, (resize_w, resize_h))
|
||||
new_h, new_w = im.shape[:2]
|
||||
w_scale = new_w / w
|
||||
h_scale = new_h / h
|
||||
scale_factor = np.array(
|
||||
[w_scale, h_scale, w_scale, h_scale], dtype=np.float32)
|
||||
norm_img[:new_h, :new_w, :] = im
|
||||
return norm_img, scale_factor, [h_scale, w_scale], [new_h, new_w]
|
||||
|
||||
def resize_boxes(self, im, points, scale_factor):
|
||||
points = points * scale_factor
|
||||
img_shape = im.shape[:2]
|
||||
points[:, 0::2] = np.clip(points[:, 0::2], 0, img_shape[1])
|
||||
points[:, 1::2] = np.clip(points[:, 1::2], 0, img_shape[0])
|
||||
return points
|
||||
|
||||
|
||||
class SRResize(object):
|
||||
def __init__(self,
|
||||
imgH=32,
|
||||
imgW=128,
|
||||
down_sample_scale=4,
|
||||
keep_ratio=False,
|
||||
min_ratio=1,
|
||||
mask=False,
|
||||
infer_mode=False,
|
||||
**kwargs):
|
||||
self.imgH = imgH
|
||||
self.imgW = imgW
|
||||
self.keep_ratio = keep_ratio
|
||||
self.min_ratio = min_ratio
|
||||
self.down_sample_scale = down_sample_scale
|
||||
self.mask = mask
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
imgH = self.imgH
|
||||
imgW = self.imgW
|
||||
images_lr = data["image_lr"]
|
||||
transform2 = ResizeNormalize(
|
||||
(imgW // self.down_sample_scale, imgH // self.down_sample_scale))
|
||||
images_lr = transform2(images_lr)
|
||||
data["img_lr"] = images_lr
|
||||
if self.infer_mode:
|
||||
return data
|
||||
|
||||
images_HR = data["image_hr"]
|
||||
label_strs = data["label"]
|
||||
transform = ResizeNormalize((imgW, imgH))
|
||||
images_HR = transform(images_HR)
|
||||
data["img_hr"] = images_HR
|
||||
return data
|
||||
|
||||
|
||||
class ResizeNormalize(object):
|
||||
def __init__(self, size, interpolation=Image.BICUBIC):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
img = img.resize(self.size, self.interpolation)
|
||||
img_numpy = np.array(img).astype("float32")
|
||||
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
|
||||
return img_numpy
|
||||
|
||||
|
||||
class GrayImageChannelFormat(object):
|
||||
"""
|
||||
format gray scale image's channel: (3,h,w) -> (1,h,w)
|
||||
Args:
|
||||
inverse: inverse gray image
|
||||
"""
|
||||
|
||||
def __init__(self, inverse=False, **kwargs):
|
||||
self.inverse = inverse
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
img_expanded = np.expand_dims(img_single_channel, 0)
|
||||
|
||||
if self.inverse:
|
||||
data['image'] = np.abs(img_expanded - 1)
|
||||
else:
|
||||
data['image'] = img_expanded
|
||||
|
||||
data['src_image'] = img
|
||||
return data
|
||||
|
||||
|
||||
class Permute(object):
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
channel_first (bool): whether convert HWC to CHW
|
||||
"""
|
||||
|
||||
def __init__(self, ):
|
||||
super(Permute, self).__init__()
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
im = im.transpose((2, 0, 1)).copy()
|
||||
return im, im_info
|
||||
|
||||
|
||||
class PadStride(object):
|
||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
"""
|
||||
|
||||
def __init__(self, stride=0):
|
||||
self.coarsest_stride = stride
|
||||
|
||||
def __call__(self, im, im_info):
|
||||
"""
|
||||
Args:
|
||||
im (np.ndarray): image (np.ndarray)
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
coarsest_stride = self.coarsest_stride
|
||||
if coarsest_stride <= 0:
|
||||
return im, im_info
|
||||
im_c, im_h, im_w = im.shape
|
||||
pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
|
||||
pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
|
||||
padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = im
|
||||
return padding_im, im_info
|
||||
|
||||
|
||||
def decode_image(im_file, im_info):
|
||||
"""read rgb image
|
||||
Args:
|
||||
im_file (str|np.ndarray): input can be image path or np.ndarray
|
||||
im_info (dict): info of image
|
||||
Returns:
|
||||
im (np.ndarray): processed image (np.ndarray)
|
||||
im_info (dict): info of processed image
|
||||
"""
|
||||
if isinstance(im_file, str):
|
||||
with open(im_file, 'rb') as f:
|
||||
im_read = f.read()
|
||||
data = np.frombuffer(im_read, dtype='uint8')
|
||||
im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
|
||||
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
|
||||
else:
|
||||
im = im_file
|
||||
im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
|
||||
im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
|
||||
return im, im_info
|
||||
|
||||
|
||||
def preprocess(im, preprocess_ops):
|
||||
# process image by preprocess_ops
|
||||
im_info = {
|
||||
'scale_factor': np.array(
|
||||
[1., 1.], dtype=np.float32),
|
||||
'im_shape': None,
|
||||
}
|
||||
im, im_info = decode_image(im, im_info)
|
||||
for operator in preprocess_ops:
|
||||
im, im_info = operator(im, im_info)
|
||||
return im, im_info
|
||||
354
deepdoc/vision/postprocess.py
Normal file
354
deepdoc/vision/postprocess.py
Normal file
@@ -0,0 +1,354 @@
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
import paddle
|
||||
from shapely.geometry import Polygon
|
||||
import pyclipper
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
support_dict = ['DBPostProcess', 'CTCLabelDecode']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
if module_name == "None":
|
||||
return
|
||||
if global_config is not None:
|
||||
config.update(global_config)
|
||||
assert module_name in support_dict, Exception(
|
||||
'post process only support {}'.format(support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
||||
|
||||
class DBPostProcess(object):
|
||||
"""
|
||||
The post process for Differentiable Binarization (DB).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
thresh=0.3,
|
||||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
box_type='quad',
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.score_mode = score_mode
|
||||
self.box_type = box_type
|
||||
assert score_mode in [
|
||||
"slow", "fast"
|
||||
], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
|
||||
|
||||
self.dilation_kernel = None if not use_dilation else np.array(
|
||||
[[1, 1], [1, 1]])
|
||||
|
||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours[:self.max_candidates]:
|
||||
epsilon = 0.002 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
continue
|
||||
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
if points.shape[0] > 2:
|
||||
box = self.unclip(points, self.unclip_ratio)
|
||||
if len(box) > 1:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
box = box.reshape(-1, 2)
|
||||
|
||||
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box)
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
'''
|
||||
|
||||
bitmap = _bitmap
|
||||
height, width = bitmap.shape
|
||||
|
||||
outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
|
||||
cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(outs) == 3:
|
||||
img, contours, _ = outs[0], outs[1], outs[2]
|
||||
elif len(outs) == 2:
|
||||
contours, _ = outs[0], outs[1]
|
||||
|
||||
num_contours = min(len(contours), self.max_candidates)
|
||||
|
||||
boxes = []
|
||||
scores = []
|
||||
for index in range(num_contours):
|
||||
contour = contours[index]
|
||||
points, sside = self.get_mini_boxes(contour)
|
||||
if sside < self.min_size:
|
||||
continue
|
||||
points = np.array(points)
|
||||
if self.score_mode == "fast":
|
||||
score = self.box_score_fast(pred, points.reshape(-1, 2))
|
||||
else:
|
||||
score = self.box_score_slow(pred, contour)
|
||||
if self.box_thresh > score:
|
||||
continue
|
||||
|
||||
box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
|
||||
box, sside = self.get_mini_boxes(box)
|
||||
if sside < self.min_size + 2:
|
||||
continue
|
||||
box = np.array(box)
|
||||
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.astype("int32"))
|
||||
scores.append(score)
|
||||
return np.array(boxes, dtype="int32"), scores
|
||||
|
||||
def unclip(self, box, unclip_ratio):
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = np.array(offset.Execute(distance))
|
||||
return expanded
|
||||
|
||||
def get_mini_boxes(self, contour):
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [
|
||||
points[index_1], points[index_2], points[index_3], points[index_4]
|
||||
]
|
||||
return box, min(bounding_box[1])
|
||||
|
||||
def box_score_fast(self, bitmap, _box):
|
||||
'''
|
||||
box_score_fast: use bbox mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def box_score_slow(self, bitmap, contour):
|
||||
'''
|
||||
box_score_slow: use polyon mean score as the mean score
|
||||
'''
|
||||
h, w = bitmap.shape[:2]
|
||||
contour = contour.copy()
|
||||
contour = np.reshape(contour, (-1, 2))
|
||||
|
||||
xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
|
||||
xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
|
||||
ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
|
||||
ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
|
||||
contour[:, 0] = contour[:, 0] - xmin
|
||||
contour[:, 1] = contour[:, 1] - ymin
|
||||
|
||||
cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
if self.dilation_kernel is not None:
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
if self.box_type == 'poly':
|
||||
boxes, scores = self.polygons_from_bitmap(pred[batch_index],
|
||||
mask, src_w, src_h)
|
||||
elif self.box_type == 'quad':
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||
src_w, src_h)
|
||||
else:
|
||||
raise ValueError(
|
||||
"box_type can only be one of ['quad', 'poly']")
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.reverse = False
|
||||
self.character_str = []
|
||||
|
||||
if character_dict_path is None:
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
if 'arabic' in character_dict_path:
|
||||
self.reverse = True
|
||||
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def pred_reverse(self, pred):
|
||||
pred_re = []
|
||||
c_current = ''
|
||||
for c in pred:
|
||||
if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
pred_re.append(c)
|
||||
c_current = ''
|
||||
else:
|
||||
c_current += c
|
||||
if c_current != '':
|
||||
pred_re.append(c_current)
|
||||
|
||||
return ''.join(pred_re[::-1])
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
selection = np.ones(len(text_index[batch_idx]), dtype=bool)
|
||||
if is_remove_duplicate:
|
||||
selection[1:] = text_index[batch_idx][1:] != text_index[
|
||||
batch_idx][:-1]
|
||||
for ignored_token in ignored_tokens:
|
||||
selection &= text_index[batch_idx] != ignored_token
|
||||
|
||||
char_list = [
|
||||
self.character[text_id]
|
||||
for text_id in text_index[batch_idx][selection]
|
||||
]
|
||||
if text_prob is not None:
|
||||
conf_list = text_prob[batch_idx][selection]
|
||||
else:
|
||||
conf_list = [1] * len(selection)
|
||||
if len(conf_list) == 0:
|
||||
conf_list = [0]
|
||||
|
||||
text = ''.join(char_list)
|
||||
|
||||
if self.reverse: # for arabic rec
|
||||
text = self.pred_reverse(text)
|
||||
|
||||
result_list.append((text, np.mean(conf_list).tolist()))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False,
|
||||
**kwargs):
|
||||
super(CTCLabelDecode, self).__init__(character_dict_path,
|
||||
use_space_char)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, tuple) or isinstance(preds, list):
|
||||
preds = preds[-1]
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
||||
327
deepdoc/vision/recognizer.py
Normal file
327
deepdoc/vision/recognizer.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from . import seeit
|
||||
from .operators import *
|
||||
from rag.settings import cron_logger
|
||||
|
||||
|
||||
class Recognizer(object):
|
||||
def __init__(self, label_list, task_name, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
if not model_dir:
|
||||
model_dir = snapshot_download(repo_id="InfiniFlow/ocr")
|
||||
|
||||
model_file_path = os.path.join(model_dir, task_name + ".onnx")
|
||||
if not os.path.exists(model_file_path):
|
||||
raise ValueError("not find model file path {}".format(
|
||||
model_file_path))
|
||||
if ort.get_device() == "GPU":
|
||||
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CUDAExecutionProvider'])
|
||||
else:
|
||||
self.ort_sess = ort.InferenceSession(model_file_path, providers=['CPUExecutionProvider'])
|
||||
self.label_list = label_list
|
||||
|
||||
@staticmethod
|
||||
def sort_Y_firstly(arr, threashold):
|
||||
# sort using y1 first and then x1
|
||||
arr = sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if abs(arr[j + 1]["top"] - arr[j]["top"]) < threashold \
|
||||
and arr[j + 1]["x0"] < arr[j]["x0"]:
|
||||
tmp = deepcopy(arr[j])
|
||||
arr[j] = deepcopy(arr[j + 1])
|
||||
arr[j + 1] = deepcopy(tmp)
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_X_firstly(arr, threashold, copy=True):
|
||||
# sort using y1 first and then x1
|
||||
arr = sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if abs(arr[j + 1]["x0"] - arr[j]["x0"]) < threashold \
|
||||
and arr[j + 1]["top"] < arr[j]["top"]:
|
||||
tmp = deepcopy(arr[j]) if copy else arr[j]
|
||||
arr[j] = deepcopy(arr[j + 1]) if copy else arr[j + 1]
|
||||
arr[j + 1] = deepcopy(tmp) if copy else tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def sort_C_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["x0"], r["top"]))
|
||||
arr = Recognizer.sort_X_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
# restore the order using th
|
||||
if "C" not in arr[j] or "C" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["C"] < arr[j]["C"] \
|
||||
or (
|
||||
arr[j + 1]["C"] == arr[j]["C"]
|
||||
and arr[j + 1]["top"] < arr[j]["top"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
return sorted(arr, key=lambda r: (r.get("C", r["x0"]), r["top"]))
|
||||
|
||||
@staticmethod
|
||||
def sort_R_firstly(arr, thr=0):
|
||||
# sort using y1 first and then x1
|
||||
# sorted(arr, key=lambda r: (r["top"], r["x0"]))
|
||||
arr = Recognizer.sort_Y_firstly(arr, thr)
|
||||
for i in range(len(arr) - 1):
|
||||
for j in range(i, -1, -1):
|
||||
if "R" not in arr[j] or "R" not in arr[j + 1]:
|
||||
continue
|
||||
if arr[j + 1]["R"] < arr[j]["R"] \
|
||||
or (
|
||||
arr[j + 1]["R"] == arr[j]["R"]
|
||||
and arr[j + 1]["x0"] < arr[j]["x0"]
|
||||
):
|
||||
tmp = arr[j]
|
||||
arr[j] = arr[j + 1]
|
||||
arr[j + 1] = tmp
|
||||
return arr
|
||||
|
||||
@staticmethod
|
||||
def overlapped_area(a, b, ratio=True):
|
||||
tp, btm, x0, x1 = a["top"], a["bottom"], a["x0"], a["x1"]
|
||||
if b["x0"] > x1 or b["x1"] < x0:
|
||||
return 0
|
||||
if b["bottom"] < tp or b["top"] > btm:
|
||||
return 0
|
||||
x0_ = max(b["x0"], x0)
|
||||
x1_ = min(b["x1"], x1)
|
||||
assert x0_ <= x1_, "Fuckedup! T:{},B:{},X0:{},X1:{} ==> {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
tp_ = max(b["top"], tp)
|
||||
btm_ = min(b["bottom"], btm)
|
||||
assert tp_ <= btm_, "Fuckedup! T:{},B:{},X0:{},X1:{} => {}".format(
|
||||
tp, btm, x0, x1, b)
|
||||
ov = (btm_ - tp_) * (x1_ - x0_) if x1 - \
|
||||
x0 != 0 and btm - tp != 0 else 0
|
||||
if ov > 0 and ratio:
|
||||
ov /= (x1 - x0) * (btm - tp)
|
||||
return ov
|
||||
|
||||
@staticmethod
|
||||
def layouts_cleanup(boxes, layouts, far=2, thr=0.7):
|
||||
def notOverlapped(a, b):
|
||||
return any([a["x1"] < b["x0"],
|
||||
a["x0"] > b["x1"],
|
||||
a["bottom"] < b["top"],
|
||||
a["top"] > b["bottom"]])
|
||||
|
||||
i = 0
|
||||
while i + 1 < len(layouts):
|
||||
j = i + 1
|
||||
while j < min(i + far, len(layouts)) \
|
||||
and (layouts[i].get("type", "") != layouts[j].get("type", "")
|
||||
or notOverlapped(layouts[i], layouts[j])):
|
||||
j += 1
|
||||
if j >= min(i + far, len(layouts)):
|
||||
i += 1
|
||||
continue
|
||||
if Recognizer.overlapped_area(layouts[i], layouts[j]) < thr \
|
||||
and Recognizer.overlapped_area(layouts[j], layouts[i]) < thr:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if layouts[i].get("score") and layouts[j].get("score"):
|
||||
if layouts[i]["score"] > layouts[j]["score"]:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
continue
|
||||
|
||||
area_i, area_i_1 = 0, 0
|
||||
for b in boxes:
|
||||
if not notOverlapped(b, layouts[i]):
|
||||
area_i += Recognizer.overlapped_area(b, layouts[i], False)
|
||||
if not notOverlapped(b, layouts[j]):
|
||||
area_i_1 += Recognizer.overlapped_area(b, layouts[j], False)
|
||||
|
||||
if area_i > area_i_1:
|
||||
layouts.pop(j)
|
||||
else:
|
||||
layouts.pop(i)
|
||||
|
||||
return layouts
|
||||
|
||||
def create_inputs(self, imgs, im_info):
|
||||
"""generate input for different model type
|
||||
Args:
|
||||
imgs (list(numpy)): list of images (np.ndarray)
|
||||
im_info (list(dict)): list of image info
|
||||
Returns:
|
||||
inputs (dict): input of model
|
||||
"""
|
||||
inputs = {}
|
||||
|
||||
im_shape = []
|
||||
scale_factor = []
|
||||
if len(imgs) == 1:
|
||||
inputs['image'] = np.array((imgs[0],)).astype('float32')
|
||||
inputs['im_shape'] = np.array(
|
||||
(im_info[0]['im_shape'],)).astype('float32')
|
||||
inputs['scale_factor'] = np.array(
|
||||
(im_info[0]['scale_factor'],)).astype('float32')
|
||||
return inputs
|
||||
|
||||
for e in im_info:
|
||||
im_shape.append(np.array((e['im_shape'],)).astype('float32'))
|
||||
scale_factor.append(np.array((e['scale_factor'],)).astype('float32'))
|
||||
|
||||
inputs['im_shape'] = np.concatenate(im_shape, axis=0)
|
||||
inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
|
||||
|
||||
imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
|
||||
max_shape_h = max([e[0] for e in imgs_shape])
|
||||
max_shape_w = max([e[1] for e in imgs_shape])
|
||||
padding_imgs = []
|
||||
for img in imgs:
|
||||
im_c, im_h, im_w = img.shape[:]
|
||||
padding_im = np.zeros(
|
||||
(im_c, max_shape_h, max_shape_w), dtype=np.float32)
|
||||
padding_im[:, :im_h, :im_w] = img
|
||||
padding_imgs.append(padding_im)
|
||||
inputs['image'] = np.stack(padding_imgs, axis=0)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped(box, boxes_sorted_by_y, naive=False):
|
||||
if not boxes_sorted_by_y:
|
||||
return
|
||||
bxs = boxes_sorted_by_y
|
||||
s, e, ii = 0, len(bxs), 0
|
||||
while s < e and not naive:
|
||||
ii = (e + s) // 2
|
||||
pv = bxs[ii]
|
||||
if box["bottom"] < pv["top"]:
|
||||
e = ii
|
||||
continue
|
||||
if box["top"] > pv["bottom"]:
|
||||
s = ii + 1
|
||||
continue
|
||||
break
|
||||
while s < ii:
|
||||
if box["top"] > bxs[s]["bottom"]:
|
||||
s += 1
|
||||
break
|
||||
while e - 1 > ii:
|
||||
if box["bottom"] < bxs[e - 1]["top"]:
|
||||
e -= 1
|
||||
break
|
||||
|
||||
max_overlaped_i, max_overlaped = None, 0
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(bxs[i], box)
|
||||
if ov <= max_overlaped:
|
||||
continue
|
||||
max_overlaped_i = i
|
||||
max_overlaped = ov
|
||||
|
||||
return max_overlaped_i
|
||||
|
||||
@staticmethod
|
||||
def find_overlapped_with_threashold(box, boxes, thr=0.3):
|
||||
if not boxes:
|
||||
return
|
||||
max_overlaped_i, max_overlaped, _max_overlaped = None, thr, 0
|
||||
s, e = 0, len(boxes)
|
||||
for i in range(s, e):
|
||||
ov = Recognizer.overlapped_area(box, boxes[i])
|
||||
_ov = Recognizer.overlapped_area(boxes[i], box)
|
||||
if (ov, _ov) < (max_overlaped, _max_overlaped):
|
||||
continue
|
||||
max_overlaped_i = i
|
||||
max_overlaped = ov
|
||||
_max_overlaped = _ov
|
||||
|
||||
return max_overlaped_i
|
||||
|
||||
def preprocess(self, image_list):
|
||||
preprocess_ops = []
|
||||
for op_info in [
|
||||
{'interp': 2, 'keep_ratio': False, 'target_size': [800, 608], 'type': 'LinearResize'},
|
||||
{'is_scale': True, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'type': 'StandardizeImage'},
|
||||
{'type': 'Permute'},
|
||||
{'stride': 32, 'type': 'PadStride'}
|
||||
]:
|
||||
new_op_info = op_info.copy()
|
||||
op_type = new_op_info.pop('type')
|
||||
preprocess_ops.append(eval(op_type)(**new_op_info))
|
||||
|
||||
inputs = []
|
||||
for im_path in image_list:
|
||||
im, im_info = preprocess(im_path, preprocess_ops)
|
||||
inputs.append({"image": np.array((im,)).astype('float32'), "scale_factor": np.array((im_info["scale_factor"],)).astype('float32')})
|
||||
return inputs
|
||||
|
||||
def __call__(self, image_list, thr=0.7, batch_size=16):
|
||||
res = []
|
||||
imgs = []
|
||||
for i in range(len(image_list)):
|
||||
if not isinstance(image_list[i], np.ndarray):
|
||||
imgs.append(np.array(image_list[i]))
|
||||
else: imgs.append(image_list[i])
|
||||
|
||||
batch_loop_cnt = math.ceil(float(len(imgs)) / batch_size)
|
||||
for i in range(batch_loop_cnt):
|
||||
start_index = i * batch_size
|
||||
end_index = min((i + 1) * batch_size, len(imgs))
|
||||
batch_image_list = imgs[start_index:end_index]
|
||||
inputs = self.preprocess(batch_image_list)
|
||||
for ins in inputs:
|
||||
bb = []
|
||||
for b in self.ort_sess.run(None, ins)[0]:
|
||||
clsid, bbox, score = int(b[0]), b[2:], b[1]
|
||||
if score < thr:
|
||||
continue
|
||||
if clsid >= len(self.label_list):
|
||||
cron_logger.warning(f"bad category id")
|
||||
continue
|
||||
bb.append({
|
||||
"type": self.label_list[clsid].lower(),
|
||||
"bbox": [float(t) for t in bbox.tolist()],
|
||||
"score": float(score)
|
||||
})
|
||||
res.append(bb)
|
||||
|
||||
#seeit.save_results(image_list, res, self.label_list, threshold=thr)
|
||||
|
||||
return res
|
||||
83
deepdoc/vision/seeit.py
Normal file
83
deepdoc/vision/seeit.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import PIL
|
||||
from PIL import ImageDraw
|
||||
|
||||
|
||||
def save_results(image_list, results, labels, output_dir='output/', threshold=0.5):
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
for idx, im in enumerate(image_list):
|
||||
im = draw_box(im, results[idx], labels, threshold=threshold)
|
||||
|
||||
out_path = os.path.join(output_dir, f"{idx}.jpg")
|
||||
im.save(out_path, quality=95)
|
||||
print("save result to: " + out_path)
|
||||
|
||||
|
||||
def draw_box(im, result, lables, threshold=0.5):
|
||||
draw_thickness = min(im.size) // 320
|
||||
draw = ImageDraw.Draw(im)
|
||||
color_list = get_color_map_list(len(lables))
|
||||
clsid2color = {n.lower():color_list[i] for i,n in enumerate(lables)}
|
||||
result = [r for r in result if r["score"] >= threshold]
|
||||
|
||||
for dt in result:
|
||||
color = tuple(clsid2color[dt["type"]])
|
||||
xmin, ymin, xmax, ymax = dt["bbox"]
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=draw_thickness,
|
||||
fill=color)
|
||||
|
||||
# draw label
|
||||
text = "{} {:.4f}".format(dt["type"], dt["score"])
|
||||
tw, th = imagedraw_textsize_c(draw, text)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
|
||||
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
|
||||
return im
|
||||
|
||||
|
||||
def get_color_map_list(num_classes):
|
||||
"""
|
||||
Args:
|
||||
num_classes (int): number of class
|
||||
Returns:
|
||||
color_map (list): RGB color list
|
||||
"""
|
||||
color_map = num_classes * [0, 0, 0]
|
||||
for i in range(0, num_classes):
|
||||
j = 0
|
||||
lab = i
|
||||
while lab:
|
||||
color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
|
||||
color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
|
||||
color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
|
||||
j += 1
|
||||
lab >>= 3
|
||||
color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
|
||||
return color_map
|
||||
|
||||
|
||||
def imagedraw_textsize_c(draw, text):
|
||||
if int(PIL.__version__.split('.')[0]) < 10:
|
||||
tw, th = draw.textsize(text)
|
||||
else:
|
||||
left, top, right, bottom = draw.textbbox((0, 0), text)
|
||||
tw, th = right - left, bottom - top
|
||||
|
||||
return tw, th
|
||||
556
deepdoc/vision/table_structure_recognizer.py
Normal file
556
deepdoc/vision/table_structure_recognizer.py
Normal file
@@ -0,0 +1,556 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from copy import deepcopy
|
||||
|
||||
import numpy as np
|
||||
|
||||
from api.utils.file_utils import get_project_base_directory
|
||||
from rag.nlp import huqie
|
||||
from .recognizer import Recognizer
|
||||
|
||||
|
||||
class TableStructureRecognizer(Recognizer):
|
||||
def __init__(self):
|
||||
self.labels = [
|
||||
"table",
|
||||
"table column",
|
||||
"table row",
|
||||
"table column header",
|
||||
"table projected row header",
|
||||
"table spanning cell",
|
||||
]
|
||||
super().__init__(self.labels, "tsr",
|
||||
os.path.join(get_project_base_directory(), "rag/res/deepdoc/"))
|
||||
|
||||
def __call__(self, images, thr=0.5):
|
||||
tbls = super().__call__(images, thr)
|
||||
res = []
|
||||
# align left&right for rows, align top&bottom for columns
|
||||
for tbl in tbls:
|
||||
lts = [{"label": b["type"],
|
||||
"score": b["score"],
|
||||
"x0": b["bbox"][0], "x1": b["bbox"][2],
|
||||
"top": b["bbox"][1], "bottom": b["bbox"][-1]
|
||||
} for b in tbl]
|
||||
if not lts:
|
||||
continue
|
||||
|
||||
left = [b["x0"] for b in lts if b["label"].find(
|
||||
"row") > 0 or b["label"].find("header") > 0]
|
||||
right = [b["x1"] for b in lts if b["label"].find(
|
||||
"row") > 0 or b["label"].find("header") > 0]
|
||||
if not left:
|
||||
continue
|
||||
left = np.median(left) if len(left) > 4 else np.min(left)
|
||||
right = np.median(right) if len(right) > 4 else np.max(right)
|
||||
for b in lts:
|
||||
if b["label"].find("row") > 0 or b["label"].find("header") > 0:
|
||||
if b["x0"] > left:
|
||||
b["x0"] = left
|
||||
if b["x1"] < right:
|
||||
b["x1"] = right
|
||||
|
||||
top = [b["top"] for b in lts if b["label"] == "table column"]
|
||||
bottom = [b["bottom"] for b in lts if b["label"] == "table column"]
|
||||
if not top:
|
||||
res.append(lts)
|
||||
continue
|
||||
top = np.median(top) if len(top) > 4 else np.min(top)
|
||||
bottom = np.median(bottom) if len(bottom) > 4 else np.max(bottom)
|
||||
for b in lts:
|
||||
if b["label"] == "table column":
|
||||
if b["top"] > top:
|
||||
b["top"] = top
|
||||
if b["bottom"] < bottom:
|
||||
b["bottom"] = bottom
|
||||
|
||||
res.append(lts)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def is_caption(bx):
|
||||
patt = [
|
||||
r"[图表]+[ 0-9::]{2,}"
|
||||
]
|
||||
if any([re.match(p, bx["text"].strip()) for p in patt]) \
|
||||
or bx["layout_type"].find("caption") >= 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __blockType(self, b):
|
||||
patt = [
|
||||
("^(20|19)[0-9]{2}[年/-][0-9]{1,2}[月/-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[年-][0-9]{1,2}月*$", "Dt"),
|
||||
("^[0-9]{1,2}[月-][0-9]{1,2}日*$", "Dt"),
|
||||
(r"^第*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}年*[一二三四1-4]季度$", "Dt"),
|
||||
(r"^(20|19)[0-9]{2}[ABCDE]$", "Dt"),
|
||||
("^[0-9.,+%/ -]+$", "Nu"),
|
||||
(r"^[0-9A-Z/\._~-]+$", "Ca"),
|
||||
(r"^[A-Z]*[a-z' -]+$", "En"),
|
||||
(r"^[0-9.,+-]+[0-9A-Za-z/$¥%<>()()' -]+$", "NE"),
|
||||
(r"^.{1}$", "Sg")
|
||||
]
|
||||
for p, n in patt:
|
||||
if re.search(p, b["text"].strip()):
|
||||
return n
|
||||
tks = [t for t in huqie.qie(b["text"]).split(" ") if len(t) > 1]
|
||||
if len(tks) > 3:
|
||||
if len(tks) < 12:
|
||||
return "Tx"
|
||||
else:
|
||||
return "Lx"
|
||||
|
||||
if len(tks) == 1 and huqie.tag(tks[0]) == "nr":
|
||||
return "Nr"
|
||||
|
||||
return "Ot"
|
||||
|
||||
def construct_table(self, boxes, is_english=False, html=False):
|
||||
cap = ""
|
||||
i = 0
|
||||
while i < len(boxes):
|
||||
if self.is_caption(boxes[i]):
|
||||
cap += boxes[i]["text"]
|
||||
boxes.pop(i)
|
||||
i -= 1
|
||||
i += 1
|
||||
|
||||
if not boxes:
|
||||
return []
|
||||
for b in boxes:
|
||||
b["btype"] = self.__blockType(b)
|
||||
max_type = Counter([b["btype"] for b in boxes]).items()
|
||||
max_type = max(max_type, key=lambda x: x[1])[0] if max_type else ""
|
||||
logging.debug("MAXTYPE: " + max_type)
|
||||
|
||||
rowh = [b["R_bott"] - b["R_top"] for b in boxes if "R" in b]
|
||||
rowh = np.min(rowh) if rowh else 0
|
||||
boxes = self.sort_R_firstly(boxes, rowh / 2)
|
||||
boxes[0]["rn"] = 0
|
||||
rows = [[boxes[0]]]
|
||||
btm = boxes[0]["bottom"]
|
||||
for b in boxes[1:]:
|
||||
b["rn"] = len(rows) - 1
|
||||
lst_r = rows[-1]
|
||||
if lst_r[-1].get("R", "") != b.get("R", "") \
|
||||
or (b["top"] >= btm - 3 and lst_r[-1].get("R", "-1") != b.get("R", "-2")
|
||||
): # new row
|
||||
btm = b["bottom"]
|
||||
b["rn"] += 1
|
||||
rows.append([b])
|
||||
continue
|
||||
btm = (btm + b["bottom"]) / 2.
|
||||
rows[-1].append(b)
|
||||
|
||||
colwm = [b["C_right"] - b["C_left"] for b in boxes if "C" in b]
|
||||
colwm = np.min(colwm) if colwm else 0
|
||||
crosspage = len(set([b["page_number"] for b in boxes])) > 1
|
||||
if crosspage:
|
||||
boxes = self.sort_X_firstly(boxes, colwm / 2, False)
|
||||
else:
|
||||
boxes = self.sort_C_firstly(boxes, colwm / 2)
|
||||
boxes[0]["cn"] = 0
|
||||
cols = [[boxes[0]]]
|
||||
right = boxes[0]["x1"]
|
||||
for b in boxes[1:]:
|
||||
b["cn"] = len(cols) - 1
|
||||
lst_c = cols[-1]
|
||||
if (int(b.get("C", "1")) - int(lst_c[-1].get("C", "1")) == 1 and b["page_number"] == lst_c[-1][
|
||||
"page_number"]) \
|
||||
or (b["x0"] >= right and lst_c[-1].get("C", "-1") != b.get("C", "-2")): # new col
|
||||
right = b["x1"]
|
||||
b["cn"] += 1
|
||||
cols.append([b])
|
||||
continue
|
||||
right = (right + b["x1"]) / 2.
|
||||
cols[-1].append(b)
|
||||
|
||||
tbl = [[[] for _ in range(len(cols))] for _ in range(len(rows))]
|
||||
for b in boxes:
|
||||
tbl[b["rn"]][b["cn"]].append(b)
|
||||
|
||||
if len(rows) >= 4:
|
||||
# remove single in column
|
||||
j = 0
|
||||
while j < len(tbl[0]):
|
||||
e, ii = 0, 0
|
||||
for i in range(len(tbl)):
|
||||
if tbl[i][j]:
|
||||
e += 1
|
||||
ii = i
|
||||
if e > 1:
|
||||
break
|
||||
if e > 1:
|
||||
j += 1
|
||||
continue
|
||||
f = (j > 0 and tbl[ii][j - 1] and tbl[ii]
|
||||
[j - 1][0].get("text")) or j == 0
|
||||
ff = (j + 1 < len(tbl[ii]) and tbl[ii][j + 1] and tbl[ii]
|
||||
[j + 1][0].get("text")) or j + 1 >= len(tbl[ii])
|
||||
if f and ff:
|
||||
j += 1
|
||||
continue
|
||||
bx = tbl[ii][j][0]
|
||||
logging.debug("Relocate column single: " + bx["text"])
|
||||
# j column only has one value
|
||||
left, right = 100000, 100000
|
||||
if j > 0 and not f:
|
||||
for i in range(len(tbl)):
|
||||
if tbl[i][j - 1]:
|
||||
left = min(left, np.min(
|
||||
[bx["x0"] - a["x1"] for a in tbl[i][j - 1]]))
|
||||
if j + 1 < len(tbl[0]) and not ff:
|
||||
for i in range(len(tbl)):
|
||||
if tbl[i][j + 1]:
|
||||
right = min(right, np.min(
|
||||
[a["x0"] - bx["x1"] for a in tbl[i][j + 1]]))
|
||||
assert left < 100000 or right < 100000
|
||||
if left < right:
|
||||
for jj in range(j, len(tbl[0])):
|
||||
for i in range(len(tbl)):
|
||||
for a in tbl[i][jj]:
|
||||
a["cn"] -= 1
|
||||
if tbl[ii][j - 1]:
|
||||
tbl[ii][j - 1].extend(tbl[ii][j])
|
||||
else:
|
||||
tbl[ii][j - 1] = tbl[ii][j]
|
||||
for i in range(len(tbl)):
|
||||
tbl[i].pop(j)
|
||||
|
||||
else:
|
||||
for jj in range(j + 1, len(tbl[0])):
|
||||
for i in range(len(tbl)):
|
||||
for a in tbl[i][jj]:
|
||||
a["cn"] -= 1
|
||||
if tbl[ii][j + 1]:
|
||||
tbl[ii][j + 1].extend(tbl[ii][j])
|
||||
else:
|
||||
tbl[ii][j + 1] = tbl[ii][j]
|
||||
for i in range(len(tbl)):
|
||||
tbl[i].pop(j)
|
||||
cols.pop(j)
|
||||
assert len(cols) == len(tbl[0]), "Column NO. miss matched: %d vs %d" % (
|
||||
len(cols), len(tbl[0]))
|
||||
|
||||
if len(cols) >= 4:
|
||||
# remove single in row
|
||||
i = 0
|
||||
while i < len(tbl):
|
||||
e, jj = 0, 0
|
||||
for j in range(len(tbl[i])):
|
||||
if tbl[i][j]:
|
||||
e += 1
|
||||
jj = j
|
||||
if e > 1:
|
||||
break
|
||||
if e > 1:
|
||||
i += 1
|
||||
continue
|
||||
f = (i > 0 and tbl[i - 1][jj] and tbl[i - 1]
|
||||
[jj][0].get("text")) or i == 0
|
||||
ff = (i + 1 < len(tbl) and tbl[i + 1][jj] and tbl[i + 1]
|
||||
[jj][0].get("text")) or i + 1 >= len(tbl)
|
||||
if f and ff:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
bx = tbl[i][jj][0]
|
||||
logging.debug("Relocate row single: " + bx["text"])
|
||||
# i row only has one value
|
||||
up, down = 100000, 100000
|
||||
if i > 0 and not f:
|
||||
for j in range(len(tbl[i - 1])):
|
||||
if tbl[i - 1][j]:
|
||||
up = min(up, np.min(
|
||||
[bx["top"] - a["bottom"] for a in tbl[i - 1][j]]))
|
||||
if i + 1 < len(tbl) and not ff:
|
||||
for j in range(len(tbl[i + 1])):
|
||||
if tbl[i + 1][j]:
|
||||
down = min(down, np.min(
|
||||
[a["top"] - bx["bottom"] for a in tbl[i + 1][j]]))
|
||||
assert up < 100000 or down < 100000
|
||||
if up < down:
|
||||
for ii in range(i, len(tbl)):
|
||||
for j in range(len(tbl[ii])):
|
||||
for a in tbl[ii][j]:
|
||||
a["rn"] -= 1
|
||||
if tbl[i - 1][jj]:
|
||||
tbl[i - 1][jj].extend(tbl[i][jj])
|
||||
else:
|
||||
tbl[i - 1][jj] = tbl[i][jj]
|
||||
tbl.pop(i)
|
||||
|
||||
else:
|
||||
for ii in range(i + 1, len(tbl)):
|
||||
for j in range(len(tbl[ii])):
|
||||
for a in tbl[ii][j]:
|
||||
a["rn"] -= 1
|
||||
if tbl[i + 1][jj]:
|
||||
tbl[i + 1][jj].extend(tbl[i][jj])
|
||||
else:
|
||||
tbl[i + 1][jj] = tbl[i][jj]
|
||||
tbl.pop(i)
|
||||
rows.pop(i)
|
||||
|
||||
# which rows are headers
|
||||
hdset = set([])
|
||||
for i in range(len(tbl)):
|
||||
cnt, h = 0, 0
|
||||
for j, arr in enumerate(tbl[i]):
|
||||
if not arr:
|
||||
continue
|
||||
cnt += 1
|
||||
if max_type == "Nu" and arr[0]["btype"] == "Nu":
|
||||
continue
|
||||
if any([a.get("H") for a in arr]) \
|
||||
or (max_type == "Nu" and arr[0]["btype"] != "Nu"):
|
||||
h += 1
|
||||
if h / cnt > 0.5:
|
||||
hdset.add(i)
|
||||
|
||||
if html:
|
||||
return [self.__html_table(cap, hdset,
|
||||
self.__cal_spans(boxes, rows,
|
||||
cols, tbl, True)
|
||||
)]
|
||||
|
||||
return self.__desc_table(cap, hdset,
|
||||
self.__cal_spans(boxes, rows, cols, tbl, False),
|
||||
is_english)
|
||||
|
||||
def __html_table(self, cap, hdset, tbl):
|
||||
# constrcut HTML
|
||||
html = "<table>"
|
||||
if cap:
|
||||
html += f"<caption>{cap}</caption>"
|
||||
for i in range(len(tbl)):
|
||||
row = "<tr>"
|
||||
txts = []
|
||||
for j, arr in enumerate(tbl[i]):
|
||||
if arr is None:
|
||||
continue
|
||||
if not arr:
|
||||
row += "<td></td>" if i not in hdset else "<th></th>"
|
||||
continue
|
||||
txt = ""
|
||||
if arr:
|
||||
h = min(np.min([c["bottom"] - c["top"] for c in arr]) / 2, 10)
|
||||
txt = "".join([c["text"]
|
||||
for c in self.sort_Y_firstly(arr, h)])
|
||||
txts.append(txt)
|
||||
sp = ""
|
||||
if arr[0].get("colspan"):
|
||||
sp = "colspan={}".format(arr[0]["colspan"])
|
||||
if arr[0].get("rowspan"):
|
||||
sp += " rowspan={}".format(arr[0]["rowspan"])
|
||||
if i in hdset:
|
||||
row += f"<th {sp} >" + txt + "</th>"
|
||||
else:
|
||||
row += f"<td {sp} >" + txt + "</td>"
|
||||
|
||||
if i in hdset:
|
||||
if all([t in hdset for t in txts]):
|
||||
continue
|
||||
for t in txts:
|
||||
hdset.add(t)
|
||||
|
||||
if row != "<tr>":
|
||||
row += "</tr>"
|
||||
else:
|
||||
row = ""
|
||||
html += "\n" + row
|
||||
html += "\n</table>"
|
||||
return html
|
||||
|
||||
def __desc_table(self, cap, hdr_rowno, tbl, is_english):
|
||||
# get text of every colomn in header row to become header text
|
||||
clmno = len(tbl[0])
|
||||
rowno = len(tbl)
|
||||
headers = {}
|
||||
hdrset = set()
|
||||
lst_hdr = []
|
||||
de = "的" if not is_english else " for "
|
||||
for r in sorted(list(hdr_rowno)):
|
||||
headers[r] = ["" for _ in range(clmno)]
|
||||
for i in range(clmno):
|
||||
if not tbl[r][i]:
|
||||
continue
|
||||
txt = "".join([a["text"].strip() for a in tbl[r][i]])
|
||||
headers[r][i] = txt
|
||||
hdrset.add(txt)
|
||||
if all([not t for t in headers[r]]):
|
||||
del headers[r]
|
||||
hdr_rowno.remove(r)
|
||||
continue
|
||||
for j in range(clmno):
|
||||
if headers[r][j]:
|
||||
continue
|
||||
if j >= len(lst_hdr):
|
||||
break
|
||||
headers[r][j] = lst_hdr[j]
|
||||
lst_hdr = headers[r]
|
||||
for i in range(rowno):
|
||||
if i not in hdr_rowno:
|
||||
continue
|
||||
for j in range(i + 1, rowno):
|
||||
if j not in hdr_rowno:
|
||||
break
|
||||
for k in range(clmno):
|
||||
if not headers[j - 1][k]:
|
||||
continue
|
||||
if headers[j][k].find(headers[j - 1][k]) >= 0:
|
||||
continue
|
||||
if len(headers[j][k]) > len(headers[j - 1][k]):
|
||||
headers[j][k] += (de if headers[j][k]
|
||||
else "") + headers[j - 1][k]
|
||||
else:
|
||||
headers[j][k] = headers[j - 1][k] \
|
||||
+ (de if headers[j - 1][k] else "") \
|
||||
+ headers[j][k]
|
||||
|
||||
logging.debug(
|
||||
f">>>>>>>>>>>>>>>>>{cap}:SIZE:{rowno}X{clmno} Header: {hdr_rowno}")
|
||||
row_txt = []
|
||||
for i in range(rowno):
|
||||
if i in hdr_rowno:
|
||||
continue
|
||||
rtxt = []
|
||||
|
||||
def append(delimer):
|
||||
nonlocal rtxt, row_txt
|
||||
rtxt = delimer.join(rtxt)
|
||||
if row_txt and len(row_txt[-1]) + len(rtxt) < 64:
|
||||
row_txt[-1] += "\n" + rtxt
|
||||
else:
|
||||
row_txt.append(rtxt)
|
||||
|
||||
r = 0
|
||||
if len(headers.items()):
|
||||
_arr = [(i - r, r) for r, _ in headers.items() if r < i]
|
||||
if _arr:
|
||||
_, r = min(_arr, key=lambda x: x[0])
|
||||
|
||||
if r not in headers and clmno <= 2:
|
||||
for j in range(clmno):
|
||||
if not tbl[i][j]:
|
||||
continue
|
||||
txt = "".join([a["text"].strip() for a in tbl[i][j]])
|
||||
if txt:
|
||||
rtxt.append(txt)
|
||||
if rtxt:
|
||||
append(":")
|
||||
continue
|
||||
|
||||
for j in range(clmno):
|
||||
if not tbl[i][j]:
|
||||
continue
|
||||
txt = "".join([a["text"].strip() for a in tbl[i][j]])
|
||||
if not txt:
|
||||
continue
|
||||
ctt = headers[r][j] if r in headers else ""
|
||||
if ctt:
|
||||
ctt += ":"
|
||||
ctt += txt
|
||||
if ctt:
|
||||
rtxt.append(ctt)
|
||||
|
||||
if rtxt:
|
||||
row_txt.append("; ".join(rtxt))
|
||||
|
||||
if cap:
|
||||
if is_english:
|
||||
from_ = " in "
|
||||
else:
|
||||
from_ = "来自"
|
||||
row_txt = [t + f"\t——{from_}“{cap}”" for t in row_txt]
|
||||
return row_txt
|
||||
|
||||
def __cal_spans(self, boxes, rows, cols, tbl, html=True):
|
||||
# caculate span
|
||||
clft = [np.mean([c.get("C_left", c["x0"]) for c in cln])
|
||||
for cln in cols]
|
||||
crgt = [np.mean([c.get("C_right", c["x1"]) for c in cln])
|
||||
for cln in cols]
|
||||
rtop = [np.mean([c.get("R_top", c["top"]) for c in row])
|
||||
for row in rows]
|
||||
rbtm = [np.mean([c.get("R_btm", c["bottom"])
|
||||
for c in row]) for row in rows]
|
||||
for b in boxes:
|
||||
if "SP" not in b:
|
||||
continue
|
||||
b["colspan"] = [b["cn"]]
|
||||
b["rowspan"] = [b["rn"]]
|
||||
# col span
|
||||
for j in range(0, len(clft)):
|
||||
if j == b["cn"]:
|
||||
continue
|
||||
if clft[j] + (crgt[j] - clft[j]) / 2 < b["H_left"]:
|
||||
continue
|
||||
if crgt[j] - (crgt[j] - clft[j]) / 2 > b["H_right"]:
|
||||
continue
|
||||
b["colspan"].append(j)
|
||||
# row span
|
||||
for j in range(0, len(rtop)):
|
||||
if j == b["rn"]:
|
||||
continue
|
||||
if rtop[j] + (rbtm[j] - rtop[j]) / 2 < b["H_top"]:
|
||||
continue
|
||||
if rbtm[j] - (rbtm[j] - rtop[j]) / 2 > b["H_bott"]:
|
||||
continue
|
||||
b["rowspan"].append(j)
|
||||
|
||||
def join(arr):
|
||||
if not arr:
|
||||
return ""
|
||||
return "".join([t["text"] for t in arr])
|
||||
|
||||
# rm the spaning cells
|
||||
for i in range(len(tbl)):
|
||||
for j, arr in enumerate(tbl[i]):
|
||||
if not arr:
|
||||
continue
|
||||
if all(["rowspan" not in a and "colspan" not in a for a in arr]):
|
||||
continue
|
||||
rowspan, colspan = [], []
|
||||
for a in arr:
|
||||
if isinstance(a.get("rowspan", 0), list):
|
||||
rowspan.extend(a["rowspan"])
|
||||
if isinstance(a.get("colspan", 0), list):
|
||||
colspan.extend(a["colspan"])
|
||||
rowspan, colspan = set(rowspan), set(colspan)
|
||||
if len(rowspan) < 2 and len(colspan) < 2:
|
||||
for a in arr:
|
||||
if "rowspan" in a:
|
||||
del a["rowspan"]
|
||||
if "colspan" in a:
|
||||
del a["colspan"]
|
||||
continue
|
||||
rowspan, colspan = sorted(rowspan), sorted(colspan)
|
||||
rowspan = list(range(rowspan[0], rowspan[-1] + 1))
|
||||
colspan = list(range(colspan[0], colspan[-1] + 1))
|
||||
assert i in rowspan, rowspan
|
||||
assert j in colspan, colspan
|
||||
arr = []
|
||||
for r in rowspan:
|
||||
for c in colspan:
|
||||
arr_txt = join(arr)
|
||||
if tbl[r][c] and join(tbl[r][c]) != arr_txt:
|
||||
arr.extend(tbl[r][c])
|
||||
tbl[r][c] = None if html else arr
|
||||
for a in arr:
|
||||
if len(rowspan) > 1:
|
||||
a["rowspan"] = len(rowspan)
|
||||
elif "rowspan" in a:
|
||||
del a["rowspan"]
|
||||
if len(colspan) > 1:
|
||||
a["colspan"] = len(colspan)
|
||||
elif "colspan" in a:
|
||||
del a["colspan"]
|
||||
tbl[rowspan[0]][colspan[0]] = arr
|
||||
|
||||
return tbl
|
||||
|
||||
Reference in New Issue
Block a user