未验证 提交 1643f268 编写于 作者: D Double_V 提交者: GitHub

add V4 rec distill (#9921)

* support min_area_rect crop

* add check_install

* fix requirement.txt

* fix check_install

* add lanms-neo for drrg

* fix

* fix doc

* fix

* support set gpu_id when inference

* fix #8855

* fix #8855

* opt slim doc

* fix doc bug

* add v4_rec_distill config

* delete debug

* fix comment

* fix comment
上级 42516643
Global:
debug: false
use_gpu: true
epoch_num: 200
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_dkd_400w_svtr_ctc_lcnet_blank_dkd0.1/
save_epoch_step: 40
eval_batch_step:
- 0
- 2000
cal_metric_during_train: true
pretrained_model: null
checkpoints: ./output/rec_dkd_400w_svtr_ctc_lcnet_blank_dkd0.1/latest
save_inference_dir: null
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_ppocrv3.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: L2
factor: 3.0e-05
Architecture:
model_type: rec
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: true
return_all_feats: true
model_type: rec
algorithm: SVTR
Transform: null
Backbone:
name: SVTRNet
img_size:
- 48
- 320
out_char_num: 40
out_channels: 192
patch_merging: Conv
embed_dim:
- 64
- 128
- 256
depth:
- 3
- 6
- 3
num_heads:
- 2
- 4
- 8
mixer:
- Conv
- Conv
- Conv
- Conv
- Conv
- Conv
- Global
- Global
- Global
- Global
- Global
- Global
local_mixer:
- - 5
- 5
- - 5
- 5
- - 5
- 5
last_stage: false
prenorm: true
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: SVTR
Transform: null
Backbone:
name: PPLCNetV3
scale: 0.95
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 120
depth: 2
hidden_dims: 120
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
max_text_length: *max_text_length
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDKDLoss:
weight: 0.1
model_name_pairs:
- - Student
- Teacher
key: head_out
multi_head: true
alpha: 1.0
beta: 2.0
dis_head: gtc
name: dkd
- DistillationCTCLoss:
weight: 1.0
model_name_list:
- Student
key: head_out
multi_head: true
- DistillationNRTRLoss:
weight: 1.0
smoothing: false
model_name_list:
- Student
key: head_out
multi_head: true
- DistillCTCLogits:
weight: 1.0
reduction: mean
model_name_pairs:
- - Student
- Teacher
key: head_out
PostProcess:
name: DistillationCTCLabelDecode
model_name:
- Student
key: head_out
multi_head: true
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: Student
ignore_space: false
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
ratio_list:
- 1.0
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_workers: 8
use_shared_memory: true
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4
profiler_options: null
......@@ -165,3 +165,79 @@ class LossFromOutput(nn.Layer):
elif self.reduction == 'sum':
loss = paddle.sum(loss)
return {'loss': loss}
class KLDivLoss(nn.Layer):
"""
KLDivLoss
"""
def __init__(self):
super().__init__()
def _kldiv(self, x, target, mask=None):
eps = 1.0e-10
loss = target * (paddle.log(target + eps) - x)
if mask is not None:
loss = loss.flatten(0, 1).sum(axis=1)
loss = loss.masked_select(mask).mean()
else:
# batch mean loss
loss = paddle.sum(loss) / loss.shape[0]
return loss
def forward(self, logits_s, logits_t, mask=None):
log_out_s = F.log_softmax(logits_s, axis=-1)
out_t = F.softmax(logits_t, axis=-1)
loss = self._kldiv(log_out_s, out_t, mask)
return loss
class DKDLoss(nn.Layer):
"""
KLDivLoss
"""
def __init__(self, temperature=1.0, alpha=1.0, beta=1.0):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.beta = beta
def _cat_mask(self, t, mask1, mask2):
t1 = (t * mask1).sum(axis=1, keepdim=True)
t2 = (t * mask2).sum(axis=1, keepdim=True)
rt = paddle.concat([t1, t2], axis=1)
return rt
def _kl_div(self, x, label, mask=None):
y = (label * (paddle.log(label + 1e-10) - x)).sum(axis=1)
if mask is not None:
y = y.masked_select(mask).mean()
else:
y = y.mean()
return y
def forward(self, logits_student, logits_teacher, target, mask=None):
gt_mask = F.one_hot(
target.reshape([-1]), num_classes=logits_student.shape[-1])
other_mask = 1 - gt_mask
logits_student = logits_student.flatten(0, 1)
logits_teacher = logits_teacher.flatten(0, 1)
pred_student = F.softmax(logits_student / self.temperature, axis=1)
pred_teacher = F.softmax(logits_teacher / self.temperature, axis=1)
pred_student = self._cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask)
log_pred_student = paddle.log(pred_student)
tckd_loss = self._kl_div(log_pred_student,
pred_teacher) * (self.temperature**2)
pred_teacher_part2 = F.softmax(
logits_teacher / self.temperature - 1000.0 * gt_mask, axis=1)
log_pred_student_part2 = F.log_softmax(
logits_student / self.temperature - 1000.0 * gt_mask, axis=1)
nckd_loss = self._kl_div(log_pred_student_part2,
pred_teacher_part2) * (self.temperature**2)
loss = self.alpha * tckd_loss + self.beta * nckd_loss
return loss
......@@ -20,9 +20,9 @@ from .center_loss import CenterLoss
from .ace_loss import ACELoss
from .rec_sar_loss import SARLoss
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationCTCLoss, DistillCTCLogits
from .distillation_loss import DistillationSARLoss, DistillationNRTRLoss
from .distillation_loss import DistillationDMLLoss, DistillationKLDivLoss, DistillationDKDLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
from .distillation_loss import DistillationLossFromOutput
......
......@@ -14,12 +14,14 @@
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
import cv2
from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss
from .rec_ce_loss import CELoss
from .basic_loss import DMLLoss, KLDivLoss, DKDLoss
from .basic_loss import DistanceLoss
from .basic_loss import LossFromOutput
from .det_db_loss import DBLoss
......@@ -102,11 +104,220 @@ class DistillationDMLLoss(DMLLoss):
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
tgt = batch[2][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype=tgt.dtype))
loss = super().forward(out1[self.dis_head],
out2[self.dis_head], non_pad_mask)
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], self.maps_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
_c], idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationKLDivLoss(KLDivLoss):
"""
"""
def __init__(self,
model_name_pairs=[],
key=None,
multi_head=False,
dis_head='ctc',
maps_name=None,
name="kl_div"):
super().__init__()
assert isinstance(model_name_pairs, list)
self.key = key
self.multi_head = multi_head
self.dis_head = dis_head
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
tgt = batch[2][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype=tgt.dtype))
loss = super().forward(out1[self.dis_head],
out2[self.dis_head])
out2[self.dis_head], non_pad_mask)
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], self.maps_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
_c], idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationDKDLoss(DKDLoss):
"""
"""
def __init__(self,
model_name_pairs=[],
key=None,
multi_head=False,
dis_head='ctc',
maps_name=None,
name="dkd",
temperature=1.0,
alpha=1.0,
beta=1.0):
super().__init__(temperature, alpha, beta)
assert isinstance(model_name_pairs, list)
self.key = key
self.multi_head = multi_head
self.dis_head = dis_head
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
tgt = batch[2][:, 1:2 +
max_len] # [batch_size, max_len + 1]
tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape,
dtype=tgt.dtype)) # batch_size * (max_len + 1)
loss = super().forward(
out1[self.dis_head], out2[self.dis_head], tgt,
non_pad_mask) # [batch_size, max_len + 1, num_char]
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
......@@ -199,6 +410,40 @@ class DistillationSARLoss(SARLoss):
return loss_dict
class DistillationNRTRLoss(CELoss):
def __init__(self,
model_name_list=[],
key=None,
multi_head=False,
smoothing=True,
name="loss_nrtr",
**kwargs):
super().__init__(smoothing=smoothing)
self.model_name_list = model_name_list
self.key = key
self.name = name
self.multi_head = multi_head
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
if self.multi_head:
assert 'gtc' in out, 'multi head has multi out'
loss = super().forward(out['gtc'], batch[:1] + batch[2:])
else:
loss = super().forward(out, batch)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, model_name)] = loss
return loss_dict
class DistillationDBLoss(DBLoss):
def __init__(self,
model_name_list=[],
......@@ -459,3 +704,212 @@ class DistillationVQADistanceLoss(DistanceLoss):
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
class CTCDKDLoss(nn.Layer):
"""
KLDivLoss
"""
def __init__(self, temperature=0.5, alpha=1.0, beta=1.0):
super().__init__()
self.temperature = temperature
self.alpha = alpha
self.beta = beta
self.eps = 1e-6
self.t = temperature
self.act = nn.Softmax(axis=-1)
self.use_log = True
def kl_loss(self, p1, p2): # predict, label
loss = paddle.multiply(
p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps))
bs = loss.shape[0]
loss = paddle.sum(loss) / bs
return loss
def _cat_mask(self, t, mask1, mask2):
t1 = (t * mask1).sum(axis=1, keepdim=True)
t2 = (t * mask2).sum(axis=1, keepdim=True)
rt = paddle.concat([t1, t2], axis=1)
return rt
def multi_label_mask(self, targets):
targets = targets.astype("int32")
res = F.one_hot(targets, num_classes=11465)
mask = paddle.clip(paddle.sum(res, axis=1), 0, 1)
mask[:, 0] = 0 # ingore ctc blank label
return mask
def forward(self, logits_student, logits_teacher, targets, mask=None):
gt_mask = self.multi_label_mask(targets)
other_mask = paddle.ones_like(gt_mask) - gt_mask
pred_student = F.softmax(logits_student / self.temperature, axis=-1)
pred_teacher = F.softmax(logits_teacher / self.temperature, axis=-1)
# differents with dkd
pred_student = paddle.mean(pred_student, axis=1)
pred_teacher = paddle.mean(pred_teacher, axis=1)
pred_student = self._cat_mask(pred_student, gt_mask, other_mask)
pred_teacher = self._cat_mask(pred_teacher, gt_mask, other_mask)
# differents with dkd
tckd_loss = self.kl_loss(pred_student, pred_teacher)
gt_mask_ex = paddle.expand_as(gt_mask.unsqueeze(axis=1), logits_teacher)
pred_teacher_part2 = F.softmax(
logits_teacher / self.temperature - 1000.0 * gt_mask_ex, axis=-1)
pred_student_part2 = F.softmax(
logits_student / self.temperature - 1000.0 * gt_mask_ex, axis=-1)
# differents with dkd
pred_teacher_part2 = paddle.mean(pred_teacher_part2, axis=1)
pred_student_part2 = paddle.mean(pred_student_part2, axis=1)
# differents with dkd
nckd_loss = self.kl_loss(pred_student_part2, pred_teacher_part2)
loss = self.alpha * tckd_loss + self.beta * nckd_loss
return loss
class KLCTCLogits(nn.Layer):
def __init__(self, weight=1.0, reduction='mean', mode="mean"):
super().__init__()
self.weight = weight
self.reduction = reduction
self.eps = 1e-6
self.t = 0.5
self.act = nn.Softmax(axis=-1)
self.use_log = True
self.mode = mode
self.ctc_dkd_loss = CTCDKDLoss()
def kl_loss(self, p1, p2): # predict, label
loss = paddle.multiply(
p2, paddle.log((p2 + self.eps) / (p1 + self.eps) + self.eps))
bs = loss.shape[0]
loss = paddle.sum(loss) / bs
return loss
def forward_meanmax(self, stu_out, tea_out):
stu_out = paddle.mean(F.softmax(stu_out / self.t, axis=-1), axis=1)
tea_out = paddle.mean(F.softmax(tea_out / self.t, axis=-1), axis=1)
loss = self.kl_loss(stu_out, tea_out)
return loss
def forward_meanlog(self, stu_out, tea_out):
stu_out = paddle.mean(F.softmax(stu_out / self.t, axis=-1), axis=1)
tea_out = paddle.mean(F.softmax(tea_out / self.t, axis=-1), axis=1)
if self.use_log is True:
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(stu_out)
log_out2 = paddle.log(tea_out)
loss = (
self._kldiv(log_out1, tea_out) + self._kldiv(log_out2, stu_out)
) / 2.0
return loss
def forward_sum(self, stu_out, tea_out):
stu_out = paddle.sum(F.softmax(stu_out / self.t, axis=-1), axis=1)
tea_out = paddle.sum(F.softmax(tea_out / self.t, axis=-1), axis=1)
stu_out = paddle.log(stu_out)
bs = stu_out.shape[0]
loss = tea_out * (paddle.log(tea_out + self.eps) - stu_out)
loss = paddle.sum(loss, axis=1) / loss.shape[0]
return loss
def _kldiv(self, x, target):
eps = 1.0e-10
loss = target * (paddle.log(target + eps) - x)
loss = paddle.sum(paddle.mean(loss, axis=1)) / loss.shape[0]
return loss
def forward(self, stu_out, tea_out, targets=None):
if self.mode == "log":
return self.forward_log(stu_out, tea_out)
elif self.mode == "mean":
blank_mask = paddle.ones_like(stu_out)
blank_mask.stop_gradient = True
blank_mask[:, :, 0] = -1
stu_out *= blank_mask
tea_out *= blank_mask
return self.forward_meanmax(stu_out, tea_out)
elif self.mode == "sum":
return self.forward_sum(stu_out, tea_out)
elif self.mode == "meanlog":
blank_mask = paddle.ones_like(stu_out)
blank_mask.stop_gradient = True
blank_mask[:, :, 0] = -1
stu_out *= blank_mask
tea_out *= blank_mask
return self.forward_meanlog(stu_out, tea_out)
elif self.mode == "ctcdkd":
# ingore ctc blank logits
blank_mask = paddle.ones_like(stu_out)
blank_mask.stop_gradient = True
blank_mask[:, :, 0] = -1
stu_out *= blank_mask
tea_out *= blank_mask
return self.ctc_dkd_loss(stu_out, tea_out, targets)
else:
raise ValueError("error!!!!!!")
def forward_log(self, out1, out2):
if self.act is not None:
out1 = self.act(out1) + 1e-10
out2 = self.act(out2) + 1e-10
if self.use_log is True:
# for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
return loss
class DistillCTCLogits(KLCTCLogits):
def __init__(self,
model_name_pairs=[],
key=None,
name="ctc_logits",
reduction="mean"):
super().__init__(reduction=reduction)
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.key = key
self.name = name
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]['ctc']
out2 = out2[self.key]['ctc']
ctc_label = batch[1]
loss = super().forward(out1, out2, ctc_label)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}".format(self.name, model_name,
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
return loss_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册