wandb_logger.py 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
import os
from .base_logger import BaseLogger

class WandbLogger(BaseLogger):
    def __init__(self, 
        project=None, 
        name=None, 
        id=None, 
        entity=None, 
        save_dir=None, 
        config=None,
        **kwargs):
        try:
            import wandb
            self.wandb = wandb
        except ModuleNotFoundError:
            raise ModuleNotFoundError(
                "Please install wandb using `pip install wandb`"
                )

        self.project = project
        self.name = name
        self.id = id
        self.save_dir = save_dir
        self.config = config
        self.kwargs = kwargs
        self.entity = entity
        self._run = None
        self._wandb_init = dict(
            project=self.project,
            name=self.name,
            id=self.id,
            entity=self.entity,
            dir=self.save_dir,
            resume="allow"
        )
        self._wandb_init.update(**kwargs)

        _ = self.run

        if self.config:
            self.run.config.update(self.config)

    @property
    def run(self):
        if self._run is None:
            if self.wandb.run is not None:
                logger.info(
                    "There is a wandb run already in progress "
                    "and newly created instances of `WandbLogger` will reuse"
                    " this run. If this is not desired, call `wandb.finish()`"
                    "before instantiating `WandbLogger`."
                )
                self._run = self.wandb.run
            else:
                self._run = self.wandb.init(**self._wandb_init)
        return self._run

    def log_metrics(self, metrics, prefix=None, step=None):
        if not prefix:
            prefix = ""
62 63
        updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()}
        
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        self.run.log(updated_metrics, step=step)

    def log_model(self, is_best, prefix, metadata=None):
        model_path = os.path.join(self.save_dir, prefix + '.pdparams')
        artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata)
        artifact.add_file(model_path, name="model_ckpt.pdparams")

        aliases = [prefix]
        if is_best:
            aliases.append("best")

        self.run.log_artifact(artifact, aliases=aliases)

    def close(self):
        self.run.finish()