提交 aa52682c 编写于 作者: T Tingquan Gao

Revert "rm amp code from train and eval & use decorator for amp training"

This reverts commit d3941dc1.
上级 85e200ed
......@@ -29,7 +29,6 @@ from ..utils import logger
from ..utils.save_load import load_dygraph_pretrain
from .slim import prune_model, quantize_model
from .distill.afd_attention import LinearTransformStudent, LinearTransformTeacher
from ..utils.amp import AMPForwardDecorator
__all__ = ["build_model", "RecModel", "DistillationModel", "AttentionModel"]
......@@ -56,12 +55,6 @@ def build_model(config, mode="train"):
# set @to_static for benchmark, skip this by default.
model = apply_to_static(config, model)
if AMPForwardDecorator.amp_level:
model = paddle.amp.decorate(
models=model,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
return model
......
import functools
def clas_forward_decorator(forward_func):
@functools.wraps(forward_func)
def parse_batch_wrapper(model, batch):
x, label = batch[0], batch[1]
return forward_func(model, x)
return parse_batch_wrapper
\ No newline at end of file
return parse_batch_wrapper
......@@ -99,7 +99,16 @@ class Engine(object):
image_file_list.append(image_file)
if len(batch_data) >= batch_size or idx == len(image_list) - 1:
batch_tensor = paddle.to_tensor(batch_data)
out = self.model(batch_tensor)
if self.amp and self.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=self.amp_level):
out = self.model(batch_tensor)
else:
out = self.model(batch_tensor)
if isinstance(out, list):
out = out[0]
......@@ -200,14 +209,10 @@ class Engine(object):
self.config["Global"]["pretrained_model"])
def _init_amp(self):
if "AMP" in self.config and self.config["AMP"] is not None:
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "2.4", "0.0"]:
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
logger.error(msg)
raise Exception(msg)
self.amp = "AMP" in self.config and self.config["AMP"] is not None
self.amp_eval = False
# for amp
if self.amp:
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
......@@ -215,26 +220,51 @@ class Engine(object):
})
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
amp_level = self.config['AMP'].get("level", "O1").upper()
if amp_level not in ["O1", "O2"]:
self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
self.use_dynamic_loss_scaling = self.config["AMP"].get(
"use_dynamic_loss_scaling", False)
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=self.scale_loss,
use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)
self.amp_level = self.config['AMP'].get("level", "O1")
if self.amp_level not in ["O1", "O2"]:
msg = "[Parameter Error]: The optimize level of AMP only support 'O1' and 'O2'. The level has been set 'O1'."
logger.warning(msg)
self.config['AMP']["level"] = "O1"
amp_level = "O1"
self.amp_level = "O1"
amp_eval = self.config["AMP"].get("use_fp16_test", False)
self.amp_eval = self.config["AMP"].get("use_fp16_test", False)
# TODO(gaotingquan): Paddle not yet support FP32 evaluation when training with AMPO2
if self.mode == "train" and self.config["Global"].get(
"eval_during_train",
True) and amp_level == "O2" and amp_eval == False:
True) and self.amp_level == "O2" and self.amp_eval == False:
msg = "PaddlePaddle only support FP16 evaluation when training with AMP O2 now. "
logger.warning(msg)
self.config["AMP"]["use_fp16_test"] = True
amp_eval = True
self.amp_eval = True
paddle_version = paddle.__version__[:3]
# paddle version < 2.3.0 and not develop
if paddle_version not in ["2.3", "2.4", "0.0"]:
msg = "When using AMP, PaddleClas release/2.6 and later version only support PaddlePaddle version >= 2.3.0."
logger.error(msg)
raise Exception(msg)
if self.mode == "train" or self.amp_eval:
self.model = paddle.amp.decorate(
models=self.model,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" and len(self.train_loss_func.parameters(
)) > 0:
self.train_loss_func = paddle.amp.decorate(
models=self.train_loss_func,
level=self.amp_level,
save_dtype='float32')
if self.mode == "train" or amp_eval:
AMPForwardDecorator.amp_level = amp_level
AMPForwardDecorator.amp_eval = amp_eval
self.amp_level = engine.config["AMP"].get("level", "O1").upper()
def _init_dist(self):
# check the gpu num
......
......@@ -67,6 +67,16 @@ class ClassEval(object):
if not self.config["Global"].get("use_multilabel", False):
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
# image input
# if engine.amp and engine.amp_eval:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# else:
# out = self.model(batch)
out = self.model(batch)
# just for DistributedBatchSampler issue: repeat sampling
......@@ -117,6 +127,14 @@ class ClassEval(object):
# calc loss
if self.eval_loss_func is not None:
# if self.amp and self.amp_eval:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# loss_dict = engine.eval_loss_func(preds, labels)
# else:
loss_dict = self.eval_loss_func(preds, labels)
for key in loss_dict:
......
......@@ -189,9 +189,31 @@ class ClassTrainer(object):
batch[1] = batch[1].reshape([batch_size, -1])
self.global_step += 1
# forward & backward & step opt
# if engine.amp:
# with paddle.amp.auto_cast(
# custom_black_list={
# "flatten_contiguous_range", "greater_than"
# },
# level=engine.amp_level):
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# scaled = engine.scaler.scale(loss)
# scaled.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.scaler.minimize(engine.optimizer[i], scaled)
# else:
# out = engine.model(batch)
# loss_dict = engine.train_loss_func(out, batch[1])
# loss = loss_dict["loss"] / engine.update_freq
# loss.backward()
# if (iter_id + 1) % engine.update_freq == 0:
# for i in range(len(engine.optimizer)):
# engine.optimizer[i].step()
out = self.model(batch)
loss_dict = self.train_loss_func(out, batch[1])
# TODO(gaotingquan): mv update_freq to loss and optimizer
loss = loss_dict["loss"] / self.update_freq
loss.backward()
......@@ -254,4 +276,4 @@ class ClassTrainer(object):
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
self.best_metric.update(metric_info)
\ No newline at end of file
self.best_metric.update(metric_info)
......@@ -2,8 +2,7 @@ import copy
import paddle
import paddle.nn as nn
from ..utils import logger
from ..utils.amp import AMPForwardDecorator, AMP_forward_decorator
from ppcls.utils import logger
from .celoss import CELoss
from .googlenetloss import GoogLeNetLoss
......@@ -50,7 +49,7 @@ from .metabinloss import IntraDomainScatterLoss
class CombinedLoss(nn.Layer):
def __init__(self, config_list, amp_config=None):
def __init__(self, config_list):
super().__init__()
loss_func = []
self.loss_weight = []
......@@ -68,13 +67,6 @@ class CombinedLoss(nn.Layer):
self.loss_func = nn.LayerList(loss_func)
logger.debug("build loss {} success.".format(loss_func))
if amp_config:
self.scaler = paddle.amp.GradScaler(
init_loss_scaling=config["AMP"].get("scale_loss", 1.0),
use_dynamic_loss_scaling=config["AMP"].get(
"use_dynamic_loss_scaling", False))
@AMP_forward_decorator
def __call__(self, input, batch):
loss_dict = {}
# just for accelerate classification traing speed
......@@ -89,49 +81,25 @@ class CombinedLoss(nn.Layer):
loss = {key: loss[key] * weight for key in loss}
loss_dict.update(loss)
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
# TODO(gaotingquan): if amp_eval & eval_loss ?
if AMPForwardDecorator.amp_level:
self.scaler(loss_dict["loss"])
return loss_dict
def build_loss(config, mode="train"):
train_loss_func, unlabel_train_loss_func, eval_loss_func = None, None, None
if mode == "train":
label_loss_info = config["Loss"]["Train"]
if label_loss_info:
train_loss_func = CombinedLoss(
copy.deepcopy(label_loss_info), config.get("AMP", None))
train_loss_func = CombinedLoss(copy.deepcopy(label_loss_info))
unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None)
if unlabel_loss_info:
unlabel_train_loss_func = CombinedLoss(
copy.deepcopy(unlabel_loss_info), config.get("AMP", None))
else:
unlabel_train_loss_func = None
if AMPForwardDecorator.amp_level is not None:
train_loss_func = paddle.amp.decorate(
models=train_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
# TODO(gaotingquan): unlabel_loss_info may be None
unlabel_train_loss_func = paddle.amp.decorate(
models=unlabel_train_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
copy.deepcopy(unlabel_loss_info))
return train_loss_func, unlabel_train_loss_func
if mode == "eval" or (mode == "train" and
config["Global"]["eval_during_train"]):
loss_config = config.get("Loss", None)
if loss_config is not None:
loss_config = loss_config.get("Eval")
if loss_config is not None:
eval_loss_func = CombinedLoss(
copy.deepcopy(loss_config), config.get("AMP", None))
if AMPForwardDecorator.amp_level is not None and AMPForwardDecorator.amp_eval:
eval_loss_func = paddle.amp.decorate(
models=eval_loss_func,
level=AMPForwardDecorator.amp_level,
save_dtype='float32')
eval_loss_func = CombinedLoss(copy.deepcopy(loss_config))
return eval_loss_func
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册