未验证 提交 e6feb68b 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1824 from TingquanGao/dev/spt_amp_eval

fix: fp32 eval by default when enable amp
...@@ -34,10 +34,10 @@ Loss: ...@@ -34,10 +34,10 @@ Loss:
# mixed precision training # mixed precision training
AMP: AMP:
scale_loss: 128.0 scale_loss: 128.0
use_dynamic_loss_scaling: True use_dynamic_loss_scaling: True
# O2: pure fp16 # O2: pure fp16
level: O2 level: O2
Optimizer: Optimizer:
name: Momentum name: Momentum
......
...@@ -53,13 +53,20 @@ def classification_eval(engine, epoch_id=0): ...@@ -53,13 +53,20 @@ def classification_eval(engine, epoch_id=0):
] ]
time_info["reader_cost"].update(time.time() - tic) time_info["reader_cost"].update(time.time() - tic)
batch_size = batch[0].shape[0] batch_size = batch[0].shape[0]
batch[0] = paddle.to_tensor(batch[0]).astype("float32") batch[0] = paddle.to_tensor(batch[0])
if not engine.config["Global"].get("use_multilabel", False): if not engine.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
if engine.amp and engine.config["AMP"].get("use_fp16_test", False): if engine.amp and (
engine.config['AMP'].get("level", "O1").upper() == "O2" or
engine.config["AMP"].get("use_fp16_test", False)):
amp_level = engine.config['AMP'].get("level", "O1").upper() amp_level = engine.config['AMP'].get("level", "O1").upper()
if amp_level == "O2":
msg = "Only support FP16 evaluation when AMP O2 is enabled."
logger.warning(msg)
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={ custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
......
...@@ -162,12 +162,21 @@ def main(args): ...@@ -162,12 +162,21 @@ def main(args):
init_model(global_config, train_prog, exe) init_model(global_config, train_prog, exe)
if 'AMP' in config: if 'AMP' in config:
if config["AMP"].get("level", "O1").upper() == "O2":
use_fp16_test = True
msg = "Only support FP16 evaluation when AMP O2 is enabled."
logger.warning(msg)
elif "use_fp16_test" in config["AMP"]:
use_fp16_test = config["AMP"].get["use_fp16_test"]
else:
use_fp16_test = False
optimizer.amp_init( optimizer.amp_init(
device, device,
scope=paddle.static.global_scope(), scope=paddle.static.global_scope(),
test_program=eval_prog test_program=eval_prog
if global_config["eval_during_train"] else None, if global_config["eval_during_train"] else None,
use_fp16_test=config["AMP"].get("use_fp16_test", False)) use_fp16_test=use_fp16_test)
if not global_config.get("is_distributed", True): if not global_config.get("is_distributed", True):
compiled_train_prog = program.compile( compiled_train_prog = program.compile(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册