未验证 提交 abc4be00 编写于 作者: D Double_V 提交者: GitHub

add nrtr dml distill loss (#9968)

* support min_area_rect crop

* add check_install

* fix requirement.txt

* fix check_install

* add lanms-neo for drrg

* fix

* fix doc

* fix

* support set gpu_id when inference

* fix #8855

* fix #8855

* opt slim doc

* fix doc bug

* add v4_rec_distill config

* delete debug

* fix comment

* fix comment

* add dml nrtr distill loss
上级 1643f268
...@@ -96,6 +96,96 @@ class DistillationDMLLoss(DMLLoss): ...@@ -96,6 +96,96 @@ class DistillationDMLLoss(DMLLoss):
continue continue
return new_outs return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
loss = super().forward(out1[self.dis_head],
out2[self.dis_head])
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], self.maps_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
_c], idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationKLDivLoss(KLDivLoss):
"""
"""
def __init__(self,
model_name_pairs=[],
key=None,
multi_head=False,
dis_head='ctc',
maps_name=None,
name="kl_div"):
super().__init__()
assert isinstance(model_name_pairs, list)
self.key = key
self.multi_head = multi_head
self.dis_head = dis_head
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs): for idx, pair in enumerate(self.model_name_pairs):
...@@ -141,6 +231,149 @@ class DistillationDMLLoss(DMLLoss): ...@@ -141,6 +231,149 @@ class DistillationDMLLoss(DMLLoss):
return loss_dict return loss_dict
class DistillationDKDLoss(DKDLoss):
"""
"""
def __init__(self,
model_name_pairs=[],
key=None,
multi_head=False,
dis_head='ctc',
maps_name=None,
name="dkd",
temperature=1.0,
alpha=1.0,
beta=1.0):
super().__init__(temperature, alpha, beta)
assert isinstance(model_name_pairs, list)
self.key = key
self.multi_head = multi_head
self.dis_head = dis_head
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
elif isinstance(model_name_pairs[0], list) and isinstance(
model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
def _check_maps_name(self, maps_name):
if maps_name is None:
return None
elif type(maps_name) == str:
return [maps_name]
elif type(maps_name) == list:
return [maps_name]
else:
return None
def _slice_out(self, outs):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.maps_name is None:
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
tgt = batch[2][:, 1:2 +
max_len] # [batch_size, max_len + 1]
tgt = tgt.reshape([-1]) # batch_size * (max_len + 1)
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape,
dtype=tgt.dtype)) # batch_size * (max_len + 1)
loss = super().forward(
out1[self.dis_head], out2[self.dis_head], tgt,
non_pad_mask) # [batch_size, max_len + 1, num_char]
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], self.maps_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name[
_c], idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationNRTRDMLLoss(DistillationDMLLoss):
"""
"""
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if self.multi_head:
# for nrtr dml loss
max_len = batch[3].max()
tgt = batch[2][:, 1:2 + max_len]
tgt = tgt.reshape([-1])
non_pad_mask = paddle.not_equal(
tgt, paddle.zeros(
tgt.shape, dtype=tgt.dtype))
loss = super().forward(out1[self.dis_head], out2[self.dis_head],
non_pad_mask)
else:
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
idx)] = loss[key]
else:
loss_dict["{}_{}".format(self.name, idx)] = loss
loss_dict = _sum_loss(loss_dict)
return loss_dict
class DistillationKLDivLoss(KLDivLoss): class DistillationKLDivLoss(KLDivLoss):
""" """
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册