diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 38f5b67b8e0a7b3cc3298f568ac8e2bf415d6b98..66794d2075eadc7c4c2133b1d465d0710cd911e0 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -112,6 +112,7 @@ class Engine(object): } paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + #TODO(gaotingquan): support rec class_num = config["Arch"].get("class_num", None) self.config["DataLoader"].update({"class_num": class_num}) # build dataloader diff --git a/ppcls/static/program.py b/ppcls/static/program.py index f6735a3f765b3396a518e2d289702f9f1cb1225b..9075a359b8ad0d991865d19f413ca250c39368f1 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -41,13 +41,14 @@ from ppcls.utils.misc import AverageMeter from ppcls.utils import logger, profiler -def create_feeds(image_shape, use_mix=None, dtype="float32"): +def create_feeds(image_shape, use_mix=False, class_num=None, dtype="float32"): """ Create feeds as model input Args: image_shape(list[int]): model input shape, such as [3, 224, 224] use_mix(bool): whether to use mix(include mixup, cutmix, fmix) + class_num(int): the class number of network, required if use_mix Returns: feeds(dict): dict of model input variables @@ -55,13 +56,14 @@ def create_feeds(image_shape, use_mix=None, dtype="float32"): feeds = OrderedDict() feeds['data'] = paddle.static.data( name="data", shape=[None] + image_shape, dtype=dtype) + if use_mix: - feeds['y_a'] = paddle.static.data( - name="y_a", shape=[None, 1], dtype="int64") - feeds['y_b'] = paddle.static.data( - name="y_b", shape=[None, 1], dtype="int64") - feeds['lam'] = paddle.static.data( - name="lam", shape=[None, 1], dtype=dtype) + if class_num is None: + msg = "When use MixUp, CutMix and so on, you must set class_num." + logger.error(msg) + raise Exception(msg) + feeds['target'] = paddle.static.data( + name="target", shape=[None, class_num], dtype="float32") else: feeds['label'] = paddle.static.data( name="label", shape=[None, 1], dtype="int64") @@ -74,6 +76,7 @@ def create_fetchs(out, architecture, topk=5, epsilon=None, + class_num=None, use_mix=False, config=None, mode="Train"): @@ -88,6 +91,7 @@ def create_fetchs(out, name(such as ResNet50) is needed topk(int): usually top5 epsilon(float): parameter for label smoothing, 0.0 <= epsilon <= 1.0 + class_num(int): the class number of network, required if use_mix use_mix(bool): whether to use mix(include mixup, cutmix, fmix) config(dict): model config @@ -97,18 +101,16 @@ def create_fetchs(out, fetchs = OrderedDict() # build loss if use_mix: - y_a = paddle.reshape(feeds['y_a'], [-1, 1]) - y_b = paddle.reshape(feeds['y_b'], [-1, 1]) - lam = paddle.reshape(feeds['lam'], [-1, 1]) + if class_num is None: + msg = "When use MixUp, CutMix and so on, you must set class_num." + logger.error(msg) + raise Exception(msg) + target = paddle.reshape(feeds['target'], [-1, class_num]) else: target = paddle.reshape(feeds['label'], [-1, 1]) loss_func = build_loss(config["Loss"][mode]) - - if use_mix: - loss_dict = loss_func(out, [y_a, y_b, lam]) - else: - loss_dict = loss_func(out, target) + loss_dict = loss_func(out, target) loss_out = loss_dict["loss"] fetchs['loss'] = (loss_out, AverageMeter('loss', '7.4f', need_avg=True)) @@ -218,6 +220,7 @@ def mixed_precision_optimizer(config, optimizer): def build(config, main_prog, startup_prog, + class_num=None, step_each_epoch=100, is_train=True, is_distributed=True): @@ -233,6 +236,7 @@ def build(config, config(dict): config main_prog(): main program startup_prog(): startup program + class_num(int): the class number of network, required if use_mix is_train(bool): train or eval is_distributed(bool): whether to use distributed training method @@ -245,10 +249,10 @@ def build(config, mode = "Train" if is_train else "Eval" use_mix = "batch_transform_ops" in config["DataLoader"][mode][ "dataset"] - use_dali = config["Global"].get('use_dali', False) feeds = create_feeds( config["Global"]["image_shape"], - use_mix=use_mix, + use_mix, + class_num=class_num, dtype="float32") # build model @@ -264,6 +268,7 @@ def build(config, feeds, config["Arch"], epsilon=config.get('ls_epsilon'), + class_num=class_num, use_mix=use_mix, config=config, mode=mode) diff --git a/ppcls/static/train.py b/ppcls/static/train.py index 5c56c17cb6d9a1e17bf9dc02bd8ab5a4b89b14fb..e262f27ffdcc827414df459822f45963f3bb0f92 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -115,6 +115,8 @@ def main(args): eval_dataloader = None use_dali = global_config.get('use_dali', False) + class_num = config["Arch"].get("class_num", None) + config["DataLoader"].update({"class_num": class_num}) train_dataloader = build_dataloader( config["DataLoader"], "Train", device=device, use_dali=use_dali) if global_config["eval_during_train"]: @@ -134,6 +136,7 @@ def main(args): config, train_prog, startup_prog, + class_num, step_each_epoch=step_each_epoch, is_train=True, is_distributed=global_config.get("is_distributed", True))