提交 f8755565 编写于 作者: D dorren

update can transform method and add copyright info for new file

上级 c57effb8
......@@ -42,7 +42,6 @@ Architecture:
bottleneck: True
use_dropout: True
input_channel: 1
Head:
name: CANHead
in_channel: 684
......@@ -66,8 +65,8 @@ Loss:
name: CANLoss
PostProcess:
name: SeqLabelDecode
character: 111
name: CANLabelDecode
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
Metric:
name: CANMetric
......@@ -75,15 +74,18 @@ Metric:
Train:
dataset:
name: PGDataSet
name: SimpleDataSet
data_dir: ./train_data/CROHME/training/images/
transforms:
- DecodeImage:
channel_first: False
- NormalizeImage:
mean: [0,0,0]
std: [1,1,1]
order: 'hwc'
- GrayImageChannelFormat:
normalize: True
inverse: True
- SeqLabelEncode:
- CANLabelEncode:
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
lower: False
- KeepKeys:
......@@ -98,15 +100,18 @@ Train:
Eval:
dataset:
name: PGDataSet
name: SimpleDataSet
data_dir: ./train_data/CROHME/evaluation/images/
transforms:
- DecodeImage:
channel_first: False
- NormalizeImage:
mean: [0,0,0]
std: [1,1,1]
order: 'hwc'
- GrayImageChannelFormat:
normalize: True
inverse: True
- SeqLabelEncode:
- CANLabelEncode:
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
lower: False
- KeepKeys:
......
......@@ -27,7 +27,7 @@ from .make_pse_gt import MakePseGt
from .rec_img_aug import BaseDataAugmentation, RecAug, RecConAug, RecResizeImg, ClsResizeImg, \
SRNRecResizeImg, GrayRecResizeImg, SARRecResizeImg, PRENResizeImg, \
ABINetRecResizeImg, SVTRRecResizeImg, ABINetRecAug, VLRecResizeImg, SPINRecResizeImg, RobustScannerRecResizeImg, \
RFLRecResizeImg, GrayImageChannelFormat
RFLRecResizeImg
from .ssl_img_aug import SSLRotateResize
from .randaugment import RandAugment
from .copy_paste import CopyPaste
......
......@@ -1479,14 +1479,14 @@ class CTLabelEncode(object):
return data
class SeqLabelEncode(BaseRecLabelEncode):
class CANLabelEncode(BaseRecLabelEncode):
def __init__(self,
character_dict_path,
max_text_length=100,
use_space_char=False,
lower=True,
**kwargs):
super(SeqLabelEncode, self).__init__(
super(CANLabelEncode, self).__init__(
max_text_length, character_dict_path, use_space_char, lower)
def encode(self, text_seq):
......
......@@ -498,3 +498,27 @@ class ResizeNormalize(object):
img_numpy = np.array(img).astype("float32")
img_numpy = img_numpy.transpose((2, 0, 1)) / 255
return img_numpy
class GrayImageChannelFormat(object):
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
inverse: inverse gray image
"""
def __init__(self, inverse=False, **kwargs):
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_expanded = np.expand_dims(img_single_channel, 0)
if self.inverse:
data['image'] = np.abs(img_expanded - 1)
else:
data['image'] = img_expanded
data['src_image'] = img
return data
\ No newline at end of file
......@@ -465,36 +465,6 @@ class RobustScannerRecResizeImg(object):
return data
class GrayImageChannelFormat(object):
"""
format gray scale image's channel: (3,h,w) -> (1,h,w)
Args:
normalize: True/False
when True convert image dynamic range [0,255]->[0,1]
inverse: inverse gray image
"""
def __init__(self, normalize=True, inverse=False, **kwargs):
self.normalize = normalize
self.inverse = inverse
def __call__(self, data):
img = data['image']
img_single_channel = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_single_channel = np.expand_dims(img_single_channel, 0)
if self.normalize:
img_single_channel = img_single_channel / 255.0
if self.inverse:
data['image'] = np.abs(img_single_channel - 1).astype('float32')
else:
data['image'] = img_single_channel.astype('float32')
data['src_image'] = img
return data
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
......
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/LBH1024/CAN/models/can.py
"""
import paddle
import paddle.nn as nn
import numpy as np
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import paddle.nn as nn
......@@ -5,14 +23,6 @@ import paddle.nn.functional as F
class Bottleneck(nn.Layer):
'''
ratio: 16
growthRate: 24
reduction: 0.5
bottleneck: True
use_dropout: True
'''
def __init__(self, nChannels, growthRate, use_dropout):
super(Bottleneck, self).__init__()
interChannels = 4 * growthRate
......@@ -78,11 +88,7 @@ class DenseNet(nn.Layer):
def __init__(self, growthRate, reduction, bottleneck, use_dropout,
input_channel, **kwargs):
super(DenseNet, self).__init__()
'''
ratio: 16
growthRate: 24
reduction: 0.5
'''
nDenseBlocks = 16
nChannels = 2 * growthRate
......
from turtle import forward
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/LBH1024/CAN/models/can.py
https://github.com/LBH1024/CAN/models/counting.py
https://github.com/LBH1024/CAN/models/decoder.py
https://github.com/LBH1024/CAN/models/attention.py
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle.nn as nn
import paddle
import math
......
......@@ -37,7 +37,7 @@ from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
from .picodet_postprocess import PicoDetPostProcess
from .ct_postprocess import CTPostProcess
from .drrg_postprocess import DRRGPostprocess
from .rec_postprocess import SeqLabelDecode
from .rec_postprocess import CANLabelDecode
def build_post_process(config, global_config=None):
......@@ -52,7 +52,7 @@ def build_post_process(config, global_config=None):
'TableMasterLabelDecode', 'SPINLabelDecode',
'DistillationSerPostProcess', 'DistillationRePostProcess',
'VLLabelDecode', 'PicoDetPostProcess', 'CTPostProcess',
'RFLLabelDecode', 'DRRGPostprocess', 'SeqLabelDecode'
'RFLLabelDecode', 'DRRGPostprocess', 'CANLabelDecode'
]
if config['name'] == 'PSEPostProcess':
......
......@@ -899,12 +899,12 @@ class VLLabelDecode(BaseRecLabelDecode):
return text, label
class SeqLabelDecode(BaseRecLabelDecode):
class CANLabelDecode(BaseRecLabelDecode):
""" Convert between latex-symbol and symbol-index """
def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SeqLabelDecode, self).__init__(character_dict_path,
super(CANLabelDecode, self).__init__(character_dict_path,
use_space_char)
def decode(self, text_index, preds_prob=None):
......
......@@ -42,7 +42,6 @@ Architecture:
bottleneck: True
use_dropout: True
input_channel: 1
Head:
name: CANHead
in_channel: 684
......@@ -66,8 +65,8 @@ Loss:
name: CANLoss
PostProcess:
name: SeqLabelDecode
character: 111
name: CANLabelDecode
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
Metric:
name: CANMetric
......@@ -75,20 +74,23 @@ Metric:
Train:
dataset:
name: PGDataSet
data_dir: ./train_data/CROHME_lite/training/images/
name: SimpleDataSet
data_dir: ./train_data/CROHME/training/images/
transforms:
- DecodeImage:
channel_first: False
- NormalizeImage:
mean: [0,0,0]
std: [1,1,1]
order: 'hwc'
- GrayImageChannelFormat:
normalize: True
inverse: True
- SeqLabelEncode:
- CANLabelEncode:
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
lower: False
- KeepKeys:
keep_keys: ['image', 'label']
label_file_list: ["./train_data/CROHME_lite/training/labels.txt"]
label_file_list: ["./train_data/CROHME/training/labels.txt"]
loader:
shuffle: True
batch_size_per_card: 8
......@@ -98,20 +100,23 @@ Train:
Eval:
dataset:
name: PGDataSet
data_dir: ./train_data/CROHME_lite/evaluation/images/
name: SimpleDataSet
data_dir: ./train_data/CROHME/evaluation/images/
transforms:
- DecodeImage:
channel_first: False
- NormalizeImage:
mean: [0,0,0]
std: [1,1,1]
order: 'hwc'
- GrayImageChannelFormat:
normalize: True
inverse: True
- SeqLabelEncode:
- CANLabelEncode:
character_dict_path: ppocr/utils/dict/latex_symbol_dict.txt
lower: False
- KeepKeys:
keep_keys: ['image', 'label']
label_file_list: ["./train_data/CROHME_lite/evaluation/labels.txt"]
label_file_list: ["./train_data/CROHME/evaluation/labels.txt"]
loader:
shuffle: False
drop_last: False
......
===========================train_params===========================
model_name:rec_d28_can
python:python3.7
gpu_list:0|0,1
python:python
gpu_list:0|0
Global.use_gpu:True|True
Global.auto_cast:null
Global.epoch_num:lite_train_lite_infer=2|whole_train_whole_infer=240
......
......@@ -262,7 +262,6 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
cd ./pretrain_models/ && tar xf can_train.tar && cd ../
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/CROHME_lite.tar --no-check-certificate
cd ./train_data/ && tar xf CROHME_lite.tar && cd ../
fi
if [ ${model_name} == "layoutxlm_ser" ]; then
${python_name} -m pip install -r ppstructure/kie/requirements.txt
......
......@@ -111,7 +111,7 @@ class TextRecognizer(object):
elif self.rec_algorithm == "CAN":
self.inverse = args.rec_image_inverse
postprocess_params = {
'name': 'SeqLabelDecode',
'name': 'CANLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册