diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 422e0b6d3cc18942a549627211e558f2779ea128..8190521628e5a5655eeb82aef06e14dbc4bab5dc 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -46,7 +46,7 @@ class TopkAcc(AvgMetrics): for k in self.topk: metric_dict["top{}".format(k)] = paddle.metric.accuracy( x, label, k=k) - self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)].numpy()[0], x.shape[0]) + self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0]) return metric_dict diff --git a/ppcls/utils/misc.py b/ppcls/utils/misc.py index ba1dd15916843adc254a7253efaa0570643bcb63..1ea8d2a44068f2dca10abdfd54b5c5ff72d29dc2 100644 --- a/ppcls/utils/misc.py +++ b/ppcls/utils/misc.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle + __all__ = ['AverageMeter'] @@ -44,6 +46,8 @@ class AverageMeter(object): @property def avg_info(self): + if isinstance(self.avg, paddle.Tensor): + self.avg = self.avg.numpy()[0] return "{}: {:.5f}".format(self.name, self.avg) @property