未验证 提交 f8fb8a7b 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #7406 from LDOUBLEV/dygraph

fix cml
......@@ -191,7 +191,6 @@ Eval:
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
......
......@@ -24,6 +24,7 @@ Architecture:
model_type: det
Models:
Student:
pretrained:
model_type: det
algorithm: DB
Transform: null
......@@ -40,6 +41,7 @@ Architecture:
name: DBHead
k: 50
Student2:
pretrained:
model_type: det
algorithm: DB
Transform: null
......@@ -91,14 +93,11 @@ Loss:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
......@@ -197,6 +196,7 @@ Train:
drop_last: false
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
......@@ -204,31 +204,21 @@ Eval:
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- DetLabelEncode: null
- DetResizeForTest: null
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- shape
- polys
- ignore_tags
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: false
drop_last: false
batch_size_per_card: 1
num_workers: 2
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
\ No newline at end of file
......@@ -60,19 +60,19 @@ class KLJSLoss(object):
], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
def __call__(self, p1, p2, reduction="mean", eps=1e-5):
if self.mode.lower() == 'kl':
loss = paddle.multiply(p2,
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
paddle.log((p2 + eps) / (p1 + eps) + eps))
loss += paddle.multiply(p1,
paddle.log((p1 + eps) / (p2 + eps) + eps))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
p2, paddle.log((2 * p2 + eps) / (p1 + p2 + eps) + eps))
loss += paddle.multiply(
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
p1, paddle.log((2 * p1 + eps) / (p1 + p2 + eps) + eps))
loss *= 0.5
else:
raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册