未验证 提交 275945df 编写于 作者: G gaotingquan

fix: compatible with Paddle 2.2, 2.3, and develop.

上级 59a3dcfc
...@@ -98,23 +98,6 @@ class Engine(object): ...@@ -98,23 +98,6 @@ class Engine(object):
logger.info('train with paddle {} and device {}'.format( logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device)) paddle.__version__, self.device))
# AMP training and evaluating
self.amp = "AMP" in self.config and self.config["AMP"] is not None
if self.amp:
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
else:
self.scale_loss = 1.0
self.use_dynamic_loss_scaling = False
if self.amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
if "class_num" in config["Global"]: if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"] global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]: if "class_num" not in config["Arch"]:
...@@ -228,27 +211,77 @@ class Engine(object): ...@@ -228,27 +211,77 @@ class Engine(object):
len(self.train_dataloader), len(self.train_dataloader),
[self.model, self.train_loss_func]) [self.model, self.train_loss_func])
# AMP training and evaluating
self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
# for amp # for amp
if self.amp: if self.amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler( self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss, init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling) use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
amp_level = self.config['AMP'].get("level", "O1")
if amp_level not in ["O1", "O2"]: self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'." msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg) logger.warning(msg)
self.config['AMP']["level"] = "O1" self.config['AMP']["level"] = "O1"
amp_level = "O1" self.amp_level = "O1"
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.config["Global"].get(
"eval_during_train",
True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
# TODO(gaotingquan): to compatible with Paddle 2.2, 2.3, develop and so on.
paddle_version = sum([
int(x) * 10**(2 - i)
for i, x in enumerate(paddle.__version__.split(".")[:3])
])
# paddle version < 2.3.0 and not develop
if paddle_version < 230 and paddle_version != 0:
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
save_dtype='float32')
elif self.amp_eval:
if self.amp_level == "O2":
msg = "The PaddlePaddle that installed not support FP16 evaluation in AMP O2. Please use PaddlePaddle version >= 2.3.0. Use FP32 evaluation instead and please notice the Eval Dataset output_fp16 should be 'False'."
logger.warning(msg)
self.amp_eval = False
else:
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
# paddle version >= 2.3.0 or develop
else:
self.model = paddle.amp.decorate( self.model = paddle.amp.decorate(
models=self.model, level=amp_level, save_dtype='float32') models=self.model,
# TODO(gaotingquan): to compatible with Paddle develop and 2.2 level=self.amp_level,
if isinstance(self.model, tuple): save_dtype='float32')
self.model = self.model[0]
if self.mode == "train" and len(self.train_loss_func.parameters( if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0: )) > 0:
self.train_loss_func = paddle.amp.decorate( self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func, models=self.train_loss_func,
level=amp_level, level=self.amp_level,
save_dtype='float32') save_dtype='float32')
# for distributed # for distributed
......
...@@ -32,15 +32,6 @@ def classification_eval(engine, epoch_id=0): ...@@ -32,15 +32,6 @@ def classification_eval(engine, epoch_id=0):
} }
print_batch_step = engine.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
if engine.amp:
amp_level = engine.config['AMP'].get("level", "O1").upper()
if amp_level == "O2" and engine.config["AMP"].get("use_fp16_test",
False):
engine.config["AMP"]["use_fp16_test"] = True
msg = "Only support FP16 evaluation when AMP O2 is enabled."
logger.warning(msg)
amp_eval = engine.config["AMP"].get("use_fp16_test", False)
metric_key = None metric_key = None
tic = time.time() tic = time.time()
accum_samples = 0 accum_samples = 0
...@@ -67,12 +58,12 @@ def classification_eval(engine, epoch_id=0): ...@@ -67,12 +58,12 @@ def classification_eval(engine, epoch_id=0):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input # image input
if engine.amp and amp_eval: if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={ custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}, },
level=amp_level): level=engine.amp_level):
out = engine.model(batch[0]) out = engine.model(batch[0])
else: else:
out = engine.model(batch[0]) out = engine.model(batch[0])
...@@ -120,12 +111,12 @@ def classification_eval(engine, epoch_id=0): ...@@ -120,12 +111,12 @@ def classification_eval(engine, epoch_id=0):
# calc loss # calc loss
if engine.eval_loss_func is not None: if engine.eval_loss_func is not None:
if engine.amp and amp_eval: if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={ custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}, },
level=amp_level): level=engine.amp_level):
loss_dict = engine.eval_loss_func(preds, labels) loss_dict = engine.eval_loss_func(preds, labels)
else: else:
loss_dict = engine.eval_loss_func(preds, labels) loss_dict = engine.eval_loss_func(preds, labels)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册