未验证 提交 b87f08f5 编写于 作者: H haoyuying 提交者: GitHub

Fix the OOM problem during evaluation (#1580)

上级 4a0cfef6
......@@ -301,24 +301,26 @@ class Trainer(object):
collate_fn=collate_fn)
self.model.eval()
avg_loss = num_samples = 0
sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int)
with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
with paddle.no_grad():
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
num_samples += bs
loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
num_samples += bs
if loss:
avg_loss += loss.numpy()[0] * bs
if loss:
avg_loss += loss.numpy()[0] * bs
for metric, value in metrics.items():
sum_metrics[metric] += value * bs
for metric, value in metrics.items():
sum_metrics[metric] += value * bs
# print avg metrics and loss
print_msg = '[Evaluation result]'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册