未验证 提交 b87f08f5 编写于 作者: H haoyuying 提交者: GitHub

Fix the OOM problem during evaluation (#1580)

上级 4a0cfef6
...@@ -301,24 +301,26 @@ class Trainer(object): ...@@ -301,24 +301,26 @@ class Trainer(object):
collate_fn=collate_fn) collate_fn=collate_fn)
self.model.eval() self.model.eval()
avg_loss = num_samples = 0 avg_loss = num_samples = 0
sum_metrics = defaultdict(int) sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int) avg_metrics = defaultdict(int)
with logger.processing('Evaluation on validation dataset'): with logger.processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(loader): with paddle.no_grad():
result = self.validation_step(batch, batch_idx) for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None) loss = result.get('loss', None)
metrics = result.get('metrics', {}) metrics = result.get('metrics', {})
bs = batch[0].shape[0] bs = batch[0].shape[0]
num_samples += bs num_samples += bs
if loss: if loss:
avg_loss += loss.numpy()[0] * bs avg_loss += loss.numpy()[0] * bs
for metric, value in metrics.items(): for metric, value in metrics.items():
sum_metrics[metric] += value * bs sum_metrics[metric] += value * bs
# print avg metrics and loss # print avg metrics and loss
print_msg = '[Evaluation result]' print_msg = '[Evaluation result]'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册