提交 79666238 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: adapt to static graph in deprecating MixCELoss

上级 873869dd
...@@ -112,6 +112,7 @@ class Engine(object): ...@@ -112,6 +112,7 @@ class Engine(object):
} }
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
#TODO(gaotingquan): support rec
class_num = config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num}) self.config["DataLoader"].update({"class_num": class_num})
# build dataloader # build dataloader
......
...@@ -41,13 +41,14 @@ from ppcls.utils.misc import AverageMeter ...@@ -41,13 +41,14 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger, profiler from ppcls.utils import logger, profiler
def create_feeds(image_shape, use_mix=None, dtype="float32"): def create_feeds(image_shape, use_mix=False, class_num=None, dtype="float32"):
""" """
Create feeds as model input Create feeds as model input
Args: Args:
image_shape(list[int]): model input shape, such as [3, 224, 224] image_shape(list[int]): model input shape, such as [3, 224, 224]
use_mix(bool): whether to use mix(include mixup, cutmix, fmix) use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
class_num(int): the class number of network, required if use_mix
Returns: Returns:
feeds(dict): dict of model input variables feeds(dict): dict of model input variables
...@@ -55,13 +56,14 @@ def create_feeds(image_shape, use_mix=None, dtype="float32"): ...@@ -55,13 +56,14 @@ def create_feeds(image_shape, use_mix=None, dtype="float32"):
feeds = OrderedDict() feeds = OrderedDict()
feeds['data'] = paddle.static.data( feeds['data'] = paddle.static.data(
name="data", shape=[None] + image_shape, dtype=dtype) name="data", shape=[None] + image_shape, dtype=dtype)
if use_mix: if use_mix:
feeds['y_a'] = paddle.static.data( if class_num is None:
name="y_a", shape=[None, 1], dtype="int64") msg = "When use MixUp, CutMix and so on, you must set class_num."
feeds['y_b'] = paddle.static.data( logger.error(msg)
name="y_b", shape=[None, 1], dtype="int64") raise Exception(msg)
feeds['lam'] = paddle.static.data( feeds['target'] = paddle.static.data(
name="lam", shape=[None, 1], dtype=dtype) name="target", shape=[None, class_num], dtype="float32")
else: else:
feeds['label'] = paddle.static.data( feeds['label'] = paddle.static.data(
name="label", shape=[None, 1], dtype="int64") name="label", shape=[None, 1], dtype="int64")
...@@ -74,6 +76,7 @@ def create_fetchs(out, ...@@ -74,6 +76,7 @@ def create_fetchs(out,
architecture, architecture,
topk=5, topk=5,
epsilon=None, epsilon=None,
class_num=None,
use_mix=False, use_mix=False,
config=None, config=None,
mode="Train"): mode="Train"):
...@@ -88,6 +91,7 @@ def create_fetchs(out, ...@@ -88,6 +91,7 @@ def create_fetchs(out,
name(such as ResNet50) is needed name(such as ResNet50) is needed
topk(int): usually top5 topk(int): usually top5
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0 epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
class_num(int): the class number of network, required if use_mix
use_mix(bool): whether to use mix(include mixup, cutmix, fmix) use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
config(dict): model config config(dict): model config
...@@ -97,18 +101,16 @@ def create_fetchs(out, ...@@ -97,18 +101,16 @@ def create_fetchs(out,
fetchs = OrderedDict() fetchs = OrderedDict()
# build loss # build loss
if use_mix: if use_mix:
y_a = paddle.reshape(feeds['y_a'], [-1, 1]) if class_num is None:
y_b = paddle.reshape(feeds['y_b'], [-1, 1]) msg = "When use MixUp, CutMix and so on, you must set class_num."
lam = paddle.reshape(feeds['lam'], [-1, 1]) logger.error(msg)
raise Exception(msg)
target = paddle.reshape(feeds['target'], [-1, class_num])
else: else:
target = paddle.reshape(feeds['label'], [-1, 1]) target = paddle.reshape(feeds['label'], [-1, 1])
loss_func = build_loss(config["Loss"][mode]) loss_func = build_loss(config["Loss"][mode])
loss_dict = loss_func(out, target)
if use_mix:
loss_dict = loss_func(out, [y_a, y_b, lam])
else:
loss_dict = loss_func(out, target)
loss_out = loss_dict["loss"] loss_out = loss_dict["loss"]
fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True)) fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True))
...@@ -218,6 +220,7 @@ def mixed_precision_optimizer(config, optimizer): ...@@ -218,6 +220,7 @@ def mixed_precision_optimizer(config, optimizer):
def build(config, def build(config,
main_prog, main_prog,
startup_prog, startup_prog,
class_num=None,
step_each_epoch=100, step_each_epoch=100,
is_train=True, is_train=True,
is_distributed=True): is_distributed=True):
...@@ -233,6 +236,7 @@ def build(config, ...@@ -233,6 +236,7 @@ def build(config,
config(dict): config config(dict): config
main_prog(): main program main_prog(): main program
startup_prog(): startup program startup_prog(): startup program
class_num(int): the class number of network, required if use_mix
is_train(bool): train or eval is_train(bool): train or eval
is_distributed(bool): whether to use distributed training method is_distributed(bool): whether to use distributed training method
...@@ -245,10 +249,10 @@ def build(config, ...@@ -245,10 +249,10 @@ def build(config,
mode = "Train" if is_train else "Eval" mode = "Train" if is_train else "Eval"
use_mix = "batch_transform_ops" in config["DataLoader"][mode][ use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"] "dataset"]
use_dali = config["Global"].get('use_dali', False)
feeds = create_feeds( feeds = create_feeds(
config["Global"]["image_shape"], config["Global"]["image_shape"],
use_mix=use_mix, use_mix,
class_num=class_num,
dtype="float32") dtype="float32")
# build model # build model
...@@ -264,6 +268,7 @@ def build(config, ...@@ -264,6 +268,7 @@ def build(config,
feeds, feeds,
config["Arch"], config["Arch"],
epsilon=config.get('ls_epsilon'), epsilon=config.get('ls_epsilon'),
class_num=class_num,
use_mix=use_mix, use_mix=use_mix,
config=config, config=config,
mode=mode) mode=mode)
......
...@@ -112,6 +112,8 @@ def main(args): ...@@ -112,6 +112,8 @@ def main(args):
eval_dataloader = None eval_dataloader = None
use_dali = global_config.get('use_dali', False) use_dali = global_config.get('use_dali', False)
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
train_dataloader = build_dataloader( train_dataloader = build_dataloader(
config["DataLoader"], "Train", device=device, use_dali=use_dali) config["DataLoader"], "Train", device=device, use_dali=use_dali)
if global_config["eval_during_train"]: if global_config["eval_during_train"]:
...@@ -131,6 +133,7 @@ def main(args): ...@@ -131,6 +133,7 @@ def main(args):
config, config,
train_prog, train_prog,
startup_prog, startup_prog,
class_num,
step_each_epoch=step_each_epoch, step_each_epoch=step_each_epoch,
is_train=True, is_train=True,
is_distributed=global_config.get("is_distributed", True)) is_distributed=global_config.get("is_distributed", True))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册