import os from .base_logger import BaseLogger class WandbLogger(BaseLogger): def __init__(self, project=None, name=None, id=None, entity=None, save_dir=None, config=None, **kwargs): try: import wandb self.wandb = wandb except ModuleNotFoundError: raise ModuleNotFoundError( "Please install wandb using `pip install wandb`" ) self.project = project self.name = name self.id = id self.save_dir = save_dir self.config = config self.kwargs = kwargs self.entity = entity self._run = None self._wandb_init = dict( project=self.project, name=self.name, id=self.id, entity=self.entity, dir=self.save_dir, resume="allow" ) self._wandb_init.update(**kwargs) _ = self.run if self.config: self.run.config.update(self.config) @property def run(self): if self._run is None: if self.wandb.run is not None: logger.info( "There is a wandb run already in progress " "and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()`" "before instantiating `WandbLogger`." ) self._run = self.wandb.run else: self._run = self.wandb.init(**self._wandb_init) return self._run def log_metrics(self, metrics, prefix=None, step=None): if not prefix: prefix = "" updated_metrics = {prefix + "/" + k: v for k, v in metrics.items()} self.run.log(updated_metrics, step=step) def log_model(self, is_best, prefix, metadata=None): model_path = os.path.join(self.save_dir, prefix + '.pdparams') artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata) artifact.add_file(model_path, name="model_ckpt.pdparams") aliases = [prefix] if is_best: aliases.append("best") self.run.log_artifact(artifact, aliases=aliases) def close(self): self.run.finish()