未验证 提交 7dc56191 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix rec distillation (#3994)

* fix rec distillation

* add dist cfg

* fix yaml
上级 51f4a2c3
...@@ -4,7 +4,7 @@ Global: ...@@ -4,7 +4,7 @@ Global:
epoch_num: 800 epoch_num: 800
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1 save_model_dir: ./output/rec_mobile_pp-OCRv2
save_epoch_step: 3 save_epoch_step: 3
eval_batch_step: [0, 2000] eval_batch_step: [0, 2000]
cal_metric_during_train: true cal_metric_during_train: true
...@@ -19,7 +19,7 @@ Global: ...@@ -19,7 +19,7 @@ Global:
infer_mode: false infer_mode: false
use_space_char: true use_space_char: true
distributed: true distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt save_res_path: ./output/rec/predicts_mobile_pp-OCRv2.txt
Optimizer: Optimizer:
...@@ -35,16 +35,9 @@ Optimizer: ...@@ -35,16 +35,9 @@ Optimizer:
name: L2 name: L2
factor: 2.0e-05 factor: 2.0e-05
Architecture: Architecture:
model_type: &model_type "rec" model_type: rec
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN algorithm: CRNN
Transform: Transform:
Backbone: Backbone:
...@@ -58,56 +51,16 @@ Architecture: ...@@ -58,56 +51,16 @@ Architecture:
name: CTCHead name: CTCHead
mid_channels: 96 mid_channels: 96
fc_decay: 0.00002 fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Loss: Loss:
name: CombinedLoss name: CTCLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess: PostProcess:
name: DistillationCTCLabelDecode name: CTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
Metric: Metric:
name: DistillationMetric name: RecMetric
base_metric_name: RecMetric
main_indicator: acc main_indicator: acc
key: "Student"
Train: Train:
dataset: dataset:
...@@ -132,7 +85,6 @@ Train: ...@@ -132,7 +85,6 @@ Train:
shuffle: true shuffle: true
batch_size_per_card: 128 batch_size_per_card: 128
drop_last: true drop_last: true
num_sections: 1
num_workers: 8 num_workers: 8
Eval: Eval:
dataset: dataset:
......
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_pp-OCRv2_distillation
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700, 800]
values : [0.001, 0.0001]
warmup_epoch: 5
regularizer:
name: L2
factor: 2.0e-05
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Loss:
name: CombinedLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_sections: 1
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 8
...@@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要 ...@@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
### 2.1 识别配置文件解析 ### 2.1 识别配置文件解析
配置文件在[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml) 配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)
#### 2.1.1 模型结构 #### 2.1.1 模型结构
...@@ -246,6 +246,39 @@ Metric: ...@@ -246,6 +246,39 @@ Metric:
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24) 关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)
#### 2.1.5 蒸馏模型微调
对蒸馏得到的识别蒸馏进行微调有2种方式。
(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。
* 首先下载预训练模型并解压。
```shell
# 下面预训练模型并解压
wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
tar -xf ch_PP-OCRv2_rec_train.tar
```
* 然后使用python,对其中的学生模型参数进行提取
```python
import paddle
# 加载预训练模型
all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
# 查看权重参数的keys
print(all_params.keys())
# 学生模型的权重提取
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# 查看学生模型权重参数的keys
print(s_params.keys())
# 保存
paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
```
转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
### 2.2 检测配置文件解析 ### 2.2 检测配置文件解析
* coming soon! * coming soon!
...@@ -56,31 +56,34 @@ class CELoss(nn.Layer): ...@@ -56,31 +56,34 @@ class CELoss(nn.Layer):
class KLJSLoss(object): class KLJSLoss(object):
def __init__(self, mode='kl'): def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']" assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode self.mode = mode
def __call__(self, p1, p2, reduction="mean"): def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5)) loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js": if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5)) loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5 loss *= 0.5
if reduction == "mean": if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2]) loss = paddle.mean(loss, axis=[1, 2])
elif reduction=="none" or reduction is None: elif reduction == "none" or reduction is None:
return loss return loss
else: else:
loss = paddle.sum(loss, axis=[1,2]) loss = paddle.sum(loss, axis=[1, 2])
return loss return loss
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
""" """
DMLLoss DMLLoss
""" """
def __init__(self, act=None): def __init__(self, act=None, use_log=False):
super().__init__() super().__init__()
if act is not None: if act is not None:
assert act in ["softmax", "sigmoid"] assert act in ["softmax", "sigmoid"]
...@@ -91,19 +94,23 @@ class DMLLoss(nn.Layer): ...@@ -91,19 +94,23 @@ class DMLLoss(nn.Layer):
else: else:
self.act = None self.act = None
self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js") self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2): def forward(self, out1, out2):
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
out2 = self.act(out2) out2 = self.act(out2)
if len(out1.shape) < 2: if self.use_log:
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1) log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0 log_out2, out1, reduction='batchmean')) / 2.0
else: else:
# for detection distillation log is not needed
loss = self.jskl_loss(out1, out2) loss = self.jskl_loss(out1, out2)
return loss return loss
......
...@@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer): ...@@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer):
loss = loss_func(input, batch, **kargs) loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor): if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss} loss = {"loss_{}_{}".format(str(loss), idx): loss}
weight = self.loss_weight[idx] weight = self.loss_weight[idx]
for key in loss.keys():
if key == "loss": loss = {key: loss[key] * weight for key in loss}
loss_all += loss[key] * weight
if "loss" in loss:
loss_all += loss["loss"]
else: else:
loss_dict["{}_{}".format(key, idx)] = loss[key] loss_all += paddle.add_n(list(loss.values()))
loss_dict.update(loss)
loss_dict["loss"] = loss_all loss_dict["loss"] = loss_all
return loss_dict return loss_dict
...@@ -44,10 +44,11 @@ class DistillationDMLLoss(DMLLoss): ...@@ -44,10 +44,11 @@ class DistillationDMLLoss(DMLLoss):
def __init__(self, def __init__(self,
model_name_pairs=[], model_name_pairs=[],
act=None, act=None,
use_log=False,
key=None, key=None,
maps_name=None, maps_name=None,
name="dml"): name="dml"):
super().__init__(act=act) super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
...@@ -57,7 +58,8 @@ class DistillationDMLLoss(DMLLoss): ...@@ -57,7 +58,8 @@ class DistillationDMLLoss(DMLLoss):
def _check_model_name_pairs(self, model_name_pairs): def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list): if not isinstance(model_name_pairs, list):
return [] return []
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str): elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs return model_name_pairs
else: else:
return [model_name_pairs] return [model_name_pairs]
...@@ -112,8 +114,8 @@ class DistillationDMLLoss(DMLLoss): ...@@ -112,8 +114,8 @@ class DistillationDMLLoss(DMLLoss):
loss_dict["{}_{}_{}_{}_{}".format(key, pair[ loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key] 0], pair[1], map_name, idx)] = loss[key]
else: else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c], loss_dict["{}_{}_{}".format(self.name, self.maps_name[
idx)] = loss _c], idx)] = loss
loss_dict = _sum_loss(loss_dict) loss_dict = _sum_loss(loss_dict)
......
...@@ -116,6 +116,7 @@ def load_dygraph_params(config, model, logger, optimizer): ...@@ -116,6 +116,7 @@ def load_dygraph_params(config, model, logger, optimizer):
logger.info(f"loaded pretrained_model successful from {pm}") logger.info(f"loaded pretrained_model successful from {pm}")
return {} return {}
def load_pretrained_params(model, path): def load_pretrained_params(model, path):
if path is None: if path is None:
return False return False
...@@ -138,6 +139,7 @@ def load_pretrained_params(model, path): ...@@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
print(f"load pretrain successful from {path}") print(f"load pretrain successful from {path}")
return model return model
def save_model(model, def save_model(model,
optimizer, optimizer,
model_path, model_path,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册