提交 8b218b01 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor amp auto_cast context manager & loss scaler

上级 f884f288
......@@ -33,6 +33,7 @@ from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.amp import AutoCast, build_scaler
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
......@@ -459,12 +460,7 @@ class Engine(object):
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data)
if self.amp and self.amp_eval:
with paddle.amp.auto_cast(
level=self.amp_level,
use_promote=self.use_promote):
out = self.model(batch_tensor)
else:
with self.auto_cast(is_eval=True):
out = self.model(batch_tensor)
if isinstance(out, list):
......@@ -528,10 +524,13 @@ class Engine(object):
)
def _init_amp(self):
self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
amp_config = self.config.get("AMP", None)
use_amp = True if amp_config else False
if self.amp:
if not use_amp:
self.auto_cast = AutoCast(use_amp)
self.scaler = build_scaler(use_amp)
else:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
......@@ -539,42 +538,46 @@ class Engine(object):
})
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
self.use_promote = self.config['AMP'].get("use_promote", False)
self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
use_promote = amp_config.get("use_promote", False)
amp_level = amp_config.get("level", "O1")
if amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg)
self.config['AMP']["level"] = "O1"
self.amp_level = "O1"
amp_level = amp_config["level"] = "O1"
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.mode == "train" and self.config["Global"].get(
"eval_during_train",
True) and self.amp_level == "O2" and self.amp_eval == False:
True) and amp_level == "O2" and amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True
self.amp_eval = True
amp_eval = True
self.auto_cast = AutoCast(
use_amp,
amp_level=amp_level,
use_promote=use_promote,
amp_eval=amp_eval)
scale_loss = amp_config.get("scale_loss", 1.0)
use_dynamic_loss_scaling = amp_config.get(
"use_dynamic_loss_scaling", False)
self.scaler = build_scaler(
use_amp,
scale_loss=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
if self.mode == "train":
self.model, self.optimizer = paddle.amp.decorate(
models=self.model,
optimizers=self.optimizer,
level=self.amp_level,
level=amp_level,
save_dtype='float32')
elif self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
models=self.model, level=amp_level, save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
......
......@@ -55,10 +55,7 @@ def classification_eval(engine, epoch_id=0):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(level=engine.amp_level):
out = engine.model(batch[0])
else:
with engine.auto_cast(is_eval=True):
out = engine.model(batch[0])
# just for DistributedBatchSampler issue: repeat sampling
......@@ -109,10 +106,7 @@ def classification_eval(engine, epoch_id=0):
# calc loss
if engine.eval_loss_func is not None:
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(level=engine.amp_level):
loss_dict = engine.eval_loss_func(preds, labels)
else:
with engine.auto_cast(is_eval=True):
loss_dict = engine.eval_loss_func(preds, labels)
for key in loss_dict:
......
......@@ -136,10 +136,7 @@ def compute_feature(engine, name="gallery"):
if len(batch) >= 3:
has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(level=engine.amp_level):
out = engine.model(batch[0])
else:
with engine.auto_cast(is_eval=True):
out = engine.model(batch[0])
if "Student" in out:
out = out["Student"]
......
......@@ -48,12 +48,7 @@ def train_epoch(engine, epoch_id, print_batch_step):
engine.global_step += 1
# image input
if engine.amp:
amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast(level=amp_level):
out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
else:
with engine.auto_cast(is_eval=False):
out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
......@@ -61,17 +56,12 @@ def train_epoch(engine, epoch_id, print_batch_step):
loss = loss_dict["loss"] / engine.update_freq
# backward & step opt
if engine.amp:
scaled = engine.scaler.scale(loss)
scaled.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
loss.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
scaled = engine.scaler.scale(loss)
scaled.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
# optimizer.step() with auto amp
engine.scaler.minimize(engine.optimizer[i], scaled)
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
......
......@@ -62,13 +62,7 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step):
inputs = paddle.concat([inputs_x, inputs_u_w, inputs_u_s], axis=0)
# image input
if engine.amp:
amp_level = engine.config['AMP'].get("level", "O1").upper()
with paddle.amp.auto_cast(level=amp_level):
loss_dict, logits_label = get_loss(
engine, inputs, batch_size_label, temperture, threshold,
targets_x)
else:
with engine.auto_cast(is_eval=False):
loss_dict, logits_label = get_loss(engine, inputs,
batch_size_label, temperture,
threshold, targets_x)
......@@ -77,16 +71,11 @@ def train_epoch_fixmatch(engine, epoch_id, print_batch_step):
loss = loss_dict["loss"]
# backward & step opt
if engine.amp:
scaled = engine.scaler.scale(loss)
scaled.backward()
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
loss.backward()
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
scaled = engine.scaler.scale(loss)
scaled.backward()
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
# step lr(by step)
for i in range(len(engine.lr_sch)):
......
......@@ -189,26 +189,19 @@ def get_meta_data(meta_dataloader_iter, num_domain):
def forward(engine, batch, loss_func):
batch_info = defaultdict()
batch_info = {"label": batch[1], "domain": batch[2]}
if engine.amp:
amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast(level=amp_level):
out = engine.model(batch[0], batch[1])
loss_dict = loss_func(out, batch_info)
else:
with engine.auto_cast(is_eval=False):
out = engine.model(batch[0], batch[1])
loss_dict = loss_func(out, batch_info)
return out, loss_dict
def backward(engine, loss, optimizer):
optimizer.clear_grad()
if engine.amp:
scaled = engine.scaler.scale(loss)
scaled.backward()
engine.scaler.minimize(optimizer, scaled)
else:
loss.backward()
optimizer.step()
scaled = engine.scaler.scale(loss)
scaled.backward()
engine.scaler.minimize(optimizer, scaled)
for name, layer in engine.model.backbone.named_sublayers():
if "gate" == name.split('.')[-1]:
layer.clip_gate()
......
from functools import partial
import contextlib
import paddle
class AutoCast:
def __init__(self,
use_amp=False,
amp_level="O1",
use_promote=False,
amp_eval=False):
self.use_amp = use_amp
self.amp_eval = amp_eval
if self.use_amp:
self.cast_context = partial(paddle.amp.auto_cast, level=amp_level)
def __call__(self, is_eval=False):
if self.use_amp:
# not is_eval: cast for all training
# is_eval and self.amp_eval: cast for evaluation only when amp_eval is True
if not is_eval or (is_eval and self.amp_eval):
return self.cast_context()
return contextlib.nullcontext()
def build_scaler(use_amp=False, scale_loss=1.0,
use_dynamic_loss_scaling=False):
class Foo:
def __init__(self):
pass
def scale(self, loss):
return loss
def step(self, optimizer):
optimizer.step()
def update(self):
return
def minimize(self, optimizer, loss):
optimizer.step()
if use_amp:
return paddle.amp.GradScaler(
init_loss_scaling=scale_loss,
use_dynamic_loss_scaling=use_dynamic_loss_scaling)
return Foo()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册