提交 b761325f 编写于 作者: G gaotingquan 提交者: Tingquan Gao

fix: fp32 eval by default when enable amp

If you want to eval by fp16 when enable amp, please set Amp.use_fp16_test=True, False by default.
上级 f3af5819
...@@ -58,7 +58,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -58,7 +58,7 @@ def classification_eval(engine, epoch_id=0):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
if engine.amp: if engine.amp and engine.config["AMP"].get("use_fp16_test", False):
amp_level = engine.config['AMP'].get("level", "O1").upper() amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={ custom_black_list={
......
...@@ -161,12 +161,13 @@ def main(args): ...@@ -161,12 +161,13 @@ def main(args):
# load pretrained models or checkpoints # load pretrained models or checkpoints
init_model(global_config, train_prog, exe) init_model(global_config, train_prog, exe)
if 'AMP' in config and config.AMP.get("level", "O1") == "O2": if 'AMP' in config:
optimizer.amp_init( optimizer.amp_init(
device, device,
scope=paddle.static.global_scope(), scope=paddle.static.global_scope(),
test_program=eval_prog test_program=eval_prog
if global_config["eval_during_train"] else None) if global_config["eval_during_train"] else None,
use_fp16_test=config["AMP"].get("use_fp16_test", False))
if not global_config.get("is_distributed", True): if not global_config.get("is_distributed", True):
compiled_train_prog = program.compile( compiled_train_prog = program.compile(
...@@ -182,7 +183,7 @@ def main(args): ...@@ -182,7 +183,7 @@ def main(args):
program.run(train_dataloader, exe, compiled_train_prog, train_feeds, program.run(train_dataloader, exe, compiled_train_prog, train_feeds,
train_fetchs, epoch_id, 'train', config, vdl_writer, train_fetchs, epoch_id, 'train', config, vdl_writer,
lr_scheduler, args.profiler_options) lr_scheduler, args.profiler_options)
# 2. evaate with eval dataset # 2. evaluate with eval dataset
if global_config["eval_during_train"] and epoch_id % global_config[ if global_config["eval_during_train"] and epoch_id % global_config[
"eval_interval"] == 0: "eval_interval"] == 0:
top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog, top1_acc = program.run(eval_dataloader, exe, compiled_eval_prog,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册