diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 758502ae968fb126764cd2218aad3d60c20fc611..3c75b59a8f9bbc27d36d4fe1fa7017ffe1e87483 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -40,6 +40,17 @@ def classification_eval(engine, epoch_id=0): dataset) if not engine.use_dali else engine.eval_dataloader.size max_iter = len(engine.eval_dataloader) - 1 if platform.system( ) == "Windows" else len(engine.eval_dataloader) + + # print("========================fp16 layer") + # for layer in engine.model.sublayers(include_self=True): + # print(type(layer), layer._dtype) + + # 用fp32做eval + engine.model.to(dtype='float32') + # print("========================to fp32 layer") + # for layer in engine.model.sublayers(include_self=True): + # print(type(layer), layer._dtype) + for iter_id, batch in enumerate(engine.eval_dataloader): if iter_id >= max_iter: break @@ -56,7 +67,8 @@ def classification_eval(engine, epoch_id=0): batch[0] = paddle.to_tensor(batch[0]).astype("float32") if not engine.config["Global"].get("use_multilabel", False): batch[1] = batch[1].reshape([-1, 1]).astype("int64") - + + ''' # image input if engine.amp: amp_level = 'O1' @@ -80,6 +92,19 @@ def classification_eval(engine, epoch_id=0): if key not in output_info: output_info[key] = AverageMeter(key, '7.5f') output_info[key].update(loss_dict[key].numpy()[0], batch_size) + ''' + + #======================================================== + out = engine.model(batch[0]) + + # calc loss + if engine.eval_loss_func is not None: + loss_dict = engine.eval_loss_func(out, batch[1]) + for key in loss_dict: + if key not in output_info: + output_info[key] = AverageMeter(key, '7.5f') + output_info[key].update(loss_dict[key].numpy()[0], batch_size) + #======================================================== # just for DistributedBatchSampler issue: repeat sampling current_samples = batch_size * paddle.distributed.get_world_size() @@ -151,6 +176,16 @@ def classification_eval(engine, epoch_id=0): len(engine.eval_dataloader), metric_msg, time_msg, ips_msg)) tic = time.time() + + #如果是amp-o2做eval后再将模型转回amp-o2的模式 + if engine.amp: + if engine.config['AMP']['use_pure_fp16'] is True: + paddle.fluid.dygraph.amp.auto_cast.pure_fp16_initialize([engine.model]) + # print("========================to fp16 layer") + # for layer in engine.model.sublayers(include_self=True): + # print(type(layer), layer._dtype) + # import sys + # sys.exit() if engine.use_dali: engine.eval_dataloader.reset() metric_msg = ", ".join([