提交 97cb2d37 编写于 作者: L LDOUBLEV

fix doc

上级 90e2f2d6
...@@ -24,6 +24,7 @@ Architecture: ...@@ -24,6 +24,7 @@ Architecture:
model_type: det model_type: det
Models: Models:
Student: Student:
pretrained:
model_type: det model_type: det
algorithm: DB algorithm: DB
Transform: null Transform: null
...@@ -40,6 +41,7 @@ Architecture: ...@@ -40,6 +41,7 @@ Architecture:
name: DBHead name: DBHead
k: 50 k: 50
Student2: Student2:
pretrained:
model_type: det model_type: det
algorithm: DB algorithm: DB
Transform: null Transform: null
...@@ -56,6 +58,7 @@ Architecture: ...@@ -56,6 +58,7 @@ Architecture:
name: DBHead name: DBHead
k: 50 k: 50
Teacher: Teacher:
pretrained:
freeze_params: true freeze_params: true
return_all_feats: false return_all_feats: false
model_type: det model_type: det
...@@ -91,14 +94,11 @@ Loss: ...@@ -91,14 +94,11 @@ Loss:
- ["Student", "Student2"] - ["Student", "Student2"]
maps_name: "thrink_maps" maps_name: "thrink_maps"
weight: 1.0 weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"] model_name_pairs: ["Student", "Student2"]
key: maps key: maps
- DistillationDBLoss: - DistillationDBLoss:
weight: 1.0 weight: 1.0
model_name_list: ["Student", "Student2"] model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true balance_loss: true
main_loss_type: DiceLoss main_loss_type: DiceLoss
alpha: 5 alpha: 5
...@@ -204,31 +204,16 @@ Eval: ...@@ -204,31 +204,16 @@ Eval:
label_file_list: label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms: transforms:
- DecodeImage: - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: false channel_first: False
- DetLabelEncode: null - DetLabelEncode: # Class handling label
- DetResizeForTest: null - DetResizeForTest:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: mean: [0.485, 0.456, 0.406]
- 0.485 std: [0.229, 0.224, 0.225]
- 0.456 order: 'hwc'
- 0.406 - ToCHWImage:
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys: - KeepKeys:
keep_keys: keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
- image
- shape
- polys
- ignore_tags
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
...@@ -60,19 +60,19 @@ class KLJSLoss(object): ...@@ -60,19 +60,19 @@ class KLJSLoss(object):
], "mode can only be one of ['kl', 'KL', 'js', 'JS']" ], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode self.mode = mode
def __call__(self, p1, p2, reduction="mean"): def __call__(self, p1, p2, reduction="mean", eps=1e-5):
if self.mode.lower() == 'kl': if self.mode.lower() == 'kl':
loss = paddle.multiply(p2, loss = paddle.multiply(p2,
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) paddle.log((p2 + eps) / (p1 + eps) + eps))
loss += paddle.multiply( loss += paddle.multiply(p1,
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) paddle.log((p1 + eps) / (p2 + eps) + eps))
loss *= 0.5 loss *= 0.5
elif self.mode.lower() == "js": elif self.mode.lower() == "js":
loss = paddle.multiply( loss = paddle.multiply(
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps))
loss += paddle.multiply( loss += paddle.multiply(
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps))
loss *= 0.5 loss *= 0.5
else: else:
raise ValueError( raise ValueError(
...@@ -125,7 +125,7 @@ class DMLLoss(nn.Layer): ...@@ -125,7 +125,7 @@ class DMLLoss(nn.Layer):
loss = ( loss = (
self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0 self._kldiv(log_out1, out2) + self._kldiv(log_out2, out1)) / 2.0
else: else:
# for detection distillation log is not needed # log is not needed for detection
loss = self.jskl_loss(out1, out2) loss = self.jskl_loss(out1, out2)
return loss return loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册