diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 9915dc2c40536599c7d72a2618b00fed5916f99b..942391d393fe0074e8d57c28da3394fd186e4e31 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -88,32 +88,25 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): random.seed(worker_seed) -def build_dataloader(config, mode, seed=None): +def build(config, mode, use_dali=False, seed=None): assert mode in [ 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" - assert mode in config["DataLoader"].keys(), "{} config not in yaml".format( - mode) - - dataloader_config = config["DataLoader"][mode] - class_num = config["Arch"].get("class_num", None) - epochs = config["Global"]["epochs"] - use_dali = config["Global"].get("use_dali", False) - num_workers = dataloader_config['loader']["num_workers"] - use_shared_memory = dataloader_config['loader']["use_shared_memory"] - + assert mode in config.keys(), "{} config not in yaml".format(mode) # build dataset if use_dali: from ppcls.data.dataloader.dali import dali_dataloader return dali_dataloader( - config["DataLoader"], + config, mode, paddle.device.get_device(), - num_threads=num_workers, + num_threads=config[mode]['loader']["num_workers"], seed=seed, enable_fuse=True) - config_dataset = dataloader_config['dataset'] + class_num = config.get("class_num", None) + epochs = config.get("epochs", None) + config_dataset = config[mode]['dataset'] config_dataset = copy.deepcopy(config_dataset) dataset_name = config_dataset.pop('name') if 'batch_transform_ops' in config_dataset: @@ -126,7 +119,7 @@ def build_dataloader(config, mode, seed=None): logger.debug("build dataset({}) success...".format(dataset)) # build sampler - config_sampler = dataloader_config['sampler'] + config_sampler = config[mode]['sampler'] if config_sampler and "name" not in config_sampler: batch_sampler = None batch_size = config_sampler["batch_size"] @@ -160,6 +153,11 @@ def build_dataloader(config, mode, seed=None): else: batch_collate_fn = None + # build dataloader + config_loader = config[mode]['loader'] + num_workers = config_loader["num_workers"] + use_shared_memory = config_loader["use_shared_memory"] + init_fn = partial( worker_init_fn, num_workers=num_workers, @@ -196,36 +194,78 @@ def build_dataloader(config, mode, seed=None): data_loader.max_iter = max_iter data_loader.total_samples = total_samples - # TODO(gaotingquan): mv to build_sampler - if mode == "train": - if dataloader_config["Train"].get("max_iter", None): + logger.debug("build data_loader({}) success...".format(data_loader)) + return data_loader + + +# TODO(gaotingquan): perf +class DataIterator(object): + def __init__(self, dataloader, use_dali=False): + self.dataloader = dataloader + self.use_dali = use_dali + self.iterator = iter(dataloader) + self.max_iter = dataloader.max_iter + self.total_samples = dataloader.total_samples + + def get_batch(self): + # fetch data batch from dataloader + try: + batch = next(self.iterator) + except Exception: + # NOTE: reset DALI dataloader manually + if self.use_dali: + self.dataloader.reset() + self.iterator = iter(self.dataloader) + batch = next(self.iterator) + return batch + + +def build_dataloader(config, mode): + class_num = config["Arch"].get("class_num", None) + config["DataLoader"].update({"class_num": class_num}) + config["DataLoader"].update({"epochs": config["Global"]["epochs"]}) + + use_dali = config["Global"].get("use_dali", False) + dataloader_dict = { + "Train": None, + "UnLabelTrain": None, + "Eval": None, + "Query": None, + "Gallery": None, + "GalleryQuery": None + } + if mode == 'train': + train_dataloader = build( + config["DataLoader"], "Train", use_dali, seed=None) + + if config["DataLoader"]["Train"].get("max_iter", None): # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. max_iter = config["Train"].get("max_iter") update_freq = config["Global"].get("update_freq", 1) - max_iter = data_loader.max_iter // update_freq * update_freq - data_loader.max_iter = max_iter - - logger.debug("build data_loader({}) success...".format(data_loader)) - return data_loader + max_iter = train_dataloader.max_iter // update_freq * update_freq + train_dataloader.max_iter = max_iter + if config["DataLoader"]["Train"].get("convert_iterator", True): + train_dataloader = DataIterator(train_dataloader, use_dali) + dataloader_dict["Train"] = train_dataloader + if config["DataLoader"].get('UnLabelTrain', None) is not None: + dataloader_dict["UnLabelTrain"] = build( + config["DataLoader"], "UnLabelTrain", use_dali, seed=None) -# # TODO(gaotingquan): the length of dataloader should be determined by sampler -# class DataIterator(object): -# def __init__(self, dataloader, use_dali=False): -# self.dataloader = dataloader -# self.use_dali = use_dali -# self.iterator = iter(dataloader) -# self.max_iter = dataloader.max_iter -# self.total_samples = dataloader.total_samples - -# def get_batch(self): -# # fetch data batch from dataloader -# try: -# batch = next(self.iterator) -# except Exception: -# # NOTE: reset DALI dataloader manually -# if self.use_dali: -# self.dataloader.reset() -# self.iterator = iter(self.dataloader) -# batch = next(self.iterator) -# return batch + if mode == "eval" or (mode == "train" and + config["Global"]["eval_during_train"]): + task = config["Global"].get("task", "classification") + if task in ["classification", "adaface"]: + dataloader_dict["Eval"] = build( + config["DataLoader"], "Eval", use_dali, seed=None) + elif task == "retrieval": + if len(config["DataLoader"]["Eval"].keys()) == 1: + key = list(config["DataLoader"]["Eval"].keys())[0] + dataloader_dict["GalleryQuery"] = build( + config["DataLoader"]["Eval"], key, use_dali) + else: + dataloader_dict["Gallery"] = build( + config["DataLoader"]["Eval"], "Gallery", use_dali) + dataloader_dict["Query"] = build(config["DataLoader"]["Eval"], + "Query", use_dali) + return dataloader_dict diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index e8424af2e1ddb0c494cb03d6955cb7ac956e84d9..df2da7a5d921e1f1443d23c41da84936f763a9ff 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -60,8 +60,6 @@ class Engine(object): # load_pretrain self._init_pretrained() - self._init_amp() - # init train_func and eval_func self.eval = build_eval_func( self.config, mode=self.mode, model=self.model) @@ -71,7 +69,7 @@ class Engine(object): # for distributed self._init_dist() - print_config(self.config) + print_config(config) @paddle.no_grad() def infer(self): diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 07f1703f8acdd29b6a69680e366a9824a731a101..ab1378bc23d2f85169cbb3faa98e78a2b9b509a1 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -29,11 +29,10 @@ class ClassEval(object): def __init__(self, config, mode, model): self.config = config self.model = model - self.print_batch_step = self.config["Global"]["print_batch_step"] self.use_dali = self.config["Global"].get("use_dali", False) - self.eval_metric_func = build_metrics(self.config, "eval") - self.eval_dataloader = build_dataloader(self.config, "Eval") - self.eval_loss_func = build_loss(self.config, "Eval") + self.eval_metric_func = build_metrics(config, "eval") + self.eval_dataloader = build_dataloader(config, "eval") + self.eval_loss_func = build_loss(config, "eval") self.output_info = dict() @paddle.no_grad() @@ -49,12 +48,13 @@ class ClassEval(object): "reader_cost": AverageMeter( "reader_cost", ".5f", postfix=" s,"), } + print_batch_step = self.config["Global"]["print_batch_step"] tic = time.time() - total_samples = self.eval_dataloader.total_samples + total_samples = self.eval_dataloader["Eval"].total_samples accum_samples = 0 - max_iter = self.eval_dataloader.max_iter - for iter_id, batch in enumerate(self.eval_dataloader): + max_iter = self.eval_dataloader["Eval"].max_iter + for iter_id, batch in enumerate(self.eval_dataloader["Eval"]): if iter_id >= max_iter: break if iter_id == 5: @@ -130,7 +130,7 @@ class ClassEval(object): self.eval_metric_func(preds, labels) time_info["batch_cost"].update(time.time() - tic) - if iter_id % self.print_batch_step == 0: + if iter_id % print_batch_step == 0: time_msg = "s, ".join([ "{}: {:.5f}".format(key, time_info[key].avg) for key in time_info @@ -153,7 +153,7 @@ class ClassEval(object): tic = time.time() if self.use_dali: - self.eval_dataloader.reset() + self.eval_dataloader["Eval"].reset() if "ATTRMetric" in self.config["Metric"]["Eval"][0]: metric_msg = ", ".join([ diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index 54ae0cc1ad39f44efe5c7d46819e3f45a80c5a5f..eedba871ef2a86b8c0510b33e42d63fc72d8f04b 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -25,7 +25,7 @@ def build_train_func(config, mode, model, eval_func): train_mode = config["Global"].get("task", None) if train_mode is None: config["Global"]["task"] = "classification" - return ClassTrainer(config, model, eval_func) + return ClassTrainer(config, mode, model, eval_func) else: return getattr(sys.modules[__name__], "train_epoch_" + train_mode)( config, mode, model, eval_func) diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index d6a5c8228fa6acff2b923361613b4650219f5fa3..d6837fcc3e7ac05cfeab772b98b95680c19d757d 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -28,7 +28,7 @@ from ...utils.save_load import init_model, ModelSaver class ClassTrainer(object): - def __init__(self, config, model, eval_func): + def __init__(self, config, mode, model, eval_func): self.config = config self.model = model self.eval = eval_func @@ -41,32 +41,32 @@ class ClassTrainer(object): # gradient accumulation self.update_freq = self.config["Global"].get("update_freq", 1) - # TODO(gaotingquan): mv to build_model - # build EMA model - self.model_ema = self._build_ema_model() + # AMP training and evaluating + # self._init_amp() # build dataloader self.use_dali = self.config["Global"].get("use_dali", False) - self.dataloader = build_dataloader(self.config, "Train") + self.dataloader_dict = build_dataloader(self.config, mode) # build loss - self.loss_func = build_loss(config, "Train") + self.train_loss_func, self.unlabel_train_loss_func = build_loss( + self.config, mode) # build metric self.train_metric_func = build_metrics(config, "train") # build optimizer self.optimizer, self.lr_sch = build_optimizer( - self.config, self.dataloader.max_iter, - [self.model, self.loss_func], self.update_freq) + self.config, self.dataloader_dict["Train"].max_iter, + [self.model, self.train_loss_func], self.update_freq) # build model saver self.model_saver = ModelSaver( - config=self.config, - net=self.model, - loss=self.loss_func, - opt=self.optimizer, - model_ema=self.model_ema if self.model_ema else None) + self, + net_name="model", + loss_name="train_loss_func", + opt_name="optimizer", + model_ema_name="model_ema") # build best metric self.best_metric = { @@ -84,6 +84,8 @@ class ClassTrainer(object): "reader_cost", ".5f", postfix=" s,"), } + # build EMA model + self.model_ema = self._build_ema_model() self._init_checkpoints() # for visualdl @@ -171,10 +173,11 @@ class ClassTrainer(object): }, prefix="latest") def train_epoch(self, epoch_id): - self.model.train() tic = time.time() - for iter_id, batch in enumerate(self.dataloader): + for iter_id in range(self.dataloader_dict["Train"].max_iter): + batch = self.dataloader_dict["Train"].get_batch() + profiler.add_profiler_step(self.config["profiler_options"]) if iter_id == 5: for key in self.time_info: @@ -187,7 +190,7 @@ class ClassTrainer(object): self.global_step += 1 out = self.model(batch) - loss_dict = self.loss_func(out, batch[1]) + loss_dict = self.train_loss_func(out, batch[1]) # TODO(gaotingquan): mv update_freq to loss and optimizer loss = loss_dict["loss"] / self.update_freq loss.backward() diff --git a/ppcls/engine/train/utils.py b/ppcls/engine/train/utils.py index 88ceb9bb3ed9d1bee87c3121f00b7fd4f7537dc3..087e943a49913204e7475fff316a36a3d0a9a477 100644 --- a/ppcls/engine/train/utils.py +++ b/ppcls/engine/train/utils.py @@ -55,13 +55,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id): batch_size / trainer.time_info["batch_cost"].avg) eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 - ) * len(trainer.dataloader) - iter_id + ) * trainer.dataloader_dict["Train"].max_iter - iter_id ) * trainer.time_info["batch_cost"].avg eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( - epoch_id, trainer.config["Global"]["epochs"], iter_id, - len(trainer.dataloader), lr_msg, metric_msg, time_msg, ips_msg, - eta_msg)) + epoch_id, trainer.config["Global"][ + "epochs"], iter_id, trainer.dataloader_dict["Train"] + .max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) for i, lr in enumerate(trainer.lr_sch): logger.scaler( diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index 006b54eb3f7498b3dd860bc143b51488c7aee340..dbb86f25751a040e7a98310066ad39941da782b0 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -50,9 +50,8 @@ from .metabinloss import IntraDomainScatterLoss class CombinedLoss(nn.Layer): - def __init__(self, config_list, mode, amp_config=None): + def __init__(self, config_list, amp_config=None): super().__init__() - self.mode = mode loss_func = [] self.loss_weight = [] assert isinstance(config_list, list), ( @@ -69,13 +68,11 @@ class CombinedLoss(nn.Layer): self.loss_func = nn.LayerList(loss_func) logger.debug("build loss {} success.".format(loss_func)) - self.scaler = None if amp_config: - if self.mode == "Train" or AMPForwardDecorator.amp_eval: - self.scaler = paddle.amp.GradScaler( - init_loss_scaling=amp_config.get("scale_loss", 1.0), - use_dynamic_loss_scaling=amp_config.get( - "use_dynamic_loss_scaling", False)) + self.scaler = paddle.amp.GradScaler( + init_loss_scaling=config["AMP"].get("scale_loss", 1.0), + use_dynamic_loss_scaling=config["AMP"].get( + "use_dynamic_loss_scaling", False)) @AMP_forward_decorator def __call__(self, input, batch): @@ -92,26 +89,49 @@ class CombinedLoss(nn.Layer): loss = {key: loss[key] * weight for key in loss} loss_dict.update(loss) loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) - - if self.scaler: + # TODO(gaotingquan): if amp_eval & eval_loss ? + if AMPForwardDecorator.amp_level: self.scaler(loss_dict["loss"]) return loss_dict -def build_loss(config, mode): - if config["Loss"][mode] is None: - return None - module_class = CombinedLoss( - copy.deepcopy(config["Loss"][mode]), - mode, - amp_config=config.get("AMP", None)) +def build_loss(config, mode="train"): + if mode == "train": + label_loss_info = config["Loss"]["Train"] + if label_loss_info: + train_loss_func = CombinedLoss( + copy.deepcopy(label_loss_info), config.get("AMP", None)) + unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None) + if unlabel_loss_info: + unlabel_train_loss_func = CombinedLoss( + copy.deepcopy(unlabel_loss_info), config.get("AMP", None)) + else: + unlabel_train_loss_func = None - if AMPForwardDecorator.amp_level is not None: - if mode == "Train" or AMPForwardDecorator.amp_eval: - module_class = paddle.amp.decorate( - models=module_class, + if AMPForwardDecorator.amp_level is not None: + train_loss_func = paddle.amp.decorate( + models=train_loss_func, + level=AMPForwardDecorator.amp_level, + save_dtype='float32') + # TODO(gaotingquan): unlabel_loss_info may be None + unlabel_train_loss_func = paddle.amp.decorate( + models=unlabel_train_loss_func, level=AMPForwardDecorator.amp_level, save_dtype='float32') + return train_loss_func, unlabel_train_loss_func - logger.debug("build loss {} success.".format(module_class)) - return module_class + if mode == "eval" or (mode == "train" and + config["Global"]["eval_during_train"]): + loss_config = config.get("Loss", None) + if loss_config is not None: + loss_config = loss_config.get("Eval") + if loss_config is not None: + eval_loss_func = CombinedLoss( + copy.deepcopy(loss_config), config.get("AMP", None)) + + if AMPForwardDecorator.amp_level is not None and AMPForwardDecorator.amp_eval: + eval_loss_func = paddle.amp.decorate( + models=eval_loss_func, + level=AMPForwardDecorator.amp_level, + save_dtype='float32') + return eval_loss_func diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 7606128248344e279d042662aca9b241c0978f1b..776459fe566110b9b603d3e1a6ea04e5f8046d17 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -150,16 +150,21 @@ def _extract_student_weights(all_params, student_prefix="Student."): class ModelSaver(object): - def __init__(self, config, net, loss, opt, model_ema): + def __init__(self, + trainer, + net_name="model", + loss_name="train_loss_func", + opt_name="optimizer", + model_ema_name="model_ema"): # net, loss, opt, model_ema, output_dir, - self.net = net - self.loss = loss - self.opt = opt - self.model_ema = model_ema - - arch_name = config["Arch"]["name"] - self.output_dir = os.path.join(config["Global"]["output_dir"], - arch_name) + self.trainer = trainer + self.net_name = net_name + self.loss_name = loss_name + self.opt_name = opt_name + self.model_ema_name = model_ema_name + + arch_name = trainer.config["Arch"]["name"] + self.output_dir = os.path.join(trainer.output_dir, arch_name) _mkdir_if_not_exist(self.output_dir) def save(self, metric_info, prefix='ppcls', save_student_model=False): @@ -169,8 +174,8 @@ class ModelSaver(object): save_dir = os.path.join(self.output_dir, prefix) - params_state_dict = self.net.state_dict() - loss = self.loss + params_state_dict = getattr(self.trainer, self.net_name).state_dict() + loss = getattr(self.trainer, self.loss_name) if loss is not None: loss_state_dict = loss.state_dict() keys_inter = set(params_state_dict.keys()) & set( @@ -185,11 +190,11 @@ class ModelSaver(object): paddle.save(s_params, save_dir + "_student.pdparams") paddle.save(params_state_dict, save_dir + ".pdparams") - model_ema = self.model_ema + model_ema = getattr(self.trainer, self.model_ema_name) if model_ema is not None: paddle.save(model_ema.module.state_dict(), save_dir + ".ema.pdparams") - optimizer = self.opt + optimizer = getattr(self.trainer, self.opt_name) paddle.save([opt.state_dict() for opt in optimizer], save_dir + ".pdopt") paddle.save(metric_info, save_dir + ".pdstates")