提交 33a15cfd 编写于 作者: C cuicheng01

fix static training bugs

上级 7bf9b40b
......@@ -46,7 +46,7 @@ class TopkAcc(AvgMetrics):
for k in self.topk:
metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k)
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)].numpy()[0], x.shape[0])
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
return metric_dict
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
__all__ = ['AverageMeter']
......@@ -44,6 +46,8 @@ class AverageMeter(object):
@property
def avg_info(self):
if isinstance(self.avg, paddle.Tensor):
self.avg = self.avg.numpy()[0]
return "{}: {:.5f}".format(self.name, self.avg)
@property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册