未验证 提交 59a3dcfc 编写于 作者: G gaotingquan

fix: amp eval

上级 fea9522a
...@@ -99,8 +99,8 @@ class Engine(object): ...@@ -99,8 +99,8 @@ class Engine(object):
paddle.__version__, self.device)) paddle.__version__, self.device))
# AMP training and evaluating # AMP training and evaluating
self.amp = "AMP" in self.config self.amp = "AMP" in self.config and self.config["AMP"] is not None
if self.amp and self.config["AMP"] is not None: if self.amp:
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0) self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get( self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False) "use_dynamic_loss_scaling", False)
...@@ -228,7 +228,7 @@ class Engine(object): ...@@ -228,7 +228,7 @@ class Engine(object):
len(self.train_dataloader), len(self.train_dataloader),
[self.model, self.train_loss_func]) [self.model, self.train_loss_func])
# for amp training # for amp
if self.amp: if self.amp:
self.scaler = paddle.amp.GradScaler( self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss, init_loss_scaling=self.scale_loss,
...@@ -239,12 +239,13 @@ class Engine(object): ...@@ -239,12 +239,13 @@ class Engine(object):
logger.warning(msg) logger.warning(msg)
self.config['AMP']["level"] = "O1" self.config['AMP']["level"] = "O1"
amp_level = "O1" amp_level = "O1"
self.model, self.optimizer = paddle.amp.decorate( self.model = paddle.amp.decorate(
models=self.model, models=self.model, level=amp_level, save_dtype='float32')
optimizers=self.optimizer, # TODO(gaotingquan): to compatible with Paddle develop and 2.2
level=amp_level, if isinstance(self.model, tuple):
save_dtype='float32') self.model = self.model[0]
if len(self.train_loss_func.parameters()) > 0: if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.amp.decorate( self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func, models=self.train_loss_func,
level=amp_level, level=amp_level,
......
...@@ -32,6 +32,15 @@ def classification_eval(engine, epoch_id=0): ...@@ -32,6 +32,15 @@ def classification_eval(engine, epoch_id=0):
} }
print_batch_step = engine.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
if engine.amp:
amp_level = engine.config['AMP'].get("level", "O1").upper()
if amp_level == "O2" and engine.config["AMP"].get("use_fp16_test",
False):
engine.config["AMP"]["use_fp16_test"] = True
msg = "Only support FP16 evaluation when AMP O2 is enabled."
logger.warning(msg)
amp_eval = engine.config["AMP"].get("use_fp16_test", False)
metric_key = None metric_key = None
tic = time.time() tic = time.time()
accum_samples = 0 accum_samples = 0
...@@ -58,15 +67,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -58,15 +67,7 @@ def classification_eval(engine, epoch_id=0):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
if engine.amp and ( if engine.amp and amp_eval:
engine.config['AMP'].get("level", "O1").upper() == "O2" or
engine.config["AMP"].get("use_fp16_test", False)):
amp_level = engine.config['AMP'].get("level", "O1").upper()
if amp_level == "O2":
msg = "Only support FP16 evaluation when AMP O2 is enabled."
logger.warning(msg)
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={ custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
...@@ -119,8 +120,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -119,8 +120,7 @@ def classification_eval(engine, epoch_id=0):
# calc loss # calc loss
if engine.eval_loss_func is not None: if engine.eval_loss_func is not None:
if engine.amp and engine.config["AMP"].get("use_fp16_test", False): if engine.amp and amp_eval:
amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={ custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册