提交 ebd31950 编写于 作者: L LDOUBLEV

fix cml

上级 1e0ec88f
......@@ -94,14 +94,11 @@ Loss:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
......@@ -191,7 +188,6 @@ Eval:
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
......
......@@ -24,6 +24,7 @@ Architecture:
model_type: det
Models:
Student:
pretrained:
model_type: det
algorithm: DB
Transform: null
......@@ -40,6 +41,7 @@ Architecture:
name: DBHead
k: 50
Student2:
pretrained:
model_type: det
algorithm: DB
Transform: null
......@@ -56,6 +58,7 @@ Architecture:
name: DBHead
k: 50
Teacher:
pretrained:
freeze_params: true
return_all_feats: false
model_type: det
......@@ -91,14 +94,11 @@ Loss:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
......@@ -204,31 +204,21 @@ Eval:
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
- DecodeImage: # load image
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest: null
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
......@@ -60,19 +60,19 @@ class KLJSLoss(object):
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
def __call__(self, p1, p2, reduction="mean", eps=1e-5):
if self.mode.lower() == 'kl':
loss = paddle.multiply(p2,
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
paddle.log((p2 + eps) / (p1 + eps) + eps))
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
p1, paddle.log((p1 + eps) / (p2 + eps) + eps))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps))
loss += paddle.multiply(
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps))
loss *= 0.5
else:
raise ValueError(
......@@ -125,7 +125,7 @@ class DMLLoss(nn.Layer):
loss = (
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
else:
# for detection distillation log is not needed
# distillation log is not needed for detection
loss = self.jskl_loss(out1, out2)
return loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册