From bb19c1f7a62f3812c52ee435e3dae7062041206f Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 7 Jan 2022 06:41:46 +0000 Subject: [PATCH] fix eval bug --- ppcls/engine/evaluation/classification.py | 31 +++++++++++++++++------ 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 98bad639..758502ae 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -56,15 +56,30 @@ def classification_eval(engine, epoch_id=0): batch[0] = paddle.to_tensor(batch[0]).astype("float32") if not engine.config["Global"].get("use_multilabel", False): batch[1] = batch[1].reshape([-1, 1]).astype("int64") + # image input - out = engine.model(batch[0]) - # calc loss - if engine.eval_loss_func is not None: - loss_dict = engine.eval_loss_func(out, batch[1]) - for key in loss_dict: - if key not in output_info: - output_info[key] = AverageMeter(key, '7.5f') - output_info[key].update(loss_dict[key].numpy()[0], batch_size) + if engine.amp: + amp_level = 'O1' + if engine.config['AMP']['use_pure_fp16'] is True: + amp_level = 'O2' + with paddle.amp.auto_cast(custom_black_list={"flatten_contiguous_range", "greater_than"}, level=amp_level): + out = engine.model(batch[0]) + # calc loss + if engine.eval_loss_func is not None: + loss_dict = engine.eval_loss_func(out, batch[1]) + for key in loss_dict: + if key not in output_info: + output_info[key] = AverageMeter(key, '7.5f') + output_info[key].update(loss_dict[key].numpy()[0], batch_size) + else: + out = engine.model(batch[0]) + # calc loss + if engine.eval_loss_func is not None: + loss_dict = engine.eval_loss_func(out, batch[1]) + for key in loss_dict: + if key not in output_info: + output_info[key] = AverageMeter(key, '7.5f') + output_info[key].update(loss_dict[key].numpy()[0], batch_size) # just for DistributedBatchSampler issue: repeat sampling current_samples = batch_size * paddle.distributed.get_world_size() -- GitLab