From 7732a69f1bdb8d534e4f764f6486b5f5edeff474 Mon Sep 17 00:00:00 2001 From: gaotingquan Date: Wed, 15 Dec 2021 07:53:09 +0000 Subject: [PATCH] fix: fix key error in distillation --- ppcls/engine/evaluation/classification.py | 8 +++++++- ppcls/metric/metrics.py | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index bdc8b2a4..2a71de9c 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -78,7 +78,13 @@ def classification_eval(engine, epoch_id=0): labels = paddle.concat(label_list, 0) if isinstance(out, dict): - out = out["logits"] + if "logits" in out: + out = out["logits"] + elif "Student" in out: + out = out["Student"] + else: + msg = "Error: Wrong key in out!" + raise Exception(msg) if isinstance(out, list): pred = [] for x in out: diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 37509eb1..7c6407e7 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -222,7 +222,8 @@ class DistillationTopkAcc(TopkAcc): self.feature_key = feature_key def forward(self, x, label): - x = x[self.model_key] + if isinstance(x, dict): + x = x[self.model_key] if self.feature_key is not None: x = x[self.feature_key] return super().forward(x, label) -- GitLab