diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 942391d393fe0074e8d57c28da3394fd186e4e31..9915dc2c40536599c7d72a2618b00fed5916f99b 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -88,25 +88,32 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): random.seed(worker_seed) -def build(config, mode, use_dali=False, seed=None): +def build_dataloader(config, mode, seed=None): assert mode in [ 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" - assert mode in config.keys(), "{} config not in yaml".format(mode) + assert mode in config["DataLoader"].keys(), "{} config not in yaml".format( + mode) + + dataloader_config = config["DataLoader"][mode] + class_num = config["Arch"].get("class_num", None) + epochs = config["Global"]["epochs"] + use_dali = config["Global"].get("use_dali", False) + num_workers = dataloader_config['loader']["num_workers"] + use_shared_memory = dataloader_config['loader']["use_shared_memory"] + # build dataset if use_dali: from ppcls.data.dataloader.dali import dali_dataloader return dali_dataloader( - config, + config["DataLoader"], mode, paddle.device.get_device(), - num_threads=config[mode]['loader']["num_workers"], + num_threads=num_workers, seed=seed, enable_fuse=True) - class_num = config.get("class_num", None) - epochs = config.get("epochs", None) - config_dataset = config[mode]['dataset'] + config_dataset = dataloader_config['dataset'] config_dataset = copy.deepcopy(config_dataset) dataset_name = config_dataset.pop('name') if 'batch_transform_ops' in config_dataset: @@ -119,7 +126,7 @@ def build(config, mode, use_dali=False, seed=None): logger.debug("build dataset({}) success...".format(dataset)) # build sampler - config_sampler = config[mode]['sampler'] + config_sampler = dataloader_config['sampler'] if config_sampler and "name" not in config_sampler: batch_sampler = None batch_size = config_sampler["batch_size"] @@ -153,11 +160,6 @@ def build(config, mode, use_dali=False, seed=None): else: batch_collate_fn = None - # build dataloader - config_loader = config[mode]['loader'] - num_workers = config_loader["num_workers"] - use_shared_memory = config_loader["use_shared_memory"] - init_fn = partial( worker_init_fn, num_workers=num_workers, @@ -194,78 +196,36 @@ def build(config, mode, use_dali=False, seed=None): data_loader.max_iter = max_iter data_loader.total_samples = total_samples - logger.debug("build data_loader({}) success...".format(data_loader)) - return data_loader - - -# TODO(gaotingquan): perf -class DataIterator(object): - def __init__(self, dataloader, use_dali=False): - self.dataloader = dataloader - self.use_dali = use_dali - self.iterator = iter(dataloader) - self.max_iter = dataloader.max_iter - self.total_samples = dataloader.total_samples - - def get_batch(self): - # fetch data batch from dataloader - try: - batch = next(self.iterator) - except Exception: - # NOTE: reset DALI dataloader manually - if self.use_dali: - self.dataloader.reset() - self.iterator = iter(self.dataloader) - batch = next(self.iterator) - return batch - - -def build_dataloader(config, mode): - class_num = config["Arch"].get("class_num", None) - config["DataLoader"].update({"class_num": class_num}) - config["DataLoader"].update({"epochs": config["Global"]["epochs"]}) - - use_dali = config["Global"].get("use_dali", False) - dataloader_dict = { - "Train": None, - "UnLabelTrain": None, - "Eval": None, - "Query": None, - "Gallery": None, - "GalleryQuery": None - } - if mode == 'train': - train_dataloader = build( - config["DataLoader"], "Train", use_dali, seed=None) - - if config["DataLoader"]["Train"].get("max_iter", None): + # TODO(gaotingquan): mv to build_sampler + if mode == "train": + if dataloader_config["Train"].get("max_iter", None): # set max iteration per epoch mannualy, when training by iteration(s), such as XBM, FixMatch. max_iter = config["Train"].get("max_iter") update_freq = config["Global"].get("update_freq", 1) - max_iter = train_dataloader.max_iter // update_freq * update_freq - train_dataloader.max_iter = max_iter - if config["DataLoader"]["Train"].get("convert_iterator", True): - train_dataloader = DataIterator(train_dataloader, use_dali) - dataloader_dict["Train"] = train_dataloader + max_iter = data_loader.max_iter // update_freq * update_freq + data_loader.max_iter = max_iter + + logger.debug("build data_loader({}) success...".format(data_loader)) + return data_loader - if config["DataLoader"].get('UnLabelTrain', None) is not None: - dataloader_dict["UnLabelTrain"] = build( - config["DataLoader"], "UnLabelTrain", use_dali, seed=None) - if mode == "eval" or (mode == "train" and - config["Global"]["eval_during_train"]): - task = config["Global"].get("task", "classification") - if task in ["classification", "adaface"]: - dataloader_dict["Eval"] = build( - config["DataLoader"], "Eval", use_dali, seed=None) - elif task == "retrieval": - if len(config["DataLoader"]["Eval"].keys()) == 1: - key = list(config["DataLoader"]["Eval"].keys())[0] - dataloader_dict["GalleryQuery"] = build( - config["DataLoader"]["Eval"], key, use_dali) - else: - dataloader_dict["Gallery"] = build( - config["DataLoader"]["Eval"], "Gallery", use_dali) - dataloader_dict["Query"] = build(config["DataLoader"]["Eval"], - "Query", use_dali) - return dataloader_dict +# # TODO(gaotingquan): the length of dataloader should be determined by sampler +# class DataIterator(object): +# def __init__(self, dataloader, use_dali=False): +# self.dataloader = dataloader +# self.use_dali = use_dali +# self.iterator = iter(dataloader) +# self.max_iter = dataloader.max_iter +# self.total_samples = dataloader.total_samples + +# def get_batch(self): +# # fetch data batch from dataloader +# try: +# batch = next(self.iterator) +# except Exception: +# # NOTE: reset DALI dataloader manually +# if self.use_dali: +# self.dataloader.reset() +# self.iterator = iter(self.dataloader) +# batch = next(self.iterator) +# return batch diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index df2da7a5d921e1f1443d23c41da84936f763a9ff..e8424af2e1ddb0c494cb03d6955cb7ac956e84d9 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -60,6 +60,8 @@ class Engine(object): # load_pretrain self._init_pretrained() + self._init_amp() + # init train_func and eval_func self.eval = build_eval_func( self.config, mode=self.mode, model=self.model) @@ -69,7 +71,7 @@ class Engine(object): # for distributed self._init_dist() - print_config(config) + print_config(self.config) @paddle.no_grad() def infer(self): diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index ab1378bc23d2f85169cbb3faa98e78a2b9b509a1..07f1703f8acdd29b6a69680e366a9824a731a101 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -29,10 +29,11 @@ class ClassEval(object): def __init__(self, config, mode, model): self.config = config self.model = model + self.print_batch_step = self.config["Global"]["print_batch_step"] self.use_dali = self.config["Global"].get("use_dali", False) - self.eval_metric_func = build_metrics(config, "eval") - self.eval_dataloader = build_dataloader(config, "eval") - self.eval_loss_func = build_loss(config, "eval") + self.eval_metric_func = build_metrics(self.config, "eval") + self.eval_dataloader = build_dataloader(self.config, "Eval") + self.eval_loss_func = build_loss(self.config, "Eval") self.output_info = dict() @paddle.no_grad() @@ -48,13 +49,12 @@ class ClassEval(object): "reader_cost": AverageMeter( "reader_cost", ".5f", postfix=" s,"), } - print_batch_step = self.config["Global"]["print_batch_step"] tic = time.time() - total_samples = self.eval_dataloader["Eval"].total_samples + total_samples = self.eval_dataloader.total_samples accum_samples = 0 - max_iter = self.eval_dataloader["Eval"].max_iter - for iter_id, batch in enumerate(self.eval_dataloader["Eval"]): + max_iter = self.eval_dataloader.max_iter + for iter_id, batch in enumerate(self.eval_dataloader): if iter_id >= max_iter: break if iter_id == 5: @@ -130,7 +130,7 @@ class ClassEval(object): self.eval_metric_func(preds, labels) time_info["batch_cost"].update(time.time() - tic) - if iter_id % print_batch_step == 0: + if iter_id % self.print_batch_step == 0: time_msg = "s, ".join([ "{}: {:.5f}".format(key, time_info[key].avg) for key in time_info @@ -153,7 +153,7 @@ class ClassEval(object): tic = time.time() if self.use_dali: - self.eval_dataloader["Eval"].reset() + self.eval_dataloader.reset() if "ATTRMetric" in self.config["Metric"]["Eval"][0]: metric_msg = ", ".join([ diff --git a/ppcls/engine/train/__init__.py b/ppcls/engine/train/__init__.py index eedba871ef2a86b8c0510b33e42d63fc72d8f04b..54ae0cc1ad39f44efe5c7d46819e3f45a80c5a5f 100644 --- a/ppcls/engine/train/__init__.py +++ b/ppcls/engine/train/__init__.py @@ -25,7 +25,7 @@ def build_train_func(config, mode, model, eval_func): train_mode = config["Global"].get("task", None) if train_mode is None: config["Global"]["task"] = "classification" - return ClassTrainer(config, mode, model, eval_func) + return ClassTrainer(config, model, eval_func) else: return getattr(sys.modules[__name__], "train_epoch_" + train_mode)( config, mode, model, eval_func) diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index d6837fcc3e7ac05cfeab772b98b95680c19d757d..d6a5c8228fa6acff2b923361613b4650219f5fa3 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -28,7 +28,7 @@ from ...utils.save_load import init_model, ModelSaver class ClassTrainer(object): - def __init__(self, config, mode, model, eval_func): + def __init__(self, config, model, eval_func): self.config = config self.model = model self.eval = eval_func @@ -41,32 +41,32 @@ class ClassTrainer(object): # gradient accumulation self.update_freq = self.config["Global"].get("update_freq", 1) - # AMP training and evaluating - # self._init_amp() + # TODO(gaotingquan): mv to build_model + # build EMA model + self.model_ema = self._build_ema_model() # build dataloader self.use_dali = self.config["Global"].get("use_dali", False) - self.dataloader_dict = build_dataloader(self.config, mode) + self.dataloader = build_dataloader(self.config, "Train") # build loss - self.train_loss_func, self.unlabel_train_loss_func = build_loss( - self.config, mode) + self.loss_func = build_loss(config, "Train") # build metric self.train_metric_func = build_metrics(config, "train") # build optimizer self.optimizer, self.lr_sch = build_optimizer( - self.config, self.dataloader_dict["Train"].max_iter, - [self.model, self.train_loss_func], self.update_freq) + self.config, self.dataloader.max_iter, + [self.model, self.loss_func], self.update_freq) # build model saver self.model_saver = ModelSaver( - self, - net_name="model", - loss_name="train_loss_func", - opt_name="optimizer", - model_ema_name="model_ema") + config=self.config, + net=self.model, + loss=self.loss_func, + opt=self.optimizer, + model_ema=self.model_ema if self.model_ema else None) # build best metric self.best_metric = { @@ -84,8 +84,6 @@ class ClassTrainer(object): "reader_cost", ".5f", postfix=" s,"), } - # build EMA model - self.model_ema = self._build_ema_model() self._init_checkpoints() # for visualdl @@ -173,11 +171,10 @@ class ClassTrainer(object): }, prefix="latest") def train_epoch(self, epoch_id): + self.model.train() tic = time.time() - for iter_id in range(self.dataloader_dict["Train"].max_iter): - batch = self.dataloader_dict["Train"].get_batch() - + for iter_id, batch in enumerate(self.dataloader): profiler.add_profiler_step(self.config["profiler_options"]) if iter_id == 5: for key in self.time_info: @@ -190,7 +187,7 @@ class ClassTrainer(object): self.global_step += 1 out = self.model(batch) - loss_dict = self.train_loss_func(out, batch[1]) + loss_dict = self.loss_func(out, batch[1]) # TODO(gaotingquan): mv update_freq to loss and optimizer loss = loss_dict["loss"] / self.update_freq loss.backward() diff --git a/ppcls/engine/train/utils.py b/ppcls/engine/train/utils.py index 087e943a49913204e7475fff316a36a3d0a9a477..88ceb9bb3ed9d1bee87c3121f00b7fd4f7537dc3 100644 --- a/ppcls/engine/train/utils.py +++ b/ppcls/engine/train/utils.py @@ -55,13 +55,13 @@ def log_info(trainer, batch_size, epoch_id, iter_id): batch_size / trainer.time_info["batch_cost"].avg) eta_sec = ((trainer.config["Global"]["epochs"] - epoch_id + 1 - ) * trainer.dataloader_dict["Train"].max_iter - iter_id + ) * len(trainer.dataloader) - iter_id ) * trainer.time_info["batch_cost"].avg eta_msg = "eta: {:s}".format(str(datetime.timedelta(seconds=int(eta_sec)))) logger.info("[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".format( - epoch_id, trainer.config["Global"][ - "epochs"], iter_id, trainer.dataloader_dict["Train"] - .max_iter, lr_msg, metric_msg, time_msg, ips_msg, eta_msg)) + epoch_id, trainer.config["Global"]["epochs"], iter_id, + len(trainer.dataloader), lr_msg, metric_msg, time_msg, ips_msg, + eta_msg)) for i, lr in enumerate(trainer.lr_sch): logger.scaler( diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index dbb86f25751a040e7a98310066ad39941da782b0..006b54eb3f7498b3dd860bc143b51488c7aee340 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -50,8 +50,9 @@ from .metabinloss import IntraDomainScatterLoss class CombinedLoss(nn.Layer): - def __init__(self, config_list, amp_config=None): + def __init__(self, config_list, mode, amp_config=None): super().__init__() + self.mode = mode loss_func = [] self.loss_weight = [] assert isinstance(config_list, list), ( @@ -68,11 +69,13 @@ class CombinedLoss(nn.Layer): self.loss_func = nn.LayerList(loss_func) logger.debug("build loss {} success.".format(loss_func)) + self.scaler = None if amp_config: - self.scaler = paddle.amp.GradScaler( - init_loss_scaling=config["AMP"].get("scale_loss", 1.0), - use_dynamic_loss_scaling=config["AMP"].get( - "use_dynamic_loss_scaling", False)) + if self.mode == "Train" or AMPForwardDecorator.amp_eval: + self.scaler = paddle.amp.GradScaler( + init_loss_scaling=amp_config.get("scale_loss", 1.0), + use_dynamic_loss_scaling=amp_config.get( + "use_dynamic_loss_scaling", False)) @AMP_forward_decorator def __call__(self, input, batch): @@ -89,49 +92,26 @@ class CombinedLoss(nn.Layer): loss = {key: loss[key] * weight for key in loss} loss_dict.update(loss) loss_dict["loss"] = paddle.add_n(list(loss_dict.values())) - # TODO(gaotingquan): if amp_eval & eval_loss ? - if AMPForwardDecorator.amp_level: + + if self.scaler: self.scaler(loss_dict["loss"]) return loss_dict -def build_loss(config, mode="train"): - if mode == "train": - label_loss_info = config["Loss"]["Train"] - if label_loss_info: - train_loss_func = CombinedLoss( - copy.deepcopy(label_loss_info), config.get("AMP", None)) - unlabel_loss_info = config.get("UnLabelLoss", {}).get("Train", None) - if unlabel_loss_info: - unlabel_train_loss_func = CombinedLoss( - copy.deepcopy(unlabel_loss_info), config.get("AMP", None)) - else: - unlabel_train_loss_func = None +def build_loss(config, mode): + if config["Loss"][mode] is None: + return None + module_class = CombinedLoss( + copy.deepcopy(config["Loss"][mode]), + mode, + amp_config=config.get("AMP", None)) - if AMPForwardDecorator.amp_level is not None: - train_loss_func = paddle.amp.decorate( - models=train_loss_func, - level=AMPForwardDecorator.amp_level, - save_dtype='float32') - # TODO(gaotingquan): unlabel_loss_info may be None - unlabel_train_loss_func = paddle.amp.decorate( - models=unlabel_train_loss_func, + if AMPForwardDecorator.amp_level is not None: + if mode == "Train" or AMPForwardDecorator.amp_eval: + module_class = paddle.amp.decorate( + models=module_class, level=AMPForwardDecorator.amp_level, save_dtype='float32') - return train_loss_func, unlabel_train_loss_func - if mode == "eval" or (mode == "train" and - config["Global"]["eval_during_train"]): - loss_config = config.get("Loss", None) - if loss_config is not None: - loss_config = loss_config.get("Eval") - if loss_config is not None: - eval_loss_func = CombinedLoss( - copy.deepcopy(loss_config), config.get("AMP", None)) - - if AMPForwardDecorator.amp_level is not None and AMPForwardDecorator.amp_eval: - eval_loss_func = paddle.amp.decorate( - models=eval_loss_func, - level=AMPForwardDecorator.amp_level, - save_dtype='float32') - return eval_loss_func + logger.debug("build loss {} success.".format(module_class)) + return module_class diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 776459fe566110b9b603d3e1a6ea04e5f8046d17..7606128248344e279d042662aca9b241c0978f1b 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -150,21 +150,16 @@ def _extract_student_weights(all_params, student_prefix="Student."): class ModelSaver(object): - def __init__(self, - trainer, - net_name="model", - loss_name="train_loss_func", - opt_name="optimizer", - model_ema_name="model_ema"): + def __init__(self, config, net, loss, opt, model_ema): # net, loss, opt, model_ema, output_dir, - self.trainer = trainer - self.net_name = net_name - self.loss_name = loss_name - self.opt_name = opt_name - self.model_ema_name = model_ema_name - - arch_name = trainer.config["Arch"]["name"] - self.output_dir = os.path.join(trainer.output_dir, arch_name) + self.net = net + self.loss = loss + self.opt = opt + self.model_ema = model_ema + + arch_name = config["Arch"]["name"] + self.output_dir = os.path.join(config["Global"]["output_dir"], + arch_name) _mkdir_if_not_exist(self.output_dir) def save(self, metric_info, prefix='ppcls', save_student_model=False): @@ -174,8 +169,8 @@ class ModelSaver(object): save_dir = os.path.join(self.output_dir, prefix) - params_state_dict = getattr(self.trainer, self.net_name).state_dict() - loss = getattr(self.trainer, self.loss_name) + params_state_dict = self.net.state_dict() + loss = self.loss if loss is not None: loss_state_dict = loss.state_dict() keys_inter = set(params_state_dict.keys()) & set( @@ -190,11 +185,11 @@ class ModelSaver(object): paddle.save(s_params, save_dir + "_student.pdparams") paddle.save(params_state_dict, save_dir + ".pdparams") - model_ema = getattr(self.trainer, self.model_ema_name) + model_ema = self.model_ema if model_ema is not None: paddle.save(model_ema.module.state_dict(), save_dir + ".ema.pdparams") - optimizer = getattr(self.trainer, self.opt_name) + optimizer = self.opt paddle.save([opt.state_dict() for opt in optimizer], save_dir + ".pdopt") paddle.save(metric_info, save_dir + ".pdstates")