提交 bb19c1f7 编写于 作者: Z zhangbo9674 提交者: Tingquan Gao

fix eval bug

上级 b2956c1b
......@@ -56,15 +56,30 @@ def classification_eval(engine, epoch_id=0):
batch[0] = paddle.to_tensor(batch[0]).astype("float32")
if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
if engine.amp:
amp_level = 'O1'
if engine.config['AMP']['use_pure_fp16'] is True:
amp_level = 'O2'
with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level):
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
else:
out = engine.model(batch[0])
# calc loss
if engine.eval_loss_func is not None:
loss_dict = engine.eval_loss_func(out, batch[1])
for key in loss_dict:
if key not in output_info:
output_info[key] = AverageMeter(key, '7.5f')
output_info[key].update(loss_dict[key].numpy()[0], batch_size)
# just for DistributedBatchSampler issue: repeat sampling
current_samples = batch_size * paddle.distributed.get_world_size()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册