未验证 提交 85aeae71 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #3002 from littletomatodonkey/dyg/add_distillation

add distillation
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: false
distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.0005
warmup_epoch: 5
regularizer:
name: L2
factor: 1.0e-05
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2]
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 0.00001
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: small
small_stride: [1, 2, 2, 2]
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 48
Head:
name: CTCHead
fc_decay: 0.00001
Loss:
name: CombinedLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_sections: 1
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 8
...@@ -13,28 +13,37 @@ ...@@ -13,28 +13,37 @@
# limitations under the License. # limitations under the License.
import copy import copy
import paddle
import paddle.nn as nn
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
def build_loss(config): # rec loss
# det loss from .rec_ctc_loss import CTCLoss
from .det_db_loss import DBLoss from .rec_att_loss import AttentionLoss
from .det_east_loss import EASTLoss from .rec_srn_loss import SRNLoss
from .det_sast_loss import SASTLoss
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
# rec loss # basic loss function
from .rec_ctc_loss import CTCLoss from .basic_loss import DistanceLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
# cls loss # combined loss function
from .cls_loss import ClsLoss from .combined_loss import CombinedLoss
# e2e loss
from .e2e_pg_loss import PGLoss def build_loss(config):
support_dict = [ support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss', 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss'] 'SRNLoss', 'PGLoss', 'CombinedLoss'
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format( assert module_name in support_dict, Exception('loss only support {}'.format(
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer):
def __init__(self, epsilon=None):
super().__init__()
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, x, label):
loss_dict = {}
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
def __init__(self, act=None):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
if act == "softmax":
self.act = nn.Softmax(axis=-1)
elif act == "sigmoid":
self.act = nn.Sigmoid()
else:
self.act = None
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
return loss
class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
"""
def __init__(self, mode="l2", **kargs):
super().__init__()
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l2":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
def forward(self, x, y):
return self.loss_func(x, y)
...@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer): ...@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
super(ClsLoss, self).__init__() super(ClsLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean') self.loss_func = nn.CrossEntropyLoss(reduction='mean')
def __call__(self, predicts, batch): def forward(self, predicts, batch):
label = batch[1] label = batch[1]
loss = self.loss_func(input=predicts, label=label) loss = self.loss_func(input=predicts, label=label)
return {'loss': loss} return {'loss': loss}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss
class CombinedLoss(nn.Layer):
"""
CombinedLoss:
a combionation of loss function
"""
def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs):
loss_dict = {}
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
loss = {
"{}_{}".format(key, idx): loss[key] * weight
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
class DistillationDMLLoss(DMLLoss):
"""
"""
def __init__(self, model_name_pairs=[], act=None, key=None,
name="loss_dml"):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict
class DistillationCTCLoss(CTCLoss):
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationDistanceLoss(DistanceLoss):
"""
"""
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
key]
else:
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
...@@ -25,7 +25,7 @@ class CTCLoss(nn.Layer): ...@@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
super(CTCLoss, self).__init__() super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none') self.loss_func = nn.CTCLoss(blank=0, reduction='none')
def __call__(self, predicts, batch): def forward(self, predicts, batch):
predicts = predicts.transpose((1, 0, 2)) predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64') preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
......
...@@ -19,20 +19,23 @@ from __future__ import unicode_literals ...@@ -19,20 +19,23 @@ from __future__ import unicode_literals
import copy import copy
__all__ = ['build_metric'] __all__ = ["build_metric"]
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
from .distillation_metric import DistillationMetric
def build_metric(config):
from .det_metric import DetMetric
from .rec_metric import RecMetric
from .cls_metric import ClsMetric
from .e2e_metric import E2EMetric
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric'] def build_metric(config):
support_dict = [
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop("name")
assert module_name in support_dict, Exception( assert module_name in support_dict, Exception(
'metric only support {}'.format(support_dict)) "metric only support {}".format(support_dict))
module_class = eval(module_name)(**config) module_class = eval(module_name)(**config)
return module_class return module_class
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import copy
from .rec_metric import RecMetric
from .det_metric import DetMetric
from .e2e_metric import E2EMetric
from .cls_metric import ClsMetric
class DistillationMetric(object):
def __init__(self,
key=None,
base_metric_name="RecMetric",
main_indicator='acc',
**kwargs):
self.main_indicator = main_indicator
self.key = key
self.main_indicator = main_indicator
self.base_metric_name = base_metric_name
self.kwargs = kwargs
self.metrics = None
def _init_metrcis(self, preds):
self.metrics = dict()
mod = importlib.import_module(__name__)
for key in preds:
self.metrics[key] = getattr(mod, self.base_metric_name)(
main_indicator=self.main_indicator, **self.kwargs)
self.metrics[key].reset()
def __call__(self, preds, *args, **kwargs):
assert isinstance(preds, dict)
if self.metrics is None:
self._init_metrcis(preds)
output = dict()
for key in preds:
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def get_metric(self):
"""
return metrics {
'acc': 0,
'norm_edit_dis': 0,
}
"""
output = dict()
for key in self.metrics:
metric = self.metrics[key].get_metric()
# main indicator
if key == self.key:
output.update(metric)
else:
for sub_key in metric:
output["{}_{}".format(key, sub_key)] = metric[sub_key]
return output
def reset(self):
for key in self.metrics:
self.metrics[key].reset()
...@@ -13,12 +13,20 @@ ...@@ -13,12 +13,20 @@
# limitations under the License. # limitations under the License.
import copy import copy
import importlib
from .base_model import BaseModel
from .distillation_model import DistillationModel
__all__ = ['build_model'] __all__ = ['build_model']
def build_model(config): def build_model(config):
from .base_model import BaseModel
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_class = BaseModel(config) if not "name" in config:
return module_class arch = BaseModel(config)
\ No newline at end of file else:
name = config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, name)(config)
return arch
...@@ -32,7 +32,6 @@ class BaseModel(nn.Layer): ...@@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
config (dict): the super parameters for module. config (dict): the super parameters for module.
""" """
super(BaseModel, self).__init__() super(BaseModel, self).__init__()
in_channels = config.get('in_channels', 3) in_channels = config.get('in_channels', 3)
model_type = config['model_type'] model_type = config['model_type']
# build transfrom, # build transfrom,
...@@ -68,14 +67,23 @@ class BaseModel(nn.Layer): ...@@ -68,14 +67,23 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"]) self.head = build_head(config["Head"])
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None): def forward(self, x, data=None):
y = dict()
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) x = self.backbone(x)
y["backbone_out"] = x
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x
if data is None: if data is None:
x = self.head(x) x = self.head(x)
else: else:
x = self.head(x, data) x = self.head(x, data)
return x y["head_out"] = x
if self.return_all_feats:
return y
else:
return x
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import init_model
__all__ = ['DistillationModel']
class DistillationModel(nn.Layer):
def __init__(self, config):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super().__init__()
self.model_list = []
self.model_name_list = []
for key in config["Models"]:
model_config = config["Models"][key]
freeze_params = False
pretrained = None
if "freeze_params" in model_config:
freeze_params = model_config.pop("freeze_params")
if "pretrained" in model_config:
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
init_model(model, path=pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
self.model_list.append(self.add_sublayer(key, model))
self.model_name_list.append(key)
def forward(self, x):
result_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
result_dict[model_name] = self.model_list[idx](x)
return result_dict
...@@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer): ...@@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
padding=1, padding=1,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish')
name='conv1')
self.stages = [] self.stages = []
self.out_channels = [] self.out_channels = []
...@@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer): ...@@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
kernel_size=k, kernel_size=k,
stride=s, stride=s,
use_se=se, use_se=se,
act=nl, act=nl))
name="conv" + str(i + 2)))
inplanes = make_divisible(scale * c) inplanes = make_divisible(scale * c)
i += 1 i += 1
block_list.append( block_list.append(
...@@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer): ...@@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
padding=0, padding=0,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish'))
name='conv_last'))
self.stages.append(nn.Sequential(*block_list)) self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze)) self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages): for i, stage in enumerate(self.stages):
...@@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer): ...@@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
padding, padding,
groups=1, groups=1,
if_act=True, if_act=True,
act=None, act=None):
name=None):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.if_act = if_act self.if_act = if_act
self.act = act self.act = act
...@@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer): ...@@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
stride=stride, stride=stride,
padding=padding, padding=padding,
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False) bias_attr=False)
self.bn = nn.BatchNorm( self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
num_channels=out_channels,
act=None,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
...@@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer): ...@@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
kernel_size, kernel_size,
stride, stride,
use_se, use_se,
act=None, act=None):
name=''):
super(ResidualUnit, self).__init__() super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se self.if_se = use_se
...@@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer): ...@@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
if_act=True, if_act=True,
act=act, act=act)
name=name + "_expand")
self.bottleneck_conv = ConvBNLayer( self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels, in_channels=mid_channels,
out_channels=mid_channels, out_channels=mid_channels,
...@@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer): ...@@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
padding=int((kernel_size - 1) // 2), padding=int((kernel_size - 1) // 2),
groups=mid_channels, groups=mid_channels,
if_act=True, if_act=True,
act=act, act=act)
name=name + "_depthwise")
if self.if_se: if self.if_se:
self.mid_se = SEModule(mid_channels, name=name + "_se") self.mid_se = SEModule(mid_channels)
self.linear_conv = ConvBNLayer( self.linear_conv = ConvBNLayer(
in_channels=mid_channels, in_channels=mid_channels,
out_channels=out_channels, out_channels=out_channels,
...@@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer): ...@@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
stride=1, stride=1,
padding=0, padding=0,
if_act=False, if_act=False,
act=None, act=None)
name=name + "_linear")
def forward(self, inputs): def forward(self, inputs):
x = self.expand_conv(inputs) x = self.expand_conv(inputs)
...@@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer): ...@@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
class SEModule(nn.Layer): class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4, name=""): def __init__(self, in_channels, reduction=4):
super(SEModule, self).__init__() super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1) self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D( self.conv1 = nn.Conv2D(
...@@ -266,17 +251,13 @@ class SEModule(nn.Layer): ...@@ -266,17 +251,13 @@ class SEModule(nn.Layer):
out_channels=in_channels // reduction, out_channels=in_channels // reduction,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0)
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
self.conv2 = nn.Conv2D( self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction, in_channels=in_channels // reduction,
out_channels=in_channels, out_channels=in_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0)
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
def forward(self, inputs): def forward(self, inputs):
outputs = self.avg_pool(inputs) outputs = self.avg_pool(inputs)
......
...@@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer): ...@@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
padding=1, padding=1,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish')
name='conv1')
i = 0 i = 0
block_list = [] block_list = []
inplanes = make_divisible(inplanes * scale) inplanes = make_divisible(inplanes * scale)
...@@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer): ...@@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
kernel_size=k, kernel_size=k,
stride=s, stride=s,
use_se=se, use_se=se,
act=nl, act=nl))
name='conv' + str(i + 2)))
inplanes = make_divisible(scale * c) inplanes = make_divisible(scale * c)
i += 1 i += 1
self.blocks = nn.Sequential(*block_list) self.blocks = nn.Sequential(*block_list)
...@@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer): ...@@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
padding=0, padding=0,
groups=1, groups=1,
if_act=True, if_act=True,
act='hardswish', act='hardswish')
name='conv_last')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze) self.out_channels = make_divisible(scale * cls_ch_squeeze)
......
...@@ -23,10 +23,10 @@ import paddle.nn.functional as F ...@@ -23,10 +23,10 @@ import paddle.nn.functional as F
from paddle import ParamAttr from paddle import ParamAttr
def get_bias_attr(k, name): def get_bias_attr(k):
stdv = 1.0 / math.sqrt(k * 1.0) stdv = 1.0 / math.sqrt(k * 1.0)
initializer = paddle.nn.initializer.Uniform(-stdv, stdv) initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr") bias_attr = ParamAttr(initializer=initializer)
return bias_attr return bias_attr
...@@ -38,18 +38,14 @@ class Head(nn.Layer): ...@@ -38,18 +38,14 @@ class Head(nn.Layer):
out_channels=in_channels // 4, out_channels=in_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr(name=name_list[0] + '.w_0'), weight_attr=ParamAttr(),
bias_attr=False) bias_attr=False)
self.conv_bn1 = nn.BatchNorm( self.conv_bn1 = nn.BatchNorm(
num_channels=in_channels // 4, num_channels=in_channels // 4,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name_list[1] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)), initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=name_list[1] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)), initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[1] + '.w_1',
moving_variance_name=name_list[1] + '.w_2',
act='relu') act='relu')
self.conv2 = nn.Conv2DTranspose( self.conv2 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
...@@ -57,19 +53,14 @@ class Head(nn.Layer): ...@@ -57,19 +53,14 @@ class Head(nn.Layer):
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[2] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) bias_attr=get_bias_attr(in_channels // 4))
self.conv_bn2 = nn.BatchNorm( self.conv_bn2 = nn.BatchNorm(
num_channels=in_channels // 4, num_channels=in_channels // 4,
param_attr=ParamAttr( param_attr=ParamAttr(
name=name_list[3] + '.w_0',
initializer=paddle.nn.initializer.Constant(value=1.0)), initializer=paddle.nn.initializer.Constant(value=1.0)),
bias_attr=ParamAttr( bias_attr=ParamAttr(
name=name_list[3] + '.b_0',
initializer=paddle.nn.initializer.Constant(value=1e-4)), initializer=paddle.nn.initializer.Constant(value=1e-4)),
moving_mean_name=name_list[3] + '.w_1',
moving_variance_name=name_list[3] + '.w_2',
act="relu") act="relu")
self.conv3 = nn.Conv2DTranspose( self.conv3 = nn.Conv2DTranspose(
in_channels=in_channels // 4, in_channels=in_channels // 4,
...@@ -77,10 +68,8 @@ class Head(nn.Layer): ...@@ -77,10 +68,8 @@ class Head(nn.Layer):
kernel_size=2, kernel_size=2,
stride=2, stride=2,
weight_attr=ParamAttr( weight_attr=ParamAttr(
name=name_list[4] + '.w_0',
initializer=paddle.nn.initializer.KaimingUniform()), initializer=paddle.nn.initializer.KaimingUniform()),
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), bias_attr=get_bias_attr(in_channels // 4), )
)
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
......
...@@ -23,14 +23,12 @@ from paddle import ParamAttr, nn ...@@ -23,14 +23,12 @@ from paddle import ParamAttr, nn
from paddle.nn import functional as F from paddle.nn import functional as F
def get_para_bias_attr(l2_decay, k, name): def get_para_bias_attr(l2_decay, k):
regularizer = paddle.regularizer.L2Decay(l2_decay) regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0) stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv) initializer = nn.initializer.Uniform(-stdv, stdv)
weight_attr = ParamAttr( weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
regularizer=regularizer, initializer=initializer, name=name + "_w_attr") bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
return [weight_attr, bias_attr] return [weight_attr, bias_attr]
...@@ -38,13 +36,12 @@ class CTCHead(nn.Layer): ...@@ -38,13 +36,12 @@ class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs): def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
super(CTCHead, self).__init__() super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr( weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc') l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear( self.fc = nn.Linear(
in_channels, in_channels,
out_channels, out_channels,
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=bias_attr, bias_attr=bias_attr)
name='ctc_fc')
self.out_channels = out_channels self.out_channels = out_channels
def forward(self, x, labels=None): def forward(self, x, labels=None):
......
...@@ -32,61 +32,53 @@ class DBFPN(nn.Layer): ...@@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
in_channels=in_channels[0], in_channels=in_channels[0],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_51.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in3_conv = nn.Conv2D( self.in3_conv = nn.Conv2D(
in_channels=in_channels[1], in_channels=in_channels[1],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_50.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in4_conv = nn.Conv2D( self.in4_conv = nn.Conv2D(
in_channels=in_channels[2], in_channels=in_channels[2],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_49.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.in5_conv = nn.Conv2D( self.in5_conv = nn.Conv2D(
in_channels=in_channels[3], in_channels=in_channels[3],
out_channels=self.out_channels, out_channels=self.out_channels,
kernel_size=1, kernel_size=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_48.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p5_conv = nn.Conv2D( self.p5_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_52.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p4_conv = nn.Conv2D( self.p4_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_53.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p3_conv = nn.Conv2D( self.p3_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_54.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
self.p2_conv = nn.Conv2D( self.p2_conv = nn.Conv2D(
in_channels=self.out_channels, in_channels=self.out_channels,
out_channels=self.out_channels // 4, out_channels=self.out_channels // 4,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
weight_attr=ParamAttr( weight_attr=ParamAttr(initializer=weight_attr),
name='conv2d_55.w_0', initializer=weight_attr),
bias_attr=False) bias_attr=False)
def forward(self, x): def forward(self, x):
......
...@@ -21,18 +21,19 @@ import copy ...@@ -21,18 +21,19 @@ import copy
__all__ = ['build_post_process'] __all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess' 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
......
...@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode): ...@@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, label=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, label=label, *args, **kwargs)
return output
class AttnLabelDecode(BaseRecLabelDecode): class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
...@@ -23,6 +23,8 @@ import six ...@@ -23,6 +23,8 @@ import six
import paddle import paddle
from ppocr.utils.logging import get_logger
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] __all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
...@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger): ...@@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path)) raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): def init_model(config, model, optimizer=None, lr_scheduler=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = paddle.static.load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
weight_name = weight_name.replace('binarize', '').replace(
'thresh', '') # for DB
if weight_name in pre_state_dict.keys():
# logger.info('Load weight: {}, shape: {}'.format(
# weight_name, pre_state_dict[weight_name].shape))
if 'encoder_rnn' in key:
# delete axis which is 1
pre_state_dict[weight_name] = pre_state_dict[
weight_name].squeeze()
# change axis
if len(pre_state_dict[weight_name].shape) > 1:
pre_state_dict[weight_name] = pre_state_dict[
weight_name].transpose((1, 0))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_state_dict(param_state_dict)
return
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
logger = get_logger()
global_config = config['Global'] global_config = config['Global']
checkpoints = global_config.get('checkpoints') checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model') pretrained_model = global_config.get('pretrained_model')
...@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): ...@@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {}) best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict: if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1 best_model_dict['start_epoch'] = states_dict['epoch'] + 1
logger.info("resume from {}".format(checkpoints)) logger.info("resume from {}".format(checkpoints))
elif pretrained_model: elif pretrained_model:
load_static_weights = global_config.get('load_static_weights', False)
if not isinstance(pretrained_model, list): if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model] pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list): for pretrained in pretrained_model:
load_static_weights = [load_static_weights] * len(pretrained_model) if not (os.path.isdir(pretrained) or
for idx, pretrained in enumerate(pretrained_model): os.path.exists(pretrained + '.pdparams')):
load_static = load_static_weights[idx] raise ValueError("Model pretrain path {} does not "
load_dygraph_pretrain( "exists.".format(pretrained))
model, logger, path=pretrained, load_static_weights=load_static) param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format( logger.info("load pretrained model from {}".format(
pretrained_model)) pretrained_model))
else: else:
......
...@@ -49,7 +49,7 @@ def main(): ...@@ -49,7 +49,7 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
best_model_dict = init_model(config, model, logger) best_model_dict = init_model(config, model)
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items(): for k, v in best_model_dict.items():
......
...@@ -17,7 +17,7 @@ import sys ...@@ -17,7 +17,7 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, "..")))
import argparse import argparse
...@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger ...@@ -31,32 +31,12 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser from tools.program import load_config, merge_config, ArgsParser
def main(): def export_single_model(model, arch_config, save_path, logger):
FLAGS = ArgsParser().parse_args() if arch_config["algorithm"] == "SRN":
config = load_config(FLAGS.config) max_text_length = arch_config["Head"]["max_text_length"]
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config['PostProcess'],
config['Global'])
# build model
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
init_model(config, model, logger)
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
if config['Architecture']['algorithm'] == "SRN":
max_text_length = config['Architecture']['Head']['max_text_length']
other_shape = [ other_shape = [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 64, 256], dtype='float32'), [ shape=[None, 1, 64, 256], dtype="float32"), [
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 256, 1], shape=[None, 256, 1],
dtype="int64"), paddle.static.InputSpec( dtype="int64"), paddle.static.InputSpec(
...@@ -71,24 +51,66 @@ def main(): ...@@ -71,24 +51,66 @@ def main():
model = to_static(model, input_spec=other_shape) model = to_static(model, input_spec=other_shape)
else: else:
infer_shape = [3, -1, -1] infer_shape = [3, -1, -1]
if config['Architecture']['model_type'] == "rec": if arch_config["model_type"] == "rec":
infer_shape = [3, 32, -1] # for rec model, H must be 32 infer_shape = [3, 32, -1] # for rec model, H must be 32
if 'Transform' in config['Architecture'] and config['Architecture'][ if "Transform" in arch_config and arch_config[
'Transform'] is not None and config['Architecture'][ "Transform"] is not None and arch_config["Transform"][
'Transform']['name'] == 'TPS': "name"] == "TPS":
logger.info( logger.info(
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training' "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
) )
infer_shape[-1] = 100 infer_shape[-1] = 100
model = to_static( model = to_static(
model, model,
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None] + infer_shape, dtype='float32') shape=[None] + infer_shape, dtype="float32")
]) ])
paddle.jit.save(model, save_path) paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path)) logger.info("inference model is saved to {}".format(save_path))
return
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
logger = get_logger()
# build post process
post_process_class = build_post_process(config["PostProcess"],
config["Global"])
# build model
# for rec algorithm
if hasattr(post_process_class, "character"):
char_num = len(getattr(post_process_class, "character"))
if config["Architecture"]["algorithm"] in ["Distillation",
]: # distillation model
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model)
model.eval()
save_path = config["Global"]["save_inference_dir"]
arch_config = config["Architecture"]
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list):
sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(model, arch_config, save_path, logger)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
......
...@@ -61,7 +61,7 @@ def main(): ...@@ -61,7 +61,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess']) post_process_class = build_post_process(config['PostProcess'])
......
...@@ -68,7 +68,7 @@ def main(): ...@@ -68,7 +68,7 @@ def main():
# build model # build model
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import os import os
import sys import sys
import json
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
...@@ -46,12 +47,18 @@ def main(): ...@@ -46,12 +47,18 @@ def main():
# build model # build model
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
config['Architecture']["Head"]['out_channels'] = len( char_num = len(getattr(post_process_class, 'character'))
getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
init_model(config, model, logger) init_model(config, model)
# create data ops # create data ops
transforms = [] transforms = []
...@@ -107,11 +114,23 @@ def main(): ...@@ -107,11 +114,23 @@ def main():
else: else:
preds = model(images) preds = model(images)
post_result = post_process_class(preds) post_result = post_process_class(preds)
for rec_reuslt in post_result: info = None
logger.info('\t result: {}'.format(rec_reuslt)) if isinstance(post_result, dict):
if len(rec_reuslt) >= 2: rec_info = dict()
fout.write(file + "\t" + rec_reuslt[0] + "\t" + str( for key in post_result:
rec_reuslt[1]) + "\n") if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
"score": post_result[key][0][1],
}
info = json.dumps(rec_info)
else:
if len(post_result[0]) >= 2:
info = post_result[0][0] + "\t" + str(post_result[0][1])
if info is not None:
logger.info("\t result: {}".format(info))
fout.write(file + "\t" + info)
logger.info("success!") logger.info("success!")
......
...@@ -386,7 +386,7 @@ def preprocess(is_train=False): ...@@ -386,7 +386,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm'] alg = config['Architecture']['algorithm']
assert alg in [ assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet' 'CLS', 'PGNet', 'Distillation'
] ]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
...@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer): ...@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
if config['Global']['distributed']: if config['Global']['distributed']:
model = paddle.DataParallel(model) model = paddle.DataParallel(model)
...@@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer): ...@@ -90,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, optimizer)
logger.info('train dataloader has {} iters'.format(len(train_dataloader))) logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None: if valid_dataloader is not None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册