提交 7732a69f 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fix key error in distillation

上级 7848c4df
......@@ -78,7 +78,13 @@ def classification_eval(engine, epoch_id=0):
labels = paddle.concat(label_list, 0)
if isinstance(out, dict):
out = out["logits"]
if "logits" in out:
out = out["logits"]
elif "Student" in out:
out = out["Student"]
else:
msg = "Error: Wrong key in out!"
raise Exception(msg)
if isinstance(out, list):
pred = []
for x in out:
......
......@@ -222,7 +222,8 @@ class DistillationTopkAcc(TopkAcc):
self.feature_key = feature_key
def forward(self, x, label):
x = x[self.model_key]
if isinstance(x, dict):
x = x[self.model_key]
if self.feature_key is not None:
x = x[self.feature_key]
return super().forward(x, label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册