提交 aa52682c 编写于 作者: T Tingquan Gao

Revert "rm amp code from train and eval & use decorator for amp training"

This reverts commit d3941dc1.
上级 85e200ed
...@@ -29,7 +29,6 @@ from ..utils import logger ...@@ -29,7 +29,6 @@ from ..utils import logger
from ..utils.save_load import load_dygraph_pretrain from ..utils.save_load import load_dygraph_pretrain
from .slim import prune_model, quantize_model from .slim import prune_model, quantize_model
from .distill.afd_attention import LinearTransformStudent, LinearTransformTeacher from .distill.afd_attention import LinearTransformStudent, LinearTransformTeacher
from ..utils.amp import AMPForwardDecorator
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"] __all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
...@@ -56,12 +55,6 @@ def build_model(config, mode="train"): ...@@ -56,12 +55,6 @@ def build_model(config, mode="train"):
# set @to_static for benchmark, skip this by default. # set @to_static for benchmark, skip this by default.
model = apply_to_static(config, model) model = apply_to_static(config, model)
if AMPForwardDecorator.amp_level:
model = paddle.amp.decorate(
models=model,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
return model return model
......
import functools
def clas_forward_decorator(forward_func): def clas_forward_decorator(forward_func):
@functools.wraps(forward_func)
def parse_batch_wrapper(model, batch): def parse_batch_wrapper(model, batch):
x, label = batch[0], batch[1] x, label = batch[0], batch[1]
return forward_func(model, x) return forward_func(model, x)
......
...@@ -99,6 +99,15 @@ class Engine(object): ...@@ -99,6 +99,15 @@ class Engine(object):
image_file_list.append(image_file) image_file_list.append(image_file)
if len(batch_data) >= batch_size or idx == len(image_list) - 1: if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data) batch_tensor = paddle.to_tensor(batch_data)
if self.amp and self.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=self.amp_level):
out = self.model(batch_tensor)
else:
out = self.model(batch_tensor) out = self.model(batch_tensor)
if isinstance(out, list): if isinstance(out, list):
...@@ -200,14 +209,10 @@ class Engine(object): ...@@ -200,14 +209,10 @@ class Engine(object):
self.config["Global"]["pretrained_model"]) self.config["Global"]["pretrained_model"])
def _init_amp(self): def _init_amp(self):
if "AMP" in self.config and self.config["AMP"] is not None: self.amp = "AMP" in self.config and self.config["AMP"] is not None
paddle_version = paddle.__version__[:3] self.amp_eval = False
# paddle version < 2.3.0 and not develop # for amp
if paddle_version not in ["2.3", "2.4", "0.0"]: if self.amp:
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
logger.error(msg)
raise Exception(msg)
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, } AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({ AMP_RELATED_FLAGS_SETTING.update({
...@@ -215,26 +220,51 @@ class Engine(object): ...@@ -215,26 +220,51 @@ class Engine(object):
}) })
paddle.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
amp_level = self.config['AMP'].get("level", "O1").upper() self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
if amp_level not in ["O1", "O2"]: self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'." msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg) logger.warning(msg)
self.config['AMP']["level"] = "O1" self.config['AMP']["level"] = "O1"
amp_level = "O1" self.amp_level = "O1"
amp_eval = self.config["AMP"].get("use_fp16_test", False) self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2 # TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.mode == "train" and self.config["Global"].get( if self.mode == "train" and self.config["Global"].get(
"eval_during_train", "eval_during_train",
True) and amp_level == "O2" and amp_eval == False: True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. " msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg) logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True self.config["AMP"]["use_fp16_test"] = True
amp_eval = True self.amp_eval = True
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "2.4", "0.0"]:
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
logger.error(msg)
raise Exception(msg)
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" or amp_eval: self.amp_level = engine.config["AMP"].get("level", "O1").upper()
AMPForwardDecorator.amp_level = amp_level
AMPForwardDecorator.amp_eval = amp_eval
def _init_dist(self): def _init_dist(self):
# check the gpu num # check the gpu num
......
...@@ -67,6 +67,16 @@ class ClassEval(object): ...@@ -67,6 +67,16 @@ class ClassEval(object):
if not self.config["Global"].get("use_multilabel", False): if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64") batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
# if engine.amp and engine.amp_eval:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# else:
# out = self.model(batch)
out = self.model(batch) out = self.model(batch)
# just for DistributedBatchSampler issue: repeat sampling # just for DistributedBatchSampler issue: repeat sampling
...@@ -117,6 +127,14 @@ class ClassEval(object): ...@@ -117,6 +127,14 @@ class ClassEval(object):
# calc loss # calc loss
if self.eval_loss_func is not None: if self.eval_loss_func is not None:
# if self.amp and self.amp_eval:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# loss_dict = engine.eval_loss_func(preds, labels)
# else:
loss_dict = self.eval_loss_func(preds, labels) loss_dict = self.eval_loss_func(preds, labels)
for key in loss_dict: for key in loss_dict:
......
...@@ -189,9 +189,31 @@ class ClassTrainer(object): ...@@ -189,9 +189,31 @@ class ClassTrainer(object):
batch[1] = batch[1].reshape([batch_size, -1]) batch[1] = batch[1].reshape([batch_size, -1])
self.global_step += 1 self.global_step += 1
# forward & backward & step opt
# if engine.amp:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# scaled = engine.scaler.scale(loss)
# scaled.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.scaler.minimize(engine.optimizer[i], scaled)
# else:
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# loss.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.optimizer[i].step()
out = self.model(batch) out = self.model(batch)
loss_dict = self.train_loss_func(out, batch[1]) loss_dict = self.train_loss_func(out, batch[1])
# TODO(gaotingquan): mv update_freq to loss and optimizer
loss = loss_dict["loss"] / self.update_freq loss = loss_dict["loss"] / self.update_freq
loss.backward() loss.backward()
......
...@@ -2,8 +2,7 @@ import copy ...@@ -2,8 +2,7 @@ import copy
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
from ..utils import logger from ppcls.utils import logger
from ..utils.amp import AMPForwardDecorator, AMP_forward_decorator
from .celoss import CELoss from .celoss import CELoss
from .googlenetloss import GoogLeNetLoss from .googlenetloss import GoogLeNetLoss
...@@ -50,7 +49,7 @@ from .metabinloss import IntraDomainScatterLoss ...@@ -50,7 +49,7 @@ from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
def __init__(self, config_list, amp_config=None): def __init__(self, config_list):
super().__init__() super().__init__()
loss_func = [] loss_func = []
self.loss_weight = [] self.loss_weight = []
...@@ -68,13 +67,6 @@ class CombinedLoss(nn.Layer): ...@@ -68,13 +67,6 @@ class CombinedLoss(nn.Layer):
self.loss_func = nn.LayerList(loss_func) self.loss_func = nn.LayerList(loss_func)
logger.debug("build loss {} success.".format(loss_func)) logger.debug("build loss {} success.".format(loss_func))
if amp_config:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=config["AMP"].get("scale_loss", 1.0),
use_dynamic_loss_scaling=config["AMP"].get(
"use_dynamic_loss_scaling", False))
@AMP_forward_decorator
def __call__(self, input, batch): def __call__(self, input, batch):
loss_dict = {} loss_dict = {}
# just for accelerate classification traing speed # just for accelerate classification traing speed
...@@ -89,49 +81,25 @@ class CombinedLoss(nn.Layer): ...@@ -89,49 +81,25 @@ class CombinedLoss(nn.Layer):
loss = {key: loss[key] * weight for key in loss} loss = {key: loss[key] * weight for key in loss}
loss_dict.update(loss) loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
# TODO(gaotingquan): if amp_eval & eval_loss ?
if AMPForwardDecorator.amp_level:
self.scaler(loss_dict["loss"])
return loss_dict return loss_dict
def build_loss(config, mode="train"): def build_loss(config, mode="train"):
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None
if mode == "train": if mode == "train":
label_loss_info = config["Loss"]["Train"] label_loss_info = config["Loss"]["Train"]
if label_loss_info: if label_loss_info:
train_loss_func = CombinedLoss( train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info))
copy.deepcopy(label_loss_info), config.get("AMP", None))
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None) unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info: if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss( unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info), config.get("AMP", None)) copy.deepcopy(unlabel_loss_info))
else:
unlabel_train_loss_func = None
if AMPForwardDecorator.amp_level is not None:
train_loss_func = paddle.amp.decorate(
models=train_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
# TODO(gaotingquan): unlabel_loss_info may be None
unlabel_train_loss_func = paddle.amp.decorate(
models=unlabel_train_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
return train_loss_func, unlabel_train_loss_func return train_loss_func, unlabel_train_loss_func
if mode == "eval" or (mode == "train" and if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]): config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None) loss_config = config.get("Loss", None)
if loss_config is not None: if loss_config is not None:
loss_config = loss_config.get("Eval") loss_config = loss_config.get("Eval")
if loss_config is not None: if loss_config is not None:
eval_loss_func = CombinedLoss( eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
copy.deepcopy(loss_config), config.get("AMP", None))
if AMPForwardDecorator.amp_level is not None and AMPForwardDecorator.amp_eval:
eval_loss_func = paddle.amp.decorate(
models=eval_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
return eval_loss_func return eval_loss_func
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册