提交 b5268dc3 编写于 作者: H huangjun12

add centripetal text model

上级 31ed44d1
Global:
use_gpu: true
epoch_num: 600
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/det_ct/
save_epoch_step: 10
# evaluation is run every 2000 iterations
eval_batch_step: [0,1000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/ResNet18_vd_pretrained.pdparams
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img623.jpg
save_res_path: ./output/det_ct/predicts_ct.txt
Architecture:
model_type: det
algorithm: CT
Transform:
Backbone:
name: ResNet_vd
layers: 18
Neck:
name: CTFPN
Head:
name: CT_Head
in_channels: 512
hidden_dim: 128
num_classes: 3
Loss:
name: CTLoss
Optimizer:
name: Adam
lr: #PolynomialDecay
name: Linear
learning_rate: 0.001
end_lr: 0.
epochs: 600
step_each_epoch: 1254
power: 0.9
PostProcess:
name: CTPostProcess
box_type: poly
Metric:
name: CTMetric
main_indicator: f_score
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/total_text/train
label_file_list:
- ./train_data/total_text/train/train.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: RGB
channel_first: False
- CTLabelEncode: # Class handling label
- RandomScale:
- MakeShrink:
- GroupRandomHorizontalFlip:
- GroupRandomRotate:
- GroupRandomCropPadding:
- MakeCentripetalShift:
- ColorJitter:
brightness: 0.125
saturation: 0.5
- ToCHWImage:
- NormalizeImage:
- KeepKeys:
keep_keys: ['image', 'gt_kernel', 'training_mask', 'gt_instance', 'gt_kernel_instance', 'training_mask_distance', 'gt_distance'] # the order of the dataloader list
loader:
shuffle: True
drop_last: True
batch_size_per_card: 4
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/total_text/test
label_file_list:
- ./train_data/total_text/test/test.txt
ratio_list: [1.0]
transforms:
- DecodeImage:
img_mode: RGB
channel_first: False
- CTLabelEncode: # Class handling label
- ScaleAlignedShort:
- NormalizeImage:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'texts'] # the order of the dataloader list
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1
num_workers: 2
# CT
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945)
> Tao Sheng, Jie Chen, Zhouhui Lian
> NeurIPS, 2021
在Total-Text文本检测公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|precision|recall|Hmean|下载链接|
| --- | --- | --- | --- | --- | --- | --- |
|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
CT模型使用Total-Text文本检测公开数据集训练得到,数据集下载可参考 [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset), 我们将标签文件转成了paddleocr格式,转换好的标签文件下载参考[train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [text.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt)
请参考[文本检测训练教程](./detection.md)。PaddleOCR对代码进行了模块化,训练不同的检测模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将CT文本检测训练过程中保存的模型,转换成inference model。以基于Resnet18_vd骨干网络,在Total-Text英文数据集训练的模型为例( [模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar) ),可以使用如下命令进行转换:
```shell
python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct
```
CT文本检测模型推理,可以执行如下命令:
```shell
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
![](../imgs_results/det_res_img623_ct.jpg)
<a name="4-2"></a>
### 4.2 C++推理
暂不支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@inproceedings{sheng2021centripetaltext,
title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection},
author={Tao Sheng and Jie Chen and Zhouhui Lian},
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
year={2021}
}
```
# CT
- [1. Introduction](#1)
- [2. Environment](#2)
- [3. Model Training / Evaluation / Prediction](#3)
- [3.1 Training](#3-1)
- [3.2 Evaluation](#3-2)
- [3.3 Prediction](#3-3)
- [4. Inference and Deployment](#4)
- [4.1 Python Inference](#4-1)
- [4.2 C++ Inference](#4-2)
- [4.3 Serving](#4-3)
- [4.4 More](#4-4)
- [5. FAQ](#5)
<a name="1"></a>
## 1. Introduction
Paper:
> [CentripetalText: An Efficient Text Instance Representation for Scene Text Detection](https://arxiv.org/abs/2107.05945)
> Tao Sheng, Jie Chen, Zhouhui Lian
> NeurIPS, 2021
On the Total-Text dataset, the text detection result is as follows:
|Model|Backbone|Configuration|Precision|Recall|Hmean|Download|
| --- | --- | --- | --- | --- | --- | --- |
|CT|ResNet18_vd|[configs/det/det_r18_vd_ct.yml](../../configs/det/det_r18_vd_ct.yml)|88.68%|81.70%|85.05%|[trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)|
<a name="2"></a>
## 2. Environment
Please prepare your environment referring to [prepare the environment](./environment_en.md) and [clone the repo](./clone_en.md).
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
The above CT model is trained using the Total-Text text detection public dataset. For the download of the dataset, please refer to [Total-Text-Dataset](https://github.com/cs-chan/Total-Text-Dataset/tree/master/Dataset). PaddleOCR format annotation download link [train.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/train.txt), [test.txt](https://paddleocr.bj.bcebos.com/dataset/ct_tipc/test.txt).
Please refer to [text detection training tutorial](./detection_en.md). PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different detection models.
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, convert the model saved in the CT text detection training process into an inference model. Taking the model based on the Resnet18_vd backbone network and trained on the Total Text English dataset as example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r18_ct_train.tar)), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o Global.pretrained_model=./det_r18_ct_train/best_accuracy Global.save_inference_dir=./inference/det_ct
```
CT text detection model inference, you can execute the following command:
```shell
python3 tools/infer/predict_det.py --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_ct/" --det_algorithm="CT"
```
The visualized text detection results are saved to the `./inference_results` folder by default, and the name of the result file is prefixed with 'det_res'. Examples of results are as follows:
![](../imgs_results/det_res_img623_ct.jpg)
<a name="4-2"></a>
### 4.2 C++ Inference
Not supported
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@inproceedings{sheng2021centripetaltext,
title={CentripetalText: An Efficient Text Instance Representation for Scene Text Detection},
author={Tao Sheng and Jie Chen and Zhouhui Lian},
booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
year={2021}
}
```
...@@ -43,6 +43,7 @@ from .vqa import * ...@@ -43,6 +43,7 @@ from .vqa import *
from .fce_aug import * from .fce_aug import *
from .fce_targets import FCENetTargets from .fce_targets import FCENetTargets
from .ct_process import *
def transform(data, ops=None): def transform(data, ops=None):
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import random
import pyclipper
import paddle
import numpy as np
import Polygon as plg
import scipy.io as scio
from PIL import Image
import paddle.vision.transforms as transforms
class RandomScale():
def __init__(self, short_size=640, **kwargs):
self.short_size = short_size
def scale_aligned(self, img, scale):
oh, ow = img.shape[0:2]
h = int(oh * scale + 0.5)
w = int(ow * scale + 0.5)
if h % 32 != 0:
h = h + (32 - h % 32)
if w % 32 != 0:
w = w + (32 - w % 32)
img = cv2.resize(img, dsize=(w, h))
factor_h = h / oh
factor_w = w / ow
return img, factor_h, factor_w
def __call__(self, data):
img = data['image']
h, w = img.shape[0:2]
random_scale = np.array([0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3])
scale = (np.random.choice(random_scale) * self.short_size) / min(h, w)
img, factor_h, factor_w = self.scale_aligned(img, scale)
data['scale_factor'] = (factor_w, factor_h)
data['image'] = img
return data
class MakeShrink():
def __init__(self, kernel_scale=0.7, **kwargs):
self.kernel_scale = kernel_scale
def dist(self, a, b):
return np.linalg.norm((a - b), ord=2, axis=0)
def perimeter(self, bbox):
peri = 0.0
for i in range(bbox.shape[0]):
peri += self.dist(bbox[i], bbox[(i + 1) % bbox.shape[0]])
return peri
def shrink(self, bboxes, rate, max_shr=20):
rate = rate * rate
shrinked_bboxes = []
for bbox in bboxes:
area = plg.Polygon(bbox).area()
peri = self.perimeter(bbox)
try:
pco = pyclipper.PyclipperOffset()
pco.AddPath(bbox, pyclipper.JT_ROUND,
pyclipper.ET_CLOSEDPOLYGON)
offset = min(
int(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)
shrinked_bbox = pco.Execute(-offset)
if len(shrinked_bbox) == 0:
shrinked_bboxes.append(bbox)
continue
shrinked_bbox = np.array(shrinked_bbox[0])
if shrinked_bbox.shape[0] <= 2:
shrinked_bboxes.append(bbox)
continue
shrinked_bboxes.append(shrinked_bbox)
except Exception as e:
shrinked_bboxes.append(bbox)
return shrinked_bboxes
def __call__(self, data):
img = data['image']
bboxes = data['polys']
words = data['texts']
scale_factor = data['scale_factor']
gt_instance = np.zeros(img.shape[0:2], dtype='uint8') # h,w
training_mask = np.ones(img.shape[0:2], dtype='uint8')
training_mask_distance = np.ones(img.shape[0:2], dtype='uint8')
for i in range(len(bboxes)):
bboxes[i] = np.reshape(bboxes[i] * (
[scale_factor[0], scale_factor[1]] * (bboxes[i].shape[0] // 2)),
(bboxes[i].shape[0] // 2, 2)).astype('int32')
for i in range(len(bboxes)):
#different value for different bbox
cv2.drawContours(gt_instance, [bboxes[i]], -1, i + 1, -1)
# set training mask to 0
cv2.drawContours(training_mask, [bboxes[i]], -1, 0, -1)
# for not accurate annotation, use training_mask_distance
if words[i] == '###' or words[i] == '???':
cv2.drawContours(training_mask_distance, [bboxes[i]], -1, 0, -1)
# make shrink
gt_kernel_instance = np.zeros(img.shape[0:2], dtype='uint8')
kernel_bboxes = self.shrink(bboxes, self.kernel_scale)
for i in range(len(bboxes)):
cv2.drawContours(gt_kernel_instance, [kernel_bboxes[i]], -1, i + 1,
-1)
# for training mask, kernel and background= 1, box region=0
if words[i] != '###' and words[i] != '???':
cv2.drawContours(training_mask, [kernel_bboxes[i]], -1, 1, -1)
gt_kernel = gt_kernel_instance.copy()
# for gt_kernel, kernel = 1
gt_kernel[gt_kernel > 0] = 1
# shrink 2 times
tmp1 = gt_kernel_instance.copy()
erode_kernel = np.ones((3, 3), np.uint8)
tmp1 = cv2.erode(tmp1, erode_kernel, iterations=1)
tmp2 = tmp1.copy()
tmp2 = cv2.erode(tmp2, erode_kernel, iterations=1)
# compute text region
gt_kernel_inner = tmp1 - tmp2
# gt_instance: text instance, bg=0, diff word use diff value
# training_mask: text instance mask, word=0,kernel and bg=1
# gt_kernel_instance: text kernel instance, bg=0, diff word use diff value
# gt_kernel: text_kernel, bg=0,diff word use same value
# gt_kernel_inner: text kernel reference
# training_mask_distance: word without anno = 0, else 1
data['image'] = [
img, gt_instance, training_mask, gt_kernel_instance, gt_kernel,
gt_kernel_inner, training_mask_distance
]
return data
class GroupRandomHorizontalFlip():
def __init__(self, p=0.5, **kwargs):
self.p = p
def __call__(self, data):
imgs = data['image']
if random.random() < self.p:
for i in range(len(imgs)):
imgs[i] = np.flip(imgs[i], axis=1).copy()
data['image'] = imgs
return data
class GroupRandomRotate():
def __init__(self, **kwargs):
pass
def __call__(self, data):
imgs = data['image']
max_angle = 10
angle = random.random() * 2 * max_angle - max_angle
for i in range(len(imgs)):
img = imgs[i]
w, h = img.shape[:2]
rotation_matrix = cv2.getRotationMatrix2D((h / 2, w / 2), angle, 1)
img_rotation = cv2.warpAffine(
img, rotation_matrix, (h, w), flags=cv2.INTER_NEAREST)
imgs[i] = img_rotation
data['image'] = imgs
return data
class GroupRandomCropPadding():
def __init__(self, target_size=(640, 640), **kwargs):
self.target_size = target_size
def __call__(self, data):
imgs = data['image']
h, w = imgs[0].shape[0:2]
t_w, t_h = self.target_size
p_w, p_h = self.target_size
if w == t_w and h == t_h:
return data
t_h = t_h if t_h < h else h
t_w = t_w if t_w < w else w
if random.random() > 3.0 / 8.0 and np.max(imgs[1]) > 0:
# make sure to crop the text region
tl = np.min(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
tl[tl < 0] = 0
br = np.max(np.where(imgs[1] > 0), axis=1) - (t_h, t_w)
br[br < 0] = 0
br[0] = min(br[0], h - t_h)
br[1] = min(br[1], w - t_w)
i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
else:
i = random.randint(0, h - t_h) if h - t_h > 0 else 0
j = random.randint(0, w - t_w) if w - t_w > 0 else 0
n_imgs = []
for idx in range(len(imgs)):
if len(imgs[idx].shape) == 3:
s3_length = int(imgs[idx].shape[-1])
img = imgs[idx][i:i + t_h, j:j + t_w, :]
img_p = cv2.copyMakeBorder(
img,
0,
p_h - t_h,
0,
p_w - t_w,
borderType=cv2.BORDER_CONSTANT,
value=tuple(0 for i in range(s3_length)))
else:
img = imgs[idx][i:i + t_h, j:j + t_w]
img_p = cv2.copyMakeBorder(
img,
0,
p_h - t_h,
0,
p_w - t_w,
borderType=cv2.BORDER_CONSTANT,
value=(0, ))
n_imgs.append(img_p)
data['image'] = n_imgs
return data
class MakeCentripetalShift():
def __init__(self, **kwargs):
pass
def jaccard(self, As, Bs):
A = As.shape[0] # small
B = Bs.shape[0] # large
dis = np.sqrt(
np.sum((As[:, np.newaxis, :].repeat(
B, axis=1) - Bs[np.newaxis, :, :].repeat(
A, axis=0))**2,
axis=-1))
ind = np.argmin(dis, axis=-1)
return ind
def __call__(self, data):
imgs = data['image']
img, gt_instance, training_mask, gt_kernel_instance, gt_kernel, gt_kernel_inner, training_mask_distance = \
imgs[0], imgs[1], imgs[2], imgs[3], imgs[4], imgs[5], imgs[6]
max_instance = np.max(gt_instance) # num bbox
# make centripetal shift
gt_distance = np.zeros((2, *img.shape[0:2]), dtype=np.float32)
for i in range(1, max_instance + 1):
# kernel_reference
ind = (gt_kernel_inner == i)
if np.sum(ind) == 0:
training_mask[gt_instance == i] = 0
training_mask_distance[gt_instance == i] = 0
continue
kpoints = np.array(np.where(ind)).transpose(
(1, 0))[:, ::-1].astype('float32')
ind = (gt_instance == i) * (gt_kernel_instance == 0)
if np.sum(ind) == 0:
continue
pixels = np.where(ind)
points = np.array(pixels).transpose(
(1, 0))[:, ::-1].astype('float32')
bbox_ind = self.jaccard(points, kpoints)
offset_gt = kpoints[bbox_ind] - points
gt_distance[:, pixels[0], pixels[1]] = offset_gt.T * 0.1
img = Image.fromarray(img)
img = img.convert('RGB')
data["image"] = img
data["gt_kernel"] = gt_kernel.astype("int64")
data["training_mask"] = training_mask.astype("int64")
data["gt_instance"] = gt_instance.astype("int64")
data["gt_kernel_instance"] = gt_kernel_instance.astype("int64")
data["training_mask_distance"] = training_mask_distance.astype("int64")
data["gt_distance"] = gt_distance.astype("float32")
return data
class ScaleAlignedShort():
def __init__(self, short_size=640, **kwargs):
self.short_size = short_size
def __call__(self, data):
img = data['image']
org_img_shape = img.shape
h, w = img.shape[0:2]
scale = self.short_size * 1.0 / min(h, w)
h = int(h * scale + 0.5)
w = int(w * scale + 0.5)
if h % 32 != 0:
h = h + (32 - h % 32)
if w % 32 != 0:
w = w + (32 - w % 32)
img = cv2.resize(img, dsize=(w, h))
new_img_shape = img.shape
img_shape = np.array(org_img_shape + new_img_shape)
data['shape'] = img_shape
data['image'] = img
return data
\ No newline at end of file
...@@ -1395,3 +1395,29 @@ class VLLabelEncode(BaseRecLabelEncode): ...@@ -1395,3 +1395,29 @@ class VLLabelEncode(BaseRecLabelEncode):
data['label_res'] = np.array(label_res) data['label_res'] = np.array(label_res)
data['label_sub'] = np.array(label_sub) data['label_sub'] = np.array(label_sub)
return data return data
class CTLabelEncode(object):
def __init__(self, **kwargs):
pass
def __call__(self, data):
label = data['label']
label = json.loads(label)
nBox = len(label)
boxes, txts = [], []
for bno in range(0, nBox):
box = label[bno]['points']
box = np.array(box)
boxes.append(box)
txt = label[bno]['transcription']
txts.append(txt)
if len(boxes) == 0:
return None
data['polys'] = boxes
data['texts'] = txts
return data
\ No newline at end of file
...@@ -25,6 +25,7 @@ from .det_east_loss import EASTLoss ...@@ -25,6 +25,7 @@ from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss from .det_sast_loss import SASTLoss
from .det_pse_loss import PSELoss from .det_pse_loss import PSELoss
from .det_fce_loss import FCELoss from .det_fce_loss import FCELoss
from .det_ct_loss import CTLoss
# rec loss # rec loss
from .rec_ctc_loss import CTCLoss from .rec_ctc_loss import CTCLoss
...@@ -68,7 +69,7 @@ def build_loss(config): ...@@ -68,7 +69,7 @@ def build_loss(config):
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss', 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss', 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
'SLALoss' 'SLALoss', 'CTLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/shengtao96/CentripetalText/tree/main/models/loss
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
import numpy as np
def ohem_single(score, gt_text, training_mask):
# online hard example mining
pos_num = int(paddle.sum(gt_text > 0.5)) - int(
paddle.sum((gt_text > 0.5) & (training_mask <= 0.5)))
if pos_num == 0:
# selected_mask = gt_text.copy() * 0 # may be not good
selected_mask = training_mask
selected_mask = paddle.cast(
selected_mask.reshape(
(1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
return selected_mask
neg_num = int(paddle.sum((gt_text <= 0.5) & (training_mask > 0.5)))
neg_num = int(min(pos_num * 3, neg_num))
if neg_num == 0:
selected_mask = training_mask
selected_mask = paddle.cast(
selected_mask.reshape(
(1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
return selected_mask
# hard example
neg_score = score[(gt_text <= 0.5) & (training_mask > 0.5)]
neg_score_sorted = paddle.sort(-neg_score)
threshold = -neg_score_sorted[neg_num - 1]
selected_mask = ((score >= threshold) |
(gt_text > 0.5)) & (training_mask > 0.5)
selected_mask = paddle.cast(
selected_mask.reshape(
(1, selected_mask.shape[0], selected_mask.shape[1])), "float32")
return selected_mask
def ohem_batch(scores, gt_texts, training_masks):
selected_masks = []
for i in range(scores.shape[0]):
selected_masks.append(
ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
i, :, :]))
selected_masks = paddle.cast(paddle.concat(selected_masks, 0), "float32")
return selected_masks
def iou_single(a, b, mask, n_class):
EPS = 1e-6
valid = mask == 1
a = a[valid]
b = b[valid]
miou = []
# iou of each class
for i in range(n_class):
inter = paddle.cast(((a == i) & (b == i)), "float32")
union = paddle.cast(((a == i) | (b == i)), "float32")
miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
miou = sum(miou) / len(miou)
return miou
def iou(a, b, mask, n_class=2, reduce=True):
batch_size = a.shape[0]
a = a.reshape((batch_size, -1))
b = b.reshape((batch_size, -1))
mask = mask.reshape((batch_size, -1))
iou = paddle.zeros((batch_size, ), dtype="float32")
for i in range(batch_size):
iou[i] = iou_single(a[i], b[i], mask[i], n_class)
if reduce:
iou = paddle.mean(iou)
return iou
class DiceLoss(nn.Layer):
def __init__(self, loss_weight=1.0):
super(DiceLoss, self).__init__()
self.loss_weight = loss_weight
def forward(self, input, target, mask, reduce=True):
batch_size = input.shape[0]
input = F.sigmoid(input) # scale to 0-1
input = input.reshape((batch_size, -1))
target = paddle.cast(target.reshape((batch_size, -1)), "float32")
mask = paddle.cast(mask.reshape((batch_size, -1)), "float32")
input = input * mask
target = target * mask
a = paddle.sum(input * target, axis=1)
b = paddle.sum(input * input, axis=1) + 0.001
c = paddle.sum(target * target, axis=1) + 0.001
d = (2 * a) / (b + c)
loss = 1 - d
loss = self.loss_weight * loss
if reduce:
loss = paddle.mean(loss)
return loss
class SmoothL1Loss(nn.Layer):
def __init__(self, beta=1.0, loss_weight=1.0):
super(SmoothL1Loss, self).__init__()
self.beta = beta
self.loss_weight = loss_weight
np_coord = np.zeros(shape=[640, 640, 2], dtype=np.int64)
for i in range(640):
for j in range(640):
np_coord[i, j, 0] = j
np_coord[i, j, 1] = i
np_coord = np_coord.reshape((-1, 2))
self.coord = self.create_parameter(
shape=[640 * 640, 2],
dtype="int32", # NOTE: not support "int64" before paddle 2.3.1
default_initializer=nn.initializer.Assign(value=np_coord))
self.coord.stop_gradient = True
def forward_single(self, input, target, mask, beta=1.0, eps=1e-6):
batch_size = input.shape[0]
diff = paddle.abs(input - target) * mask.unsqueeze(1)
loss = paddle.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
loss = paddle.cast(loss.reshape((batch_size, -1)), "float32")
mask = paddle.cast(mask.reshape((batch_size, -1)), "float32")
loss = paddle.sum(loss, axis=-1)
loss = loss / (mask.sum(axis=-1) + eps)
return loss
def select_single(self, distance, gt_instance, gt_kernel_instance,
training_mask):
with paddle.no_grad():
# paddle 2.3.1, paddle.slice not support:
# distance[:, self.coord[:, 1], self.coord[:, 0]]
select_distance_list = []
for i in range(2):
tmp1 = distance[i, :]
tmp2 = tmp1[self.coord[:, 1], self.coord[:, 0]]
select_distance_list.append(tmp2.unsqueeze(0))
select_distance = paddle.concat(select_distance_list, axis=0)
off_points = paddle.cast(
self.coord, "float32") + 10 * select_distance.transpose((1, 0))
off_points = paddle.cast(off_points, "int64")
off_points = paddle.clip(off_points, 0, distance.shape[-1] - 1)
selected_mask = (
gt_instance[self.coord[:, 1], self.coord[:, 0]] !=
gt_kernel_instance[off_points[:, 1], off_points[:, 0]])
selected_mask = paddle.cast(
selected_mask.reshape((1, -1, distance.shape[-1])), "int64")
selected_training_mask = selected_mask * training_mask
return selected_training_mask
def forward(self,
distances,
gt_instances,
gt_kernel_instances,
training_masks,
gt_distances,
reduce=True):
selected_training_masks = []
for i in range(distances.shape[0]):
selected_training_masks.append(
self.select_single(distances[i, :, :, :], gt_instances[i, :, :],
gt_kernel_instances[i, :, :], training_masks[
i, :, :]))
selected_training_masks = paddle.cast(
paddle.concat(selected_training_masks, 0), "float32")
loss = self.forward_single(distances, gt_distances,
selected_training_masks, self.beta)
loss = self.loss_weight * loss
with paddle.no_grad():
batch_size = distances.shape[0]
false_num = selected_training_masks.reshape((batch_size, -1))
false_num = false_num.sum(axis=-1)
total_num = paddle.cast(
training_masks.reshape((batch_size, -1)), "float32")
total_num = total_num.sum(axis=-1)
iou_text = (total_num - false_num) / (total_num + 1e-6)
if reduce:
loss = paddle.mean(loss)
return loss, iou_text
class CTLoss(nn.Layer):
def __init__(self):
super(CTLoss, self).__init__()
self.kernel_loss = DiceLoss()
self.loc_loss = SmoothL1Loss(beta=0.1, loss_weight=0.05)
def forward(self, preds, batch):
imgs = batch[0]
out = preds['maps']
gt_kernels, training_masks, gt_instances, gt_kernel_instances, training_mask_distances, gt_distances = batch[
1:]
kernels = out[:, 0, :, :]
distances = out[:, 1:, :, :]
# kernel loss
selected_masks = ohem_batch(kernels, gt_kernels, training_masks)
loss_kernel = self.kernel_loss(
kernels, gt_kernels, selected_masks, reduce=False)
iou_kernel = iou(paddle.cast((kernels > 0), "int64"),
gt_kernels,
training_masks,
reduce=False)
losses = dict(loss_kernels=loss_kernel, )
# loc loss
loss_loc, iou_text = self.loc_loss(
distances,
gt_instances,
gt_kernel_instances,
training_mask_distances,
gt_distances,
reduce=False)
losses.update(dict(loss_loc=loss_loc, ))
loss_all = loss_kernel + loss_loc
losses = {'loss': loss_all}
return losses
...@@ -31,12 +31,14 @@ from .kie_metric import KIEMetric ...@@ -31,12 +31,14 @@ from .kie_metric import KIEMetric
from .vqa_token_ser_metric import VQASerTokenMetric from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric from .vqa_token_re_metric import VQAReTokenMetric
from .sr_metric import SRMetric from .sr_metric import SRMetric
from .ct_metric import CTMetric
def build_metric(config): def build_metric(config):
support_dict = [ support_dict = [
"DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric", "DetMetric", "DetFCEMetric", "RecMetric", "ClsMetric", "E2EMetric",
"DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric', "DistillationMetric", "TableMetric", 'KIEMetric', 'VQASerTokenMetric',
'VQAReTokenMetric', 'SRMetric' 'VQAReTokenMetric', 'SRMetric', 'CTMetric'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from scipy import io
import numpy as np
from ppocr.utils.e2e_metric.Deteval import combine_results, get_score_C
class CTMetric(object):
def __init__(self, main_indicator, delimiter='\t', **kwargs):
self.delimiter = delimiter
self.main_indicator = main_indicator
self.reset()
def reset(self):
self.results = [] # clear results
def __call__(self, preds, batch, **kwargs):
# NOTE: only support bs=1 now, as the label length of different sample is Unequal
assert len(
preds) == 1, "CentripetalText test now only suuport batch_size=1."
label = batch[2]
text = batch[3]
pred = preds[0]['points']
result = get_score_C(label, text, pred)
self.results.append(result)
def get_metric(self):
"""
Input format: y0,x0, ..... yn,xn. Each detection is separated by the end of line token ('\n')'
"""
metrics = combine_results(self.results, rec_flag=False)
self.reset()
return metrics
...@@ -23,6 +23,7 @@ def build_head(config): ...@@ -23,6 +23,7 @@ def build_head(config):
from .det_pse_head import PSEHead from .det_pse_head import PSEHead
from .det_fce_head import FCEHead from .det_fce_head import FCEHead
from .e2e_pg_head import PGHead from .e2e_pg_head import PGHead
from .det_ct_head import CT_Head
# rec head # rec head
from .rec_ctc_head import CTCHead from .rec_ctc_head import CTCHead
...@@ -52,7 +53,7 @@ def build_head(config): ...@@ -52,7 +53,7 @@ def build_head(config):
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead', 'MultiHead', 'ABINetHead', 'TableMasterHead', 'SPINAttentionHead',
'VLHead', 'SLAHead', 'RobustScannerHead' 'VLHead', 'SLAHead', 'RobustScannerHead', 'CT_Head'
] ]
#table head #table head
......
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
import math
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
ones_ = Constant(value=1.)
zeros_ = Constant(value=0.)
class CT_Head(nn.Layer):
def __init__(self,
in_channels,
hidden_dim,
num_classes,
loss_kernel=None,
loss_loc=None):
super(CT_Head, self).__init__()
self.conv1 = nn.Conv2D(
in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2D(hidden_dim)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
hidden_dim, num_classes, kernel_size=1, stride=1, padding=0)
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
normal_(m.weight)
elif isinstance(m, nn.BatchNorm2D):
zeros_(m.bias)
ones_(m.weight)
def _upsample(self, x, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear')
def forward(self, f, targets=None):
out = self.conv1(f)
out = self.relu1(self.bn1(out))
out = self.conv2(out)
if self.training:
out = self._upsample(out, scale=4)
return {'maps': out}
else:
score = F.sigmoid(out[:, 0, :, :])
return {'maps': out, 'score': score}
...@@ -26,13 +26,15 @@ def build_neck(config): ...@@ -26,13 +26,15 @@ def build_neck(config):
from .fce_fpn import FCEFPN from .fce_fpn import FCEFPN
from .pren_fpn import PRENFPN from .pren_fpn import PRENFPN
from .csp_pan import CSPPAN from .csp_pan import CSPPAN
from .ct_fpn import CTFPN
support_dict = [ support_dict = [
'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN', 'FPN', 'FCEFPN', 'LKPAN', 'DBFPN', 'RSEFPN', 'EASTFPN', 'SASTFPN',
'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN' 'SequenceEncoder', 'PGFPN', 'TableFPN', 'PRENFPN', 'CSPPAN', 'CTFPN'
] ]
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format( assert module_name in support_dict, Exception('neck only support {}'.format(
support_dict)) support_dict))
module_class = eval(module_name)(**config) module_class = eval(module_name)(**config)
return module_class return module_class
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
import os
import sys
import math
from paddle.nn.initializer import TruncatedNormal, Constant, Normal
ones_ = Constant(value=1.)
zeros_ = Constant(value=0.)
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
class Conv_BN_ReLU(nn.Layer):
def __init__(self,
in_planes,
out_planes,
kernel_size=1,
stride=1,
padding=0):
super(Conv_BN_ReLU, self).__init__()
self.conv = nn.Conv2D(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias_attr=False)
self.bn = nn.BatchNorm2D(out_planes)
self.relu = nn.ReLU()
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
normal_(m.weight)
elif isinstance(m, nn.BatchNorm2D):
zeros_(m.bias)
ones_(m.weight)
def forward(self, x):
return self.relu(self.bn(self.conv(x)))
class FPEM(nn.Layer):
def __init__(self, in_channels, out_channels):
super(FPEM, self).__init__()
planes = out_channels
self.dwconv3_1 = nn.Conv2D(
planes,
planes,
kernel_size=3,
stride=1,
padding=1,
groups=planes,
bias_attr=False)
self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes)
self.dwconv2_1 = nn.Conv2D(
planes,
planes,
kernel_size=3,
stride=1,
padding=1,
groups=planes,
bias_attr=False)
self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes)
self.dwconv1_1 = nn.Conv2D(
planes,
planes,
kernel_size=3,
stride=1,
padding=1,
groups=planes,
bias_attr=False)
self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes)
self.dwconv2_2 = nn.Conv2D(
planes,
planes,
kernel_size=3,
stride=2,
padding=1,
groups=planes,
bias_attr=False)
self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes)
self.dwconv3_2 = nn.Conv2D(
planes,
planes,
kernel_size=3,
stride=2,
padding=1,
groups=planes,
bias_attr=False)
self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes)
self.dwconv4_2 = nn.Conv2D(
planes,
planes,
kernel_size=3,
stride=2,
padding=1,
groups=planes,
bias_attr=False)
self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes)
def _upsample_add(self, x, y):
return F.upsample(x, scale_factor=2, mode='bilinear') + y
def forward(self, f1, f2, f3, f4):
# up-down
f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3)))
f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2)))
f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1)))
# down-up
f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1)))
f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2)))
f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3)))
return f1, f2, f3, f4
class CTFPN(nn.Layer):
def __init__(self, in_channels, out_channel=128):
super(CTFPN, self).__init__()
self.out_channels = out_channel * 4
self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128)
self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128)
self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128)
self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128)
self.fpem1 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
def _upsample(self, x, scale=1):
return F.upsample(x, scale_factor=scale, mode='bilinear')
def forward(self, f):
# # reduce channel
f1 = self.reduce_layer1(f[0]) # N,64,160,160 --> N, 128, 160, 160
f2 = self.reduce_layer2(f[1]) # N, 128, 80, 80 --> N, 128, 80, 80
f3 = self.reduce_layer3(f[2]) # N, 256, 40, 40 --> N, 128, 40, 40
f4 = self.reduce_layer4(f[3]) # N, 512, 20, 20 --> N, 128, 20, 20
# FPEM
f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4)
f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1)
# FFM
f1 = f1_1 + f1_2
f2 = f2_1 + f2_2
f3 = f3_1 + f3_2
f4 = f4_1 + f4_2
f2 = self._upsample(f2, scale=2)
f3 = self._upsample(f3, scale=4)
f4 = self._upsample(f4, scale=8)
ff = paddle.concat((f1, f2, f3, f4), 1) # N,512, 160,160
return ff
...@@ -35,6 +35,7 @@ from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, ...@@ -35,6 +35,7 @@ from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess,
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess from .picodet_postprocess import PicoDetPostProcess
from .ct_postprocess import CTPostProcess
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
...@@ -48,7 +49,7 @@ def build_post_process(config, global_config=None): ...@@ -48,7 +49,7 @@ def build_post_process(config, global_config=None):
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode', 'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess', 'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess' 'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refered from:
https://github.com/shengtao96/CentripetalText/blob/main/test.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import os.path as osp
import numpy as np
import cv2
import paddle
import pyclipper
class CTPostProcess(object):
"""
The post process for Centripetal Text (CT).
"""
def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs):
self.min_score = min_score
self.min_area = min_area
self.box_type = box_type
self.coord = np.zeros((2, 300, 300), dtype=np.int32)
for i in range(300):
for j in range(300):
self.coord[0, i, j] = j
self.coord[1, i, j] = i
def __call__(self, preds, batch):
outs = preds['maps']
out_scores = preds['score']
if isinstance(outs, paddle.Tensor):
outs = outs.numpy()
if isinstance(out_scores, paddle.Tensor):
out_scores = out_scores.numpy()
batch_size = outs.shape[0]
boxes_batch = []
for idx in range(batch_size):
bboxes = []
scores = []
img_shape = batch[idx]
org_img_size = img_shape[:3]
img_shape = img_shape[3:]
img_size = img_shape[:2]
out = np.expand_dims(outs[idx], axis=0)
outputs = dict()
score = np.expand_dims(out_scores[idx], axis=0)
kernel = out[:, 0, :, :] > 0.2
loc = out[:, 1:, :, :].astype("float32")
score = score[0].astype(np.float32)
kernel = kernel[0].astype(np.uint8)
loc = loc[0].astype(np.float32)
label_num, label_kernel = cv2.connectedComponents(
kernel, connectivity=4)
for i in range(1, label_num):
ind = (label_kernel == i)
if ind.sum(
) < 10: # pixel number less than 10, treated as background
label_kernel[ind] = 0
label = np.zeros_like(label_kernel)
h, w = label_kernel.shape
pixels = self.coord[:, :h, :w].reshape(2, -1)
points = pixels.transpose([1, 0]).astype(np.float32)
off_points = (points + 10. / 4. * loc[:, pixels[1], pixels[0]].T
).astype(np.int32)
off_points[:, 0] = np.clip(off_points[:, 0], 0, label.shape[1] - 1)
off_points[:, 1] = np.clip(off_points[:, 1], 0, label.shape[0] - 1)
label[pixels[1], pixels[0]] = label_kernel[off_points[:, 1],
off_points[:, 0]]
label[label_kernel > 0] = label_kernel[label_kernel > 0]
score_pocket = [0.0]
for i in range(1, label_num):
ind = (label_kernel == i)
if ind.sum() == 0:
score_pocket.append(0.0)
continue
score_i = np.mean(score[ind])
score_pocket.append(score_i)
label_num = np.max(label) + 1
label = cv2.resize(
label, (img_size[1], img_size[0]),
interpolation=cv2.INTER_NEAREST)
scale = (float(org_img_size[1]) / float(img_size[1]),
float(org_img_size[0]) / float(img_size[0]))
for i in range(1, label_num):
ind = (label == i)
points = np.array(np.where(ind)).transpose((1, 0))
if points.shape[0] < self.min_area:
continue
score_i = score_pocket[i]
if score_i < self.min_score:
continue
if self.box_type == 'rect':
rect = cv2.minAreaRect(points[:, ::-1])
bbox = cv2.boxPoints(rect) * scale
z = bbox.mean(0)
bbox = z + (bbox - z) * 0.85
elif self.box_type == 'poly':
binary = np.zeros(label.shape, dtype='uint8')
binary[ind] = 1
try:
_, contours, _ = cv2.findContours(
binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
except BaseException:
contours, _ = cv2.findContours(
binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
bbox = contours[0] * scale
bbox = bbox.astype('int32')
bboxes.append(bbox.reshape(-1, 2))
scores.append(score_i)
boxes_batch.append({'points': bboxes})
return boxes_batch
...@@ -12,8 +12,10 @@ ...@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
import numpy as np import numpy as np
import scipy.io as io import scipy.io as io
import Polygon as plg
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
...@@ -269,7 +271,124 @@ def get_socre_B(gt_dir, img_id, pred_dict): ...@@ -269,7 +271,124 @@ def get_socre_B(gt_dir, img_id, pred_dict):
return single_data return single_data
def combine_results(all_data): def get_score_C(gt_label, text, pred_bboxes):
"""
get score for CentripetalText (CT) prediction.
"""
def gt_reading_mod(gt_label, text):
"""This helper reads groundtruths from mat files"""
groundtruths = []
nbox = len(gt_label)
for i in range(nbox):
label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
groundtruths.append(label)
return groundtruths
def get_union(pD, pG):
areaA = pD.area()
areaB = pG.area()
return areaA + areaB - get_intersection(pD, pG)
def get_intersection(pD, pG):
pInt = pD & pG
if len(pInt) == 0:
return 0
return pInt.area()
def detection_filtering(detections, groundtruths, threshold=0.5):
for gt in groundtruths:
point_num = gt['points'].shape[1] // 2
if gt['transcription'] == '###' and (point_num > 1):
gt_p = np.array(gt['points']).reshape(point_num,
2).astype('int32')
gt_p = plg.Polygon(gt_p)
for det_id, detection in enumerate(detections):
det_y = detection[0::2]
det_x = detection[1::2]
det_p = np.concatenate((np.array(det_x), np.array(det_y)))
det_p = det_p.reshape(2, -1).transpose()
det_p = plg.Polygon(det_p)
try:
det_gt_iou = get_intersection(det_p,
gt_p) / det_p.area()
except:
print(det_x, det_y, gt_p)
if det_gt_iou > threshold:
detections[det_id] = []
detections[:] = [item for item in detections if item != []]
return detections
def sigma_calculation(det_p, gt_p):
"""
sigma = inter_area / gt_area
"""
if gt_p.area() == 0.:
return 0
return get_intersection(det_p, gt_p) / gt_p.area()
def tau_calculation(det_p, gt_p):
"""
tau = inter_area / det_area
"""
if det_p.area() == 0.:
return 0
return get_intersection(det_p, gt_p) / det_p.area()
detections = []
for item in pred_bboxes:
detections.append(item[:, ::-1].reshape(-1))
groundtruths = gt_reading_mod(gt_label, text)
detections = detection_filtering(
detections, groundtruths) # filters detections overlapping with DC area
for idx in range(len(groundtruths) - 1, -1, -1):
#NOTE: source code use 'orin' to indicate '#', here we use 'anno',
# which may cause slight drop in fscore, about 0.12
if groundtruths[idx]['transcription'] == '###':
groundtruths.pop(idx)
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
local_tau_table = np.zeros((len(groundtruths), len(detections)))
for gt_id, gt in enumerate(groundtruths):
if len(detections) > 0:
for det_id, detection in enumerate(detections):
point_num = gt['points'].shape[1] // 2
gt_p = np.array(gt['points']).reshape(point_num,
2).astype('int32')
gt_p = plg.Polygon(gt_p)
det_y = detection[0::2]
det_x = detection[1::2]
det_p = np.concatenate((np.array(det_x), np.array(det_y)))
det_p = det_p.reshape(2, -1).transpose()
det_p = plg.Polygon(det_p)
local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
gt_p)
local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
data = {}
data['sigma'] = local_sigma_table
data['global_tau'] = local_tau_table
data['global_pred_str'] = ''
data['global_gt_str'] = ''
return data
def combine_results(all_data, rec_flag=True):
tr = 0.7 tr = 0.7
tp = 0.6 tp = 0.6
fsc_k = 0.8 fsc_k = 0.8
...@@ -278,6 +397,7 @@ def combine_results(all_data): ...@@ -278,6 +397,7 @@ def combine_results(all_data):
global_tau = [] global_tau = []
global_pred_str = [] global_pred_str = []
global_gt_str = [] global_gt_str = []
for data in all_data: for data in all_data:
global_sigma.append(data['sigma']) global_sigma.append(data['sigma'])
global_tau.append(data['global_tau']) global_tau.append(data['global_tau'])
...@@ -294,7 +414,7 @@ def combine_results(all_data): ...@@ -294,7 +414,7 @@ def combine_results(all_data):
def one_to_one(local_sigma_table, local_tau_table, def one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision, local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy): gt_flag, det_flag, idy, rec_flag):
hit_str_num = 0 hit_str_num = 0
for gt_id in range(num_gt): for gt_id in range(num_gt):
gt_matching_qualified_sigma_candidates = np.where( gt_matching_qualified_sigma_candidates = np.where(
...@@ -328,9 +448,10 @@ def combine_results(all_data): ...@@ -328,9 +448,10 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr) matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
# recg start # recg start
if rec_flag:
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[ pred_str_cur = global_pred_str[idy][matched_det_id[0]
0]] .tolist()[0]]
if pred_str_cur == gt_str_cur: if pred_str_cur == gt_str_cur:
hit_str_num += 1 hit_str_num += 1
else: else:
...@@ -343,7 +464,7 @@ def combine_results(all_data): ...@@ -343,7 +464,7 @@ def combine_results(all_data):
def one_to_many(local_sigma_table, local_tau_table, def one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision, local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy): gt_flag, det_flag, idy, rec_flag):
hit_str_num = 0 hit_str_num = 0
for gt_id in range(num_gt): for gt_id in range(num_gt):
# skip the following if the groundtruth was matched # skip the following if the groundtruth was matched
...@@ -374,6 +495,7 @@ def combine_results(all_data): ...@@ -374,6 +495,7 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1 det_flag[0, qualified_tau_candidates] = 1
# recg start # recg start
if rec_flag:
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][ pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]] qualified_tau_candidates[0].tolist()[0]]
...@@ -388,6 +510,7 @@ def combine_results(all_data): ...@@ -388,6 +510,7 @@ def combine_results(all_data):
gt_flag[0, gt_id] = 1 gt_flag[0, gt_id] = 1
det_flag[0, qualified_tau_candidates] = 1 det_flag[0, qualified_tau_candidates] = 1
# recg start # recg start
if rec_flag:
gt_str_cur = global_gt_str[idy][gt_id] gt_str_cur = global_gt_str[idy][gt_id]
pred_str_cur = global_pred_str[idy][ pred_str_cur = global_pred_str[idy][
qualified_tau_candidates[0].tolist()[0]] qualified_tau_candidates[0].tolist()[0]]
...@@ -409,7 +532,7 @@ def combine_results(all_data): ...@@ -409,7 +532,7 @@ def combine_results(all_data):
def many_to_one(local_sigma_table, local_tau_table, def many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision, local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idy): gt_flag, det_flag, idy, rec_flag):
hit_str_num = 0 hit_str_num = 0
for det_id in range(num_det): for det_id in range(num_det):
# skip the following if the detection was matched # skip the following if the detection was matched
...@@ -440,11 +563,12 @@ def combine_results(all_data): ...@@ -440,11 +563,12 @@ def combine_results(all_data):
gt_flag[0, qualified_sigma_candidates] = 1 gt_flag[0, qualified_sigma_candidates] = 1
det_flag[0, det_id] = 1 det_flag[0, det_id] = 1
# recg start # recg start
if rec_flag:
pred_str_cur = global_pred_str[idy][det_id] pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0]) gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len): for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[ ele_gt_id = qualified_sigma_candidates[
idx] 0].tolist()[idx]
if ele_gt_id not in global_gt_str[idy]: if ele_gt_id not in global_gt_str[idy]:
continue continue
gt_str_cur = global_gt_str[idy][ele_gt_id] gt_str_cur = global_gt_str[idy][ele_gt_id]
...@@ -452,7 +576,8 @@ def combine_results(all_data): ...@@ -452,7 +576,8 @@ def combine_results(all_data):
hit_str_num += 1 hit_str_num += 1
break break
else: else:
if pred_str_cur.lower() == gt_str_cur.lower(): if pred_str_cur.lower() == gt_str_cur.lower(
):
hit_str_num += 1 hit_str_num += 1
break break
# recg end # recg end
...@@ -461,10 +586,12 @@ def combine_results(all_data): ...@@ -461,10 +586,12 @@ def combine_results(all_data):
det_flag[0, det_id] = 1 det_flag[0, det_id] = 1
gt_flag[0, qualified_sigma_candidates] = 1 gt_flag[0, qualified_sigma_candidates] = 1
# recg start # recg start
if rec_flag:
pred_str_cur = global_pred_str[idy][det_id] pred_str_cur = global_pred_str[idy][det_id]
gt_len = len(qualified_sigma_candidates[0]) gt_len = len(qualified_sigma_candidates[0])
for idx in range(gt_len): for idx in range(gt_len):
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx] ele_gt_id = qualified_sigma_candidates[0].tolist()[
idx]
if ele_gt_id not in global_gt_str[idy]: if ele_gt_id not in global_gt_str[idy]:
continue continue
gt_str_cur = global_gt_str[idy][ele_gt_id] gt_str_cur = global_gt_str[idy][ele_gt_id]
...@@ -504,7 +631,7 @@ def combine_results(all_data): ...@@ -504,7 +631,7 @@ def combine_results(all_data):
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table, gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision, local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx) gt_flag, det_flag, idx, rec_flag)
hit_str_count += hit_str_num hit_str_count += hit_str_num
#######then check for one-to-many case########## #######then check for one-to-many case##########
...@@ -512,14 +639,14 @@ def combine_results(all_data): ...@@ -512,14 +639,14 @@ def combine_results(all_data):
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table, gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision, local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx) gt_flag, det_flag, idx, rec_flag)
hit_str_count += hit_str_num hit_str_count += hit_str_num
#######then check for many-to-one case########## #######then check for many-to-one case##########
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \ local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table, gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
local_accumulative_recall, local_accumulative_precision, local_accumulative_recall, local_accumulative_precision,
global_accumulative_recall, global_accumulative_precision, global_accumulative_recall, global_accumulative_precision,
gt_flag, det_flag, idx) gt_flag, det_flag, idx, rec_flag)
hit_str_count += hit_str_num hit_str_count += hit_str_num
try: try:
......
===========================train_params===========================
model_name:det_r18_ct
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_lite_infer=4
Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./train_data/total_text/test/rgb/
null:null
##
trainer:norm_train
norm_train:tools/train.py -c configs/det/det_r18_vd_ct.yml -o Global.print_batch_step=1 Train.loader.shuffle=false
quant_export:null
fpgm_export:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c configs/det/det_r18_vd_ct.yml -o
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o
quant_export:null
fpgm_export:null
distill_export:null
export1:null
export2:null
##
train_model:./inference/det_r18_vd_ct/best_accuracy
infer_export:tools/export_model.py -c configs/det/det_r18_vd_ct.yml -o
infer_quant:False
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
--save_log_path:null
--benchmark:True
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
\ No newline at end of file
...@@ -264,6 +264,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -264,6 +264,11 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
cd ./train_data/ && tar xf XFUND.tar cd ./train_data/ && tar xf XFUND.tar
cd ../ cd ../
fi fi
if [ ${model_name} == "det_r18_ct" ]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet18_vd_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/ct_tipc/total_text_lite2.tar --no-check-certificate
cd ./train_data && tar xf total_text_lite2.tar && ln -s total_text_lite2 total_text && cd ../
fi
elif [ ${MODE} = "whole_train_whole_infer" ];then elif [ ${MODE} = "whole_train_whole_infer" ];then
wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
......
...@@ -127,6 +127,9 @@ class TextDetector(object): ...@@ -127,6 +127,9 @@ class TextDetector(object):
postprocess_params["beta"] = args.beta postprocess_params["beta"] = args.beta
postprocess_params["fourier_degree"] = args.fourier_degree postprocess_params["fourier_degree"] = args.fourier_degree
postprocess_params["box_type"] = args.det_fce_box_type postprocess_params["box_type"] = args.det_fce_box_type
elif self.det_algorithm == "CT":
pre_process_list[0] = {'ScaleAlignedShort': {'short_size': 640}}
postprocess_params['name'] = 'CTPostProcess'
else: else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm)) logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0) sys.exit(0)
...@@ -253,6 +256,9 @@ class TextDetector(object): ...@@ -253,6 +256,9 @@ class TextDetector(object):
elif self.det_algorithm == 'FCE': elif self.det_algorithm == 'FCE':
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
preds['level_{}'.format(i)] = output preds['level_{}'.format(i)] = output
elif self.det_algorithm == "CT":
preds['maps'] = outputs[0]
preds['score'] = outputs[1]
else: else:
raise NotImplementedError raise NotImplementedError
...@@ -260,7 +266,7 @@ class TextDetector(object): ...@@ -260,7 +266,7 @@ class TextDetector(object):
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points'] dt_boxes = post_result[0]['points']
if (self.det_algorithm == "SAST" and self.det_sast_polygon) or ( if (self.det_algorithm == "SAST" and self.det_sast_polygon) or (
self.det_algorithm in ["PSE", "FCE"] and self.det_algorithm in ["PSE", "FCE", "CT"] and
self.postprocess_op.box_type == 'poly'): self.postprocess_op.box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else: else:
......
...@@ -625,7 +625,7 @@ def preprocess(is_train=False): ...@@ -625,7 +625,7 @@ def preprocess(is_train=False):
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE', 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE', 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN', 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
'Gestalt', 'SLANet', 'RobustScanner' 'Gestalt', 'SLANet', 'RobustScanner', 'CT'
] ]
if use_xpu: if use_xpu:
......
...@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer): ...@@ -119,6 +119,7 @@ def main(config, device, logger, vdl_writer):
config['Loss']['ignore_index'] = char_num - 1 config['Loss']['ignore_index'] = char_num - 1
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_sync_bn = config["Global"].get("use_sync_bn", False) use_sync_bn = config["Global"].get("use_sync_bn", False)
if use_sync_bn: if use_sync_bn:
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
...@@ -146,7 +147,7 @@ def main(config, device, logger, vdl_writer): ...@@ -146,7 +147,7 @@ def main(config, device, logger, vdl_writer):
use_amp = config["Global"].get("use_amp", False) use_amp = config["Global"].get("use_amp", False)
amp_level = config["Global"].get("amp_level", 'O2') amp_level = config["Global"].get("amp_level", 'O2')
amp_custom_black_list = config['Global'].get('amp_custom_black_list',[]) amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
if use_amp: if use_amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
...@@ -161,7 +162,10 @@ def main(config, device, logger, vdl_writer): ...@@ -161,7 +162,10 @@ def main(config, device, logger, vdl_writer):
use_dynamic_loss_scaling=use_dynamic_loss_scaling) use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if amp_level == "O2": if amp_level == "O2":
model, optimizer = paddle.amp.decorate( model, optimizer = paddle.amp.decorate(
models=model, optimizers=optimizer, level=amp_level, master_weight=True) models=model,
optimizers=optimizer,
level=amp_level,
master_weight=True)
else: else:
scaler = None scaler = None
...@@ -174,7 +178,8 @@ def main(config, device, logger, vdl_writer): ...@@ -174,7 +178,8 @@ def main(config, device, logger, vdl_writer):
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
eval_class, pre_best_model_dict, logger, vdl_writer, scaler,amp_level, amp_custom_black_list) eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
amp_level, amp_custom_black_list)
def test_reader(config, device, logger): def test_reader(config, device, logger):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册