提交 0742f5c5 编写于 作者: L LDOUBLEV

fix metric etc.al

上级 a7b32ca8
...@@ -90,14 +90,14 @@ Loss: ...@@ -90,14 +90,14 @@ Loss:
- ["Student", "Student2"] - ["Student", "Student2"]
maps_name: "thrink_maps" maps_name: "thrink_maps"
weight: 1.0 weight: 1.0
act: "softmax" # act: None
model_name_pairs: ["Student", "Student2"] model_name_pairs: ["Student", "Student2"]
key: maps key: maps
- DistillationDBLoss: - DistillationDBLoss:
weight: 1.0 weight: 1.0
model_name_list: ["Student", "Student2"] model_name_list: ["Student", "Student2"]
# key: maps # key: maps
name: DBLoss # name: DBLoss
balance_loss: true balance_loss: true
main_loss_type: DiceLoss main_loss_type: DiceLoss
alpha: 5 alpha: 5
...@@ -119,8 +119,8 @@ Optimizer: ...@@ -119,8 +119,8 @@ Optimizer:
PostProcess: PostProcess:
name: DistillationDBPostProcess name: DistillationDBPostProcess
model_name: ["Student", "Student2"] model_name: ["Student", "Student2", "Teacher"]
key: head_out # key: maps
thresh: 0.3 thresh: 0.3
box_thresh: 0.6 box_thresh: 0.6
max_candidates: 1000 max_candidates: 1000
......
...@@ -54,6 +54,27 @@ class CELoss(nn.Layer): ...@@ -54,6 +54,27 @@ class CELoss(nn.Layer):
return loss return loss
class KLJSLoss(object):
def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
if self.mode.lower() == "js":
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
loss = paddle.mean(loss, axis=[1,2])
elif reduction=="none" or reduction is None:
return loss
else:
loss = paddle.sum(loss, axis=[1,2])
return loss
class DMLLoss(nn.Layer): class DMLLoss(nn.Layer):
""" """
DMLLoss DMLLoss
...@@ -70,16 +91,20 @@ class DMLLoss(nn.Layer): ...@@ -70,16 +91,20 @@ class DMLLoss(nn.Layer):
else: else:
self.act = None self.act = None
self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2): def forward(self, out1, out2):
if self.act is not None: if self.act is not None:
out1 = self.act(out1) out1 = self.act(out1)
out2 = self.act(out2) out2 = self.act(out2)
if len(out1.shape) < 2:
log_out1 = paddle.log(out1) log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2) log_out2 = paddle.log(out2)
loss = (F.kl_div( loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div( log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0 log_out2, out1, reduction='batchmean')) / 2.0
else:
loss = self.jskl_loss(out1, out2)
return loss return loss
......
...@@ -55,7 +55,5 @@ class CombinedLoss(nn.Layer): ...@@ -55,7 +55,5 @@ class CombinedLoss(nn.Layer):
loss_all += loss[key] * weight loss_all += loss[key] * weight
else: else:
loss_dict["{}_{}".format(key, idx)] = loss[key] loss_dict["{}_{}".format(key, idx)] = loss[key]
# loss[f"{key}_{idx}"] = loss[key]
loss_dict.update(loss)
loss_dict["loss"] = loss_all loss_dict["loss"] = loss_all
return loss_dict return loss_dict
...@@ -46,13 +46,13 @@ class DistillationDMLLoss(DMLLoss): ...@@ -46,13 +46,13 @@ class DistillationDMLLoss(DMLLoss):
act=None, act=None,
key=None, key=None,
maps_name=None, maps_name=None,
name="loss_dml"): name="dml"):
super().__init__(act=act) super().__init__(act=act)
assert isinstance(model_name_pairs, list) assert isinstance(model_name_pairs, list)
self.key = key self.key = key
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs) self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name self.name = name
self.maps_name = maps_name self.maps_name = self._check_maps_name(maps_name)
def _check_model_name_pairs(self, model_name_pairs): def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list): if not isinstance(model_name_pairs, list):
...@@ -76,11 +76,11 @@ class DistillationDMLLoss(DMLLoss): ...@@ -76,11 +76,11 @@ class DistillationDMLLoss(DMLLoss):
new_outs = {} new_outs = {}
for k in self.maps_name: for k in self.maps_name:
if k == "thrink_maps": if k == "thrink_maps":
new_outs[k] = paddle.slice(outs, axes=[1], starts=[0], ends=[1]) new_outs[k] = outs[:, 0, :, :]
elif k == "threshold_maps": elif k == "threshold_maps":
new_outs[k] = paddle.slice(outs, axes=[1], starts=[1], ends=[2]) new_outs[k] = outs[:, 1, :, :]
elif k == "binary_maps": elif k == "binary_maps":
new_outs[k] = paddle.slice(outs, axes=[1], starts=[2], ends=[3]) new_outs[k] = outs[:, 2, :, :]
else: else:
continue continue
return new_outs return new_outs
...@@ -105,14 +105,14 @@ class DistillationDMLLoss(DMLLoss): ...@@ -105,14 +105,14 @@ class DistillationDMLLoss(DMLLoss):
else: else:
outs1 = self._slice_out(out1) outs1 = self._slice_out(out1)
outs2 = self._slice_out(out2) outs2 = self._slice_out(out2)
for k in outs1.keys(): for _c, k in enumerate(outs1.keys()):
loss = super().forward(outs1[k], outs2[k]) loss = super().forward(outs1[k], outs2[k])
if isinstance(loss, dict): if isinstance(loss, dict):
for key in loss: for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[ loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key] 0], pair[1], map_name, idx)] = loss[key]
else: else:
loss_dict["{}_{}_{}".format(self.name, self.maps_name, loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
idx)] = loss idx)] = loss
loss_dict = _sum_loss(loss_dict) loss_dict = _sum_loss(loss_dict)
...@@ -152,7 +152,7 @@ class DistillationDBLoss(DBLoss): ...@@ -152,7 +152,7 @@ class DistillationDBLoss(DBLoss):
beta=10, beta=10,
ohem_ratio=3, ohem_ratio=3,
eps=1e-6, eps=1e-6,
name="db_loss", name="db",
**kwargs): **kwargs):
super().__init__() super().__init__()
self.model_name_list = model_name_list self.model_name_list = model_name_list
......
...@@ -55,6 +55,10 @@ class DetMetric(object): ...@@ -55,6 +55,10 @@ class DetMetric(object):
result = self.evaluator.evaluate_image(gt_info_list, det_info_list) result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
self.results.append(result) self.results.append(result)
metircs = self.evaluator.combine_results(self.results)
self.reset()
return metircs
def get_metric(self): def get_metric(self):
""" """
return metrics { return metrics {
......
...@@ -200,21 +200,18 @@ class DistillationDBPostProcess(DBPostProcess): ...@@ -200,21 +200,18 @@ class DistillationDBPostProcess(DBPostProcess):
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
**kwargs): **kwargs):
super(DistillationDBPostProcess, self).__init__( super().__init__()
thresh, box_thresh, max_candidates, unclip_ratio, use_dilation,
score_mode)
if not isinstance(model_name, list): if not isinstance(model_name, list):
model_name = [model_name] model_name = [model_name]
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
def forward(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
for name in self.model_name: for name in self.model_name:
pred = predicts[name] pred = predicts[name]
if self.key is not None: if self.key is not None:
pred = pred[self.key] pred = pred[self.key]
results[name] = super().__call__(pred, shape_list=label) results[name] = super().__call__(pred, shape_list=shape_list)
return results return results
...@@ -135,6 +135,7 @@ def load_pretrained_params(model, path): ...@@ -135,6 +135,7 @@ def load_pretrained_params(model, path):
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !" f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
) )
model.set_state_dict(new_state_dict) model.set_state_dict(new_state_dict)
print(f"load pretrain successful from {path}")
return True return True
def save_model(model, def save_model(model,
......
...@@ -55,8 +55,10 @@ def main(): ...@@ -55,8 +55,10 @@ def main():
model = build_model(config['Architecture']) model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
else:
model_type = None
best_model_dict = init_model(config, model) best_model_dict = init_model(config, model)
if len(best_model_dict): if len(best_model_dict):
logger.info('metric in ckpt ***************') logger.info('metric in ckpt ***************')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册