提交 0742f5c5 编写于 作者: L LDOUBLEV

fix metric etc.al

上级 a7b32ca8
......@@ -90,14 +90,14 @@ Loss:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
act: "softmax"
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
name: DBLoss
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
......@@ -119,8 +119,8 @@ Optimizer:
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
model_name: ["Student", "Student2", "Teacher"]
# key: maps
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
......
......@@ -54,6 +54,27 @@ class CELoss(nn.Layer):
return loss
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2])
elif reduction=="none" or reduction is None:
return loss
else:
loss = paddle.sum(loss, axis=[1,2])
return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
......@@ -70,16 +91,20 @@ class DMLLoss(nn.Layer):
else:
self.act = None
self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
if len(out1.shape) < 2:
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
else:
loss = self.jskl_loss(out1, out2)
return loss
......
......@@ -55,7 +55,5 @@ class CombinedLoss(nn.Layer):
loss_all += loss[key] * weight
else:
loss_dict["{}_{}".format(key, idx)] = loss[key]
# loss[f"{key}_{idx}"] = loss[key]
loss_dict.update(loss)
loss_dict["loss"] = loss_all
return loss_dict
......@@ -46,13 +46,13 @@ class DistillationDMLLoss(DMLLoss):
act=None,
key=None,
maps_name=None,
name="loss_dml"):
name="dml"):
super().__init__(act=act)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = maps_name
self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
......@@ -76,11 +76,11 @@ class DistillationDMLLoss(DMLLoss):
new_outs = {}
for k in self.maps_name:
if k == "thrink_maps":
new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1])
new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps":
new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2])
new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps":
new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3])
new_outs[k] = outs[:, 2, :, :]
else:
continue
return new_outs
......@@ -105,14 +105,14 @@ class DistillationDMLLoss(DMLLoss):
else:
outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2)
for k in outs1.keys():
for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key]
else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name,
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
idx)] = loss
loss_dict = _sum_loss(loss_dict)
......@@ -152,7 +152,7 @@ class DistillationDBLoss(DBLoss):
beta=10,
ohem_ratio=3,
eps=1e-6,
name="db_loss",
name="db",
**kwargs):
super().__init__()
self.model_name_list = model_name_list
......
......@@ -55,6 +55,10 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result)
metircs = self.evaluator.combine_results(self.results)
self.reset()
return metircs
def get_metric(self):
"""
return metrics {
......
......@@ -200,21 +200,18 @@ class DistillationDBPostProcess(DBPostProcess):
use_dilation=False,
score_mode="fast",
**kwargs):
super(DistillationDBPostProcess, self).__init__(
thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
score_mode)
super().__init__()
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def forward(self, predicts, shape_list):
def __call__(self, predicts, shape_list):
results = {}
for name in self.model_name:
pred = predicts[name]
if self.key is not None:
pred = pred[self.key]
results[name] = super().__call__(pred, shape_list=label)
results[name] = super().__call__(pred, shape_list=shape_list)
return results
......@@ -135,6 +135,7 @@ def load_pretrained_params(model, path):
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
)
model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}")
return True
def save_model(model,
......
......@@ -55,8 +55,10 @@ def main():
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
model_type = None
best_model_dict = init_model(config, model)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册