diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index bdc8b2a48fdb8c8656fbf8a84bb94bd4162e258e..2a71de9c8a0a424915ae0eaa2f446a56e3464186 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -78,7 +78,13 @@ def classification_eval(engine, epoch_id=0): labels = paddle.concat(label_list, 0) if isinstance(out, dict): - out = out["logits"] + if "logits" in out: + out = out["logits"] + elif "Student" in out: + out = out["Student"] + else: + msg = "Error: Wrong key in out!" + raise Exception(msg) if isinstance(out, list): pred = [] for x in out: diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 37509eb14ea98b96f7a1fc96ee3e63f9fba18e7c..7c6407e7a4c74fa7d4330d72c6be52f6a843cdf0 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -222,7 +222,8 @@ class DistillationTopkAcc(TopkAcc): self.feature_key = feature_key def forward(self, x, label): - x = x[self.model_key] + if isinstance(x, dict): + x = x[self.model_key] if self.feature_key is not None: x = x[self.feature_key] return super().forward(x, label)