policy.py 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
import math
import torch
import torch.nn as nn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR

from ding.policy import Policy
from ding.model import model_wrap
from ding.torch_utils import to_device
from ding.utils import EasyTimer


class ImageClassificationPolicy(Policy):
    config = dict(
        type='image_classification',
        on_policy=False,
    )

    def _init_learn(self):
        self._optimizer = SGD(
            self._model.parameters(),
            lr=self._cfg.learn.learning_rate,
            weight_decay=self._cfg.learn.weight_decay,
            momentum=0.9
        )
        self._timer = EasyTimer(cuda=True)

        def lr_scheduler_fn(epoch):
            if epoch <= self._cfg.learn.warmup_epoch:
                return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
            else:
                ratio = epoch // self._cfg.learn.decay_epoch
                return math.pow(self._cfg.learn.decay_rate, ratio)

        self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
        self._lr_scheduler.step()
        self._learn_model = model_wrap(self._model, 'base')
        self._learn_model.reset()

        self._ce_loss = nn.CrossEntropyLoss()

    def _forward_learn(self, data):
        if self._cuda:
            data = to_device(data, self._device)
        self._learn_model.train()

        with self._timer:
            img, target = data
            logit = self._learn_model.forward(img)
            loss = self._ce_loss(logit, target)
        forward_time = self._timer.value

        with self._timer:
            self._optimizer.zero_grad()
            loss.backward()
        backward_time = self._timer.value

        with self._timer:
            if self._cfg.learn.multi_gpu:
                self.sync_gradients(self._learn_model)
        sync_time = self._timer.value
        self._optimizer.step()

        cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
        cur_lr = sum(cur_lr) / len(cur_lr)
        return {
            'cur_lr': cur_lr,
            'total_loss': loss.item(),
            'forward_time': forward_time,
            'backward_time': backward_time,
            'sync_time': sync_time,
        }

    def _monitor_vars_learn(self):
        return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']

    def _init_eval(self):
        self._eval_model = model_wrap(self._model, 'base')

    def _forward_eval(self, data):
        if self._cuda:
            data = to_device(data, self._device)
        self._eval_model.eval()
        with torch.no_grad():
            output = self._eval_model.forward(data)
        if self._cuda:
            output = to_device(output, 'cpu')
        return output

    def _init_collect(self):
        pass

    def _forward_collect(self, data):
        pass

    def _process_transition(self):
        pass

    def _get_train_sample(self):
        pass