提交 33a15cfd 编写于 作者: C cuicheng01

fix static training bugs

上级 7bf9b40b
...@@ -46,7 +46,7 @@ class TopkAcc(AvgMetrics): ...@@ -46,7 +46,7 @@ class TopkAcc(AvgMetrics):
for k in self.topk: for k in self.topk:
metric_dict["top{}".format(k)] = paddle.metric.accuracy( metric_dict["top{}".format(k)] = paddle.metric.accuracy(
x, label, k=k) x, label, k=k)
self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)].numpy()[0], x.shape[0]) self.avg_meters["top{}".format(k)].update(metric_dict["top{}".format(k)], x.shape[0])
return metric_dict return metric_dict
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle
__all__ = ['AverageMeter'] __all__ = ['AverageMeter']
...@@ -44,6 +46,8 @@ class AverageMeter(object): ...@@ -44,6 +46,8 @@ class AverageMeter(object):
@property @property
def avg_info(self): def avg_info(self):
if isinstance(self.avg, paddle.Tensor):
self.avg = self.avg.numpy()[0]
return "{}: {:.5f}".format(self.name, self.avg) return "{}: {:.5f}".format(self.name, self.avg)
@property @property
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册