base_trainer.py 6.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
import os
import glob
from paddle import fluid
from paddle.fluid import dygraph
import pickle


class BaseTrainer:
    """Base trainer class. Contains functions for training and saving/loading chackpoints.
    Trainer classes should inherit from this one and overload the train_epoch function."""

    def __init__(self, actor, loaders, optimizer, settings, lr_scheduler=None):
        """
        args:
            actor - The actor for training the network
            loaders - list of dataset loaders, e.g. [train_loader, val_loader]. In each epoch, the trainer runs one
                        epoch for each loader.
            optimizer - The optimizer used for training, e.g. Adam
            settings - Training settings
            lr_scheduler - Learning rate scheduler
        """
        self.actor = actor
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.loaders = loaders

        self.update_settings(settings)

        self.epoch = 0
        self.stats = {}

    def update_settings(self, settings=None):
        """Updates the trainer settings. Must be called to update internal settings."""
        if settings is not None:
            self.settings = settings

        if self.settings.env.workspace_dir is not None:
            self.settings.env.workspace_dir = os.path.expanduser(
                self.settings.env.workspace_dir)
            self._checkpoint_dir = os.path.join(self.settings.env.workspace_dir,
                                                'checkpoints')
            if not os.path.exists(self._checkpoint_dir):
                os.makedirs(self._checkpoint_dir)
        else:
            self._checkpoint_dir = None

    def train(self, max_epochs, load_latest=False, fail_safe=True):
        """Do training for the given number of epochs.
        args:
            max_epochs - Max number of training epochs,
            load_latest - Bool indicating whether to resume from latest epoch.
            fail_safe - Bool indicating whether the training to automatically restart in case of any crashes.
        """

        num_tries = 10
        for i in range(num_tries):
            try:
                if load_latest:
                    self.load_checkpoint()

                for epoch in range(self.epoch + 1, max_epochs + 1):
                    self.epoch = epoch
                    self.train_epoch()

                if self._checkpoint_dir:
                    self.save_checkpoint()
            except:
                print('Training crashed at epoch {}'.format(self.epoch))
                if fail_safe:
                    load_latest = True
                    print('Restarting training from last epoch ...')
                else:
                    raise

        print('Finished training!')

    def train_epoch(self):
        raise NotImplementedError

    def save_checkpoint(self):
        """Saves a checkpoint of the network and other variables."""
        actor_type = type(self.actor).__name__
        net_type = type(self.actor.net).__name__
        state = {
            'epoch': self.epoch,
            'actor_type': actor_type,
            'net_type': net_type,
            'net_info': getattr(self.actor.net, 'info', None),
            'constructor': getattr(self.actor.net, 'constructor', None),
            'stats': self.stats,
            'settings': self.settings
        }

        directory = '{}/{}/{}_ep{:04d}'.format(self._checkpoint_dir,
                                               self.settings.project_path,
                                               net_type, self.epoch)
        if not os.path.exists(directory):
            os.makedirs(directory)

        fluid.save_dygraph(self.actor.net.state_dict(), directory)
        fluid.save_dygraph(self.optimizer.state_dict(), directory)
        with open(os.path.join(directory, '_custom_state.pickle'), 'wb') as f:
            pickle.dump(state, f)

    def load_checkpoint(self, checkpoint=None):
        """Loads a network checkpoint file.

        Can be called in three different ways:
            load_checkpoint():
                Loads the latest epoch from the workspace. Use this to continue training.
            load_checkpoint(epoch_num):
                Loads the network at the given epoch number (int).
            load_checkpoint(path_to_checkpoint):
                Loads the file from the given absolute path (str).
        """

        net_type = type(self.actor.net).__name__

        if checkpoint is None:
            # Load most recent checkpoint
            checkpoint_list = sorted(
                glob.glob('{}/{}/{}_ep*'.format(self._checkpoint_dir,
                                                self.settings.project_path,
                                                net_type)))
            if checkpoint_list:
                checkpoint_path = checkpoint_list[-1].split('.')[0]
            else:
                print('No matching checkpoint file found')
                return
        elif isinstance(checkpoint, int):
            # Checkpoint is the epoch number
            checkpoint_path = '{}/{}/{}_ep{:04d}'.format(
                self._checkpoint_dir, self.settings.project_path, net_type,
                checkpoint)
        elif isinstance(checkpoint, str):
            # checkpoint is the path
            checkpoint_path = os.path.expanduser(checkpoint)
        else:
            raise TypeError

        # paddle load network
        net_params, opt_params = fluid.load_dygraph(checkpoint_path)
        self.actor.net.load_dict(net_params)
        self.optimizer.set_dict(opt_params)

        # paddle load state
        state_path = '{}/{}/custom_state.pickle'.format(
            self._checkpoint_dir, self.settings.project_path)
        current_state = pickle.load(
            open(os.path.join(state_path, 'custom_state.pickle'), 'rb'))

        print("\nload checkpoint done !! Current states are as follows:")
        for key, value in enumerate(current_state):
            print(key, value)

        return True