提交 79666238 编写于 作者: G gaotingquan 提交者: Tingquan Gao

refactor: adapt to static graph in deprecating MixCELoss

上级 873869dd
......@@ -112,6 +112,7 @@ class Engine(object):
}
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
#TODO(gaotingquan): support rec
class_num = config["Arch"].get("class_num", None)
self.config["DataLoader"].update({"class_num": class_num})
# build dataloader
......
......@@ -41,13 +41,14 @@ from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger, profiler
def create_feeds(image_shape, use_mix=None, dtype="float32"):
def create_feeds(image_shape, use_mix=False, class_num=None, dtype="float32"):
"""
Create feeds as model input
Args:
image_shape(list[int]): model input shape, such as [3, 224, 224]
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
class_num(int): the class number of network, required if use_mix
Returns:
feeds(dict): dict of model input variables
......@@ -55,13 +56,14 @@ def create_feeds(image_shape, use_mix=None, dtype="float32"):
feeds = OrderedDict()
feeds['data'] = paddle.static.data(
name="data", shape=[None] + image_shape, dtype=dtype)
if use_mix:
feeds['y_a'] = paddle.static.data(
name="y_a", shape=[None, 1], dtype="int64")
feeds['y_b'] = paddle.static.data(
name="y_b", shape=[None, 1], dtype="int64")
feeds['lam'] = paddle.static.data(
name="lam", shape=[None, 1], dtype=dtype)
if class_num is None:
msg = "When use MixUp, CutMix and so on, you must set class_num."
logger.error(msg)
raise Exception(msg)
feeds['target'] = paddle.static.data(
name="target", shape=[None, class_num], dtype="float32")
else:
feeds['label'] = paddle.static.data(
name="label", shape=[None, 1], dtype="int64")
......@@ -74,6 +76,7 @@ def create_fetchs(out,
architecture,
topk=5,
epsilon=None,
class_num=None,
use_mix=False,
config=None,
mode="Train"):
......@@ -88,6 +91,7 @@ def create_fetchs(out,
name(such as ResNet50) is needed
topk(int): usually top5
epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0
class_num(int): the class number of network, required if use_mix
use_mix(bool): whether to use mix(include mixup, cutmix, fmix)
config(dict): model config
......@@ -97,17 +101,15 @@ def create_fetchs(out,
fetchs = OrderedDict()
# build loss
if use_mix:
y_a = paddle.reshape(feeds['y_a'], [-1, 1])
y_b = paddle.reshape(feeds['y_b'], [-1, 1])
lam = paddle.reshape(feeds['lam'], [-1, 1])
if class_num is None:
msg = "When use MixUp, CutMix and so on, you must set class_num."
logger.error(msg)
raise Exception(msg)
target = paddle.reshape(feeds['target'], [-1, class_num])
else:
target = paddle.reshape(feeds['label'], [-1, 1])
loss_func = build_loss(config["Loss"][mode])
if use_mix:
loss_dict = loss_func(out, [y_a, y_b, lam])
else:
loss_dict = loss_func(out, target)
loss_out = loss_dict["loss"]
......@@ -218,6 +220,7 @@ def mixed_precision_optimizer(config, optimizer):
def build(config,
main_prog,
startup_prog,
class_num=None,
step_each_epoch=100,
is_train=True,
is_distributed=True):
......@@ -233,6 +236,7 @@ def build(config,
config(dict): config
main_prog(): main program
startup_prog(): startup program
class_num(int): the class number of network, required if use_mix
is_train(bool): train or eval
is_distributed(bool): whether to use distributed training method
......@@ -245,10 +249,10 @@ def build(config,
mode = "Train" if is_train else "Eval"
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"]
use_dali = config["Global"].get('use_dali', False)
feeds = create_feeds(
config["Global"]["image_shape"],
use_mix=use_mix,
use_mix,
class_num=class_num,
dtype="float32")
# build model
......@@ -264,6 +268,7 @@ def build(config,
feeds,
config["Arch"],
epsilon=config.get('ls_epsilon'),
class_num=class_num,
use_mix=use_mix,
config=config,
mode=mode)
......
......@@ -112,6 +112,8 @@ def main(args):
eval_dataloader = None
use_dali = global_config.get('use_dali', False)
class_num = config["Arch"].get("class_num", None)
config["DataLoader"].update({"class_num": class_num})
train_dataloader = build_dataloader(
config["DataLoader"], "Train", device=device, use_dali=use_dali)
if global_config["eval_during_train"]:
......@@ -131,6 +133,7 @@ def main(args):
config,
train_prog,
startup_prog,
class_num,
step_each_epoch=step_each_epoch,
is_train=True,
is_distributed=global_config.get("is_distributed", True))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册