未验证 提交 e0a6e5bf 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fox dist err (#1621)

* fox dist err

* fix init

* fix init
上级 aea712cc
...@@ -28,7 +28,6 @@ from ppcls.utils import logger ...@@ -28,7 +28,6 @@ from ppcls.utils import logger
from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.arch.slim import prune_model, quantize_model from ppcls.arch.slim import prune_model, quantize_model
__all__ = ["build_model", "RecModel", "DistillationModel"] __all__ = ["build_model", "RecModel", "DistillationModel"]
...@@ -82,13 +81,11 @@ class RecModel(TheseusLayer): ...@@ -82,13 +81,11 @@ class RecModel(TheseusLayer):
out["backbone"] = x out["backbone"] = x
if self.neck is not None: if self.neck is not None:
x = self.neck(x) x = self.neck(x)
out["neck"] = x
out["features"] = x out["features"] = x
if self.head is not None: if self.head is not None:
y = self.head(x, label) y = self.head(x, label)
out["neck"] = x out["logits"] = y
else:
y = None
out["logits"] = y
return out return out
......
# global configs # global configs
# global configs
Global: Global:
checkpoints: null checkpoints: null
pretrained_model: null pretrained_model: null
...@@ -85,11 +84,6 @@ Loss: ...@@ -85,11 +84,6 @@ Loss:
key: "logits" key: "logits"
model_name_pairs: model_name_pairs:
- ["Student", "Teacher"] - ["Student", "Teacher"]
- DistillationDMLLoss:
weight: 1.0
key: "logits"
model_name_pairs:
- ["Student", "Teacher"]
Eval: Eval:
- DistillationGTCELoss: - DistillationGTCELoss:
weight: 1.0 weight: 1.0
......
...@@ -57,7 +57,7 @@ Optimizer: ...@@ -57,7 +57,7 @@ Optimizer:
momentum: 0.9 momentum: 0.9
lr: lr:
name: Cosine name: Cosine
learning_rate: 1.3 learning_rate: 0.65
warmup_epoch: 5 warmup_epoch: 5
regularizer: regularizer:
name: 'L2' name: 'L2'
......
...@@ -69,7 +69,7 @@ class DistillationGTCELoss(CELoss): ...@@ -69,7 +69,7 @@ class DistillationGTCELoss(CELoss):
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss_dict = dict() loss_dict = dict()
for _, name in enumerate(self.model_names): for name in self.model_names:
out = predicts[name] out = predicts[name]
if self.key is not None: if self.key is not None:
out = out[self.key] out = out[self.key]
......
...@@ -42,8 +42,8 @@ class DMLLoss(nn.Layer): ...@@ -42,8 +42,8 @@ class DMLLoss(nn.Layer):
def forward(self, x, target): def forward(self, x, target):
if self.act is not None: if self.act is not None:
x = F.softmax(x) x = self.act(x)
target = F.softmax(target) target = self.act(target)
loss = self._kldiv(x, target) + self._kldiv(target, x) loss = self._kldiv(x, target) + self._kldiv(target, x)
loss = loss / 2 loss = loss / 2
loss = paddle.mean(loss) loss = paddle.mean(loss)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册