提交 ed02b91d 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add distillation function

上级 551a6827
......@@ -13,28 +13,37 @@
# limitations under the License.
import copy
import paddle
import paddle.nn as nn
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
def build_loss(config):
# det loss
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
# cls loss
from .cls_loss import ClsLoss
# e2e loss
from .e2e_pg_loss import PGLoss
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
# basic loss function
from .basic_loss import DistanceLoss
# cls loss
from .cls_loss import ClsLoss
# combined loss function
from .combined_loss import CombinedLoss
# e2e loss
from .e2e_pg_loss import PGLoss
def build_loss(config):
support_dict = [
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
'SRNLoss', 'PGLoss']
'SRNLoss', 'PGLoss', 'CombinedLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
assert module_name in support_dict, Exception('loss only support {}'.format(
......
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import L1Loss
from paddle.nn import MSELoss as L2Loss
from paddle.nn import SmoothL1Loss
class CELoss(nn.Layer):
def __init__(self, name="loss_ce", epsilon=None):
super().__init__()
self.name = name
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
epsilon = None
self.epsilon = epsilon
def _labelsmoothing(self, target, class_num):
if target.shape[-1] != class_num:
one_hot_target = F.one_hot(target, class_num)
else:
one_hot_target = target
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
return soft_target
def forward(self, x, label):
loss_dict = {}
if self.epsilon is not None:
class_num = x.shape[-1]
label = self._labelsmoothing(label, class_num)
x = -F.log_softmax(x, axis=-1)
loss = paddle.sum(x * label, axis=-1)
else:
if label.shape[-1] == x.shape[-1]:
label = F.softmax(label, axis=-1)
soft_label = True
else:
soft_label = False
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
loss_dict[self.name] = paddle.mean(loss)
return loss_dict
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
def __init__(self, name="loss_dml"):
super().__init__()
self.name = name
def forward(self, out1, out2):
loss_dict = {}
soft_out1 = F.softmax(out1, axis=-1)
log_soft_out1 = paddle.log(soft_out1)
soft_out2 = F.softmax(out2, axis=-1)
log_soft_out2 = paddle.log(soft_out2)
loss = (F.kl_div(
log_soft_out1, soft_out2, reduction='batchmean') + F.kl_div(
log_soft_out2, soft_out1, reduction='batchmean')) / 2.0
loss_dict[self.name] = loss
return loss_dict
class DistanceLoss(nn.Layer):
"""
DistanceLoss:
mode: loss mode
name: loss key in the output dict
"""
def __init__(self, mode="l2", name="loss_dist", **kargs):
assert mode in ["l1", "l2", "smooth_l1"]
if mode == "l1":
self.loss_func = nn.L1Loss(**kargs)
elif mode == "l1":
self.loss_func = nn.MSELoss(**kargs)
elif mode == "smooth_l1":
self.loss_func = nn.SmoothL1Loss(**kargs)
self.name = "{}_{}".format(name, mode)
def forward(self, x, y):
return {self.name: self.loss_func(x, y)}
......@@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
super(ClsLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
def __call__(self, predicts, batch):
def forward(self, predicts, batch):
label = batch[1]
loss = self.loss_func(input=predicts, label=label)
return {'loss': loss}
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
class CombinedLoss(nn.Layer):
"""
CombinedLoss:
a combionation of loss function
"""
def __init__(self, loss_config_list=None):
super().__init__()
self.loss_func = []
self.loss_weight = []
assert isinstance(loss_config_list, list), (
'operator config should be a list')
for config in loss_config_list:
assert isinstance(config,
dict) and len(config) == 1, "yaml format error"
name = list(config)[0]
param = config[name]
assert "weight" in param, "weight must be in param, but param just contains {}".format(
param.keys())
self.loss_weight.append(param.pop("weight"))
self.loss_func.append(eval(name)(**param))
def forward(self, input, batch, **kargs):
loss_dict = {}
for idx, loss_func in enumerate(self.loss_func):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx]
loss = {
"{}_{}".format(key, idx): loss[key] * weight
for key in loss
}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
return loss_dict
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import paddle
import paddle.nn as nn
from .rec_ctc_loss import CTCLoss
from .basic_loss import DMLLoss
class DistillationDMLLoss(DMLLoss):
"""
"""
def __init__(self,
model_name_list1=[],
model_name_list2=[],
key=None,
name="loss_dml"):
super().__init__(name=name)
if not isinstance(model_name_list1, (list, )):
model_name_list1 = [model_name_list1]
if not isinstance(model_name_list2, (list, )):
model_name_list2 = [model_name_list2]
assert len(model_name_list1) == len(model_name_list2)
self.model_name_list1 = model_name_list1
self.model_name_list2 = model_name_list2
self.key = key
def forward(self, predicts, batch):
loss_dict = dict()
for idx in range(len(self.model_name_list1)):
out1 = predicts[self.model_name_list1[idx]]
out2 = predicts[self.model_name_list2[idx]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
assert len(loss) == 1
loss = list(loss.values())[0]
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict
class DistillationCTCLoss(CTCLoss):
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
super().__init__()
self.model_name_list = model_name_list
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for model_name in self.model_name_list:
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
if isinstance(loss, dict):
assert len(loss) == 1
loss = list(loss.values())[0]
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
......@@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
def __call__(self, predicts, batch):
def forward(self, predicts, batch):
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
......
......@@ -13,12 +13,20 @@
# limitations under the License.
import copy
import importlib
from .base_model import BaseModel
from .distillation_model import DistillationModel
__all__ = ['build_model']
def build_model(config):
from .base_model import BaseModel
config = copy.deepcopy(config)
module_class = BaseModel(config)
return module_class
\ No newline at end of file
if not "name" in config:
arch = BaseModel(config)
else:
name = config.pop("name")
mod = importlib.import_module(__name__)
arch = getattr(mod, name)(config)
return arch
......@@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
config (dict): the super parameters for module.
"""
super(BaseModel, self).__init__()
in_channels = config.get('in_channels', 3)
model_type = config['model_type']
# build transfrom,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import load_dygraph_pretrain
__all__ = ['DistillationModel']
class DistillationModel(nn.Layer):
def __init__(self, config):
"""
the module for OCR distillation.
args:
config (dict): the super parameters for module.
"""
super().__init__()
freeze_params = config["freeze_params"]
pretrained = config["pretrained"]
if not isinstance(freeze_params, list):
freeze_params = [freeze_params]
assert len(config["Models"]) == len(freeze_params)
if not isinstance(pretrained, list):
pretrained = [pretrained] * len(config["Models"])
assert len(config["Models"]) == len(pretrained)
self.model_dict = dict()
index = 0
for key in config["Models"]:
model_config = config["Models"][key]
model = BaseModel(model_config)
if pretrained[index] is not None:
load_dygraph_pretrain(model, path=pretrained[index])
if freeze_params[index]:
for param in model.parameters():
param.trainable = False
self.model_dict[key] = self.add_sublayer(key, model)
index += 1
def forward(self, x):
result_dict = dict()
for key in self.model_dict:
result_dict[key] = self.model_dict[key](x)
return result_dict
......@@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
padding=1,
groups=1,
if_act=True,
act='hardswish',
name='conv1')
act='hardswish')
self.stages = []
self.out_channels = []
......@@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name="conv" + str(i + 2)))
act=nl))
inplanes = make_divisible(scale * c)
i += 1
block_list.append(
......@@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
padding=0,
groups=1,
if_act=True,
act='hardswish',
name='conv_last'))
act='hardswish'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
for i, stage in enumerate(self.stages):
......@@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
padding,
groups=1,
if_act=True,
act=None,
name=None):
act=None):
super(ConvBNLayer, self).__init__()
self.if_act = if_act
self.act = act
......@@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(name=name + '_weights'),
bias_attr=False)
self.bn = nn.BatchNorm(
num_channels=out_channels,
act=None,
param_attr=ParamAttr(name=name + "_bn_scale"),
bias_attr=ParamAttr(name=name + "_bn_offset"),
moving_mean_name=name + "_bn_mean",
moving_variance_name=name + "_bn_variance")
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
def forward(self, x):
x = self.conv(x)
......@@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
kernel_size,
stride,
use_se,
act=None,
name=''):
act=None):
super(ResidualUnit, self).__init__()
self.if_shortcut = stride == 1 and in_channels == out_channels
self.if_se = use_se
......@@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
stride=1,
padding=0,
if_act=True,
act=act,
name=name + "_expand")
act=act)
self.bottleneck_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=mid_channels,
......@@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
padding=int((kernel_size - 1) // 2),
groups=mid_channels,
if_act=True,
act=act,
name=name + "_depthwise")
act=act)
if self.if_se:
self.mid_se = SEModule(mid_channels, name=name + "_se")
self.mid_se = SEModule(mid_channels)
self.linear_conv = ConvBNLayer(
in_channels=mid_channels,
out_channels=out_channels,
......@@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
stride=1,
padding=0,
if_act=False,
act=None,
name=name + "_linear")
act=None)
def forward(self, inputs):
x = self.expand_conv(inputs)
......@@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
class SEModule(nn.Layer):
def __init__(self, in_channels, reduction=4, name=""):
def __init__(self, in_channels, reduction=4):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2D(1)
self.conv1 = nn.Conv2D(
......@@ -266,17 +251,13 @@ class SEModule(nn.Layer):
out_channels=in_channels // reduction,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name=name + "_1_weights"),
bias_attr=ParamAttr(name=name + "_1_offset"))
padding=0)
self.conv2 = nn.Conv2D(
in_channels=in_channels // reduction,
out_channels=in_channels,
kernel_size=1,
stride=1,
padding=0,
weight_attr=ParamAttr(name + "_2_weights"),
bias_attr=ParamAttr(name=name + "_2_offset"))
padding=0)
def forward(self, inputs):
outputs = self.avg_pool(inputs)
......
......@@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
padding=1,
groups=1,
if_act=True,
act='hardswish',
name='conv1')
act='hardswish')
i = 0
block_list = []
inplanes = make_divisible(inplanes * scale)
......@@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
kernel_size=k,
stride=s,
use_se=se,
act=nl,
name='conv' + str(i + 2)))
act=nl))
inplanes = make_divisible(scale * c)
i += 1
self.blocks = nn.Sequential(*block_list)
......@@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
padding=0,
groups=1,
if_act=True,
act='hardswish',
name='conv_last')
act='hardswish')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
self.out_channels = make_divisible(scale * cls_ch_squeeze)
......
......@@ -23,14 +23,12 @@ from paddle import ParamAttr, nn
from paddle.nn import functional as F
def get_para_bias_attr(l2_decay, k, name):
def get_para_bias_attr(l2_decay, k):
regularizer = paddle.regularizer.L2Decay(l2_decay)
stdv = 1.0 / math.sqrt(k * 1.0)
initializer = nn.initializer.Uniform(-stdv, stdv)
weight_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
bias_attr = ParamAttr(
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
return [weight_attr, bias_attr]
......@@ -38,13 +36,12 @@ class CTCHead(nn.Layer):
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
super(CTCHead, self).__init__()
weight_attr, bias_attr = get_para_bias_attr(
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
l2_decay=fc_decay, k=in_channels)
self.fc = nn.Linear(
in_channels,
out_channels,
weight_attr=weight_attr,
bias_attr=bias_attr,
name='ctc_fc')
bias_attr=bias_attr)
self.out_channels = out_channels
def forward(self, x, labels=None):
......
......@@ -21,18 +21,19 @@ import copy
__all__ = ['build_post_process']
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
'DistillationCTCLabelDecode'
]
config = copy.deepcopy(config)
......
......@@ -125,6 +125,31 @@ class CTCLabelDecode(BaseRecLabelDecode):
return dict_character
class DistillationCTCLabelDecode(CTCLabelDecode):
"""
Convert
Convert between text-label and text-index
"""
def __init__(self,
character_dict_path=None,
character_type='ch',
use_space_char=False,
model_name="student",
key_out=None,
**kwargs):
super(DistillationCTCLabelDecode, self).__init__(
character_dict_path, character_type, use_space_char)
self.model_name = model_name
self.key_out = key_out
def __call__(self, preds, label=None, *args, **kwargs):
pred = preds[self.model_name]
if self.key_out is not None:
pred = pred[self.key_out]
return super().__call__(pred, label=label, *args, **kwargs)
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......
......@@ -42,7 +42,10 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
def load_dygraph_pretrain(model,
logger=None,
path=None,
load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
......
......@@ -386,7 +386,7 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet'
'CLS', 'PGNet', 'Distillation'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
......
......@@ -72,7 +72,14 @@ def main(config, device, logger, vdl_writer):
# for rec algorithm
if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num
if config['Architecture']["algorithm"] in ["Distillation",
]: # distillation model
for key in config['Architecture']["Models"]:
config['Architecture']["Models"][key]["Head"][
'out_channels'] = char_num
else: # base rec model
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
if config['Global']['distributed']:
model = paddle.DataParallel(model)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册