From 33e3f479286e123caa88013bb361a247d1890a81 Mon Sep 17 00:00:00 2001 From: Manan Goel Date: Thu, 12 May 2022 08:38:13 +0530 Subject: [PATCH] Integration of the Weights & Biases Metric Logger (#5886) * Implementation of WandbLogger callback to log metrics and model checkpoints * Added documentation on how to use WandbLogger and VisualDL * Minor bug fix and pull from upstream * Removed pass statement for cleanup * Update logging_en.md Fixed documentation * Update trainer.py --- README_en.md | 2 + docs/tutorials/logging_en.md | 46 +++++++++++ ppdet/engine/callbacks.py | 147 ++++++++++++++++++++++++++++++++++- ppdet/engine/trainer.py | 4 +- tools/train.py | 5 ++ 5 files changed, 202 insertions(+), 2 deletions(-) create mode 100644 docs/tutorials/logging_en.md diff --git a/README_en.md b/README_en.md index f30cae8c9..2028cfc5e 100644 --- a/README_en.md +++ b/README_en.md @@ -284,6 +284,8 @@ The relationship between COCO mAP and FPS on Qualcomm Snapdragon 865 of represen - [Prune/Quant/Distill](configs/slim) +- [Metric Logging during Model Training](docs/tutorials/logging_en.md) + - Inference and Deployment - [Export model for inference](deploy/EXPORT_MODEL_en.md) diff --git a/docs/tutorials/logging_en.md b/docs/tutorials/logging_en.md new file mode 100644 index 000000000..b45ceba69 --- /dev/null +++ b/docs/tutorials/logging_en.md @@ -0,0 +1,46 @@ +# Logging + +This document talks about how to track metrics and visualize model performance during training. The library currently supports [VisualDL](https://www.paddlepaddle.org.cn/documentation/docs/en/guides/03_VisualDL/visualdl_usage_en.html) and [Weights & Biases](https://docs.wandb.ai). + +## VisualDL +Logging to VisualDL is supported only in python >= 3.5. To install VisualDL + +``` +pip install visualdl +``` + +PaddleDetection uses a callback to log the training metrics at the end of every step and metrics from the validation step at the end of every epoch. To use VisualDL for visualization, add the `--use_vdl` flag to the training command and `--vdl_log_dir ` to set the directory which stores the records. + +For example + +``` +python tools/train -c config.yml --use_vdl --vdl_log_dir ./logs +``` + +Another possible way to do this is to add the aforementioned flags to the `config.yml` file. + +## Weights & Biases +W&B is a MLOps tool that can be used for experiment tracking, dataset/model versioning, visualizing results and collaborating with colleagues. A W&B logger is integrated directly into PaddleDetection and to use it, first you need to install the wandb sdk and login to your wandb account. + +``` +pip install wandb +wandb login +``` + +To use wandb to log metrics while training add the `--use_wandb` flag to the training command and any other arguments for the W&B logger can be provided like this - + +``` +python tools/train -c config.yml --use_wandb -o wandb-project=MyDetector wandb-entity=MyTeam wandb-save_dir=./logs +``` + +The arguments to the W&B logger must be proceeded by `-o` and each invidiual argument must contain the prefix "wandb-". + +If this is too tedious, an alternative way is to add the arguments to the `config.yml` file under the `wandb` header. For example + +``` +use_wandb: True +wandb: + project: MyProject + entity: MyTeam + save_dir: ./logs +``` diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index 77ca94602..09683d18b 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -198,7 +198,7 @@ class Checkpointer(Callback): "training iterations being too few or not " \ "loading the correct weights.") return - if map_res[key][0] > self.best_ap: + if map_res[key][0] >= self.best_ap: self.best_ap = map_res[key][0] save_name = 'best_model' weight = self.weight.state_dict() @@ -288,6 +288,151 @@ class VisualDLWriter(Callback): self.vdl_mAP_step) self.vdl_mAP_step += 1 +class WandbCallback(Callback): + def __init__(self, model): + super(WandbCallback, self).__init__(model) + + try: + import wandb + self.wandb = wandb + except Exception as e: + logger.error('wandb not found, please install wandb. ' + 'Use: `pip install wandb`.') + raise e + + self.wandb_params = model.cfg.get('wandb', None) + self.save_dir = os.path.join(self.model.cfg.save_dir, + self.model.cfg.filename) + if self.wandb_params is None: + self.wandb_params = {} + for k, v in model.cfg.items(): + if k.startswith("wandb_"): + self.wandb_params.update({ + k.lstrip("wandb_"): v + }) + + self._run = None + if dist.get_world_size() < 2 or dist.get_rank() == 0: + _ = self.run + self.run.config.update(self.model.cfg) + self.run.define_metric("epoch") + self.run.define_metric("eval/*", step_metric="epoch") + + self.best_ap = 0 + + @property + def run(self): + if self._run is None: + if self.wandb.run is not None: + logger.info("There is an ongoing wandb run which will be used" + "for logging. Please use `wandb.finish()` to end that" + "if the behaviour is not intended") + self._run = self.wandb.run + else: + self._run = self.wandb.init(**self.wandb_params) + return self._run + + def save_model(self, + optimizer, + save_dir, + save_name, + last_epoch, + ema_model=None, + ap=None, + tags=None): + if dist.get_world_size() < 2 or dist.get_rank() == 0: + model_path = os.path.join(save_dir, save_name) + metadata = {} + metadata["last_epoch"] = last_epoch + if ap: + metadata["ap"] = ap + if ema_model is None: + ema_artifact = self.wandb.Artifact(name="ema_model-{}".format(self.run.id), type="model", metadata=metadata) + model_artifact = self.wandb.Artifact(name="model-{}".format(self.run.id), type="model", metadata=metadata) + + ema_artifact.add_file(model_path + ".pdema", name="model_ema") + model_artifact.add_file(model_path + ".pdparams", name="model") + + self.run.log_artifact(ema_artifact, aliases=tags) + self.run.log_artfact(model_artifact, aliases=tags) + else: + model_artifact = self.wandb.Artifact(name="model-{}".format(self.run.id), type="model", metadata=metadata) + model_artifact.add_file(model_path + ".pdparams", name="model") + self.run.log_artifact(model_artifact, aliases=tags) + + def on_step_end(self, status): + + mode = status['mode'] + if dist.get_world_size() < 2 or dist.get_rank() == 0: + if mode == 'train': + training_status = status['training_staus'].get() + for k, v in training_status.items(): + training_status[k] = float(v) + metrics = { + "train/" + k: v for k,v in training_status.items() + } + self.run.log(metrics) + + def on_epoch_end(self, status): + mode = status['mode'] + epoch_id = status['epoch_id'] + save_name = None + if dist.get_world_size() < 2 or dist.get_rank() == 0: + if mode == 'train': + end_epoch = self.model.cfg.epoch + if ( + epoch_id + 1 + ) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: + save_name = str(epoch_id) if epoch_id != end_epoch - 1 else "model_final" + tags = ["latest", "epoch_{}".format(epoch_id)] + self.save_model( + self.model.optimizer, + self.save_dir, + save_name, + epoch_id + 1, + self.model.use_ema, + tags=tags + ) + if mode == 'eval': + merged_dict = {} + for metric in self.model._metrics: + for key, map_value in metric.get_results().items(): + merged_dict["eval/{}-mAP".format(key)] = map_value[0] + merged_dict["epoch"] = status["epoch_id"] + self.run.log(merged_dict) + + if 'save_best_model' in status and status['save_best_model']: + for metric in self.model._metrics: + map_res = metric.get_results() + if 'bbox' in map_res: + key = 'bbox' + elif 'keypoint' in map_res: + key = 'keypoint' + else: + key = 'mask' + if key not in map_res: + logger.warning("Evaluation results empty, this may be due to " \ + "training iterations being too few or not " \ + "loading the correct weights.") + return + if map_res[key][0] >= self.best_ap: + self.best_ap = map_res[key][0] + save_name = 'best_model' + tags = ["best", "epoch_{}".format(epoch_id)] + + self.save_model( + self.model.optimizer, + self.save_dir, + save_name, + last_epoch=epoch_id + 1, + ema_model=self.model.use_ema, + ap=self.best_ap, + tags=tags + ) + + def on_train_end(self, status): + self.run.finish() + class SniperProposalsGenerator(Callback): def __init__(self, model): diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index de188e8e1..29ca286ed 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -46,7 +46,7 @@ from ppdet.data.source.category import get_categories import ppdet.utils.stats as stats from ppdet.utils import profiler -from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator +from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback from .export_utils import _dump_infer_config, _prune_input_spec from ppdet.utils.logger import setup_logger @@ -184,6 +184,8 @@ class Trainer(object): self._callbacks.append(VisualDLWriter(self)) if self.cfg.get('save_proposals', False): self._callbacks.append(SniperProposalsGenerator(self)) + if self.cfg.get('use_wandb', False) or 'wandb' in self.cfg: + self._callbacks.append(WandbCallback(self)) self._compose_callback = ComposeCallback(self._callbacks) elif self.mode == 'eval': self._callbacks = [LogPrinter(self)] diff --git a/tools/train.py b/tools/train.py index 8e4977e77..43f883592 100755 --- a/tools/train.py +++ b/tools/train.py @@ -76,6 +76,11 @@ def parse_args(): type=str, default="vdl_log_dir/scalar", help='VisualDL logging directory for scalar.') + parser.add_argument( + "--use_wandb", + type=bool, + default=False, + help="whether to record the data to wandb.") parser.add_argument( '--save_prediction_only', action='store_true', -- GitLab