amp.py 1.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
from functools import partial
import contextlib
import paddle


class AutoCast:
    def __init__(self,
                 use_amp=False,
                 amp_level="O1",
                 use_promote=False,
                 amp_eval=False):
        self.use_amp = use_amp
        self.amp_eval = amp_eval

        if self.use_amp:
            self.cast_context = partial(paddle.amp.auto_cast, level=amp_level)

    def __call__(self, is_eval=False):
        if self.use_amp:
            # not is_eval: cast for all training
            # is_eval and self.amp_eval: cast for evaluation only when amp_eval is True
            if not is_eval or (is_eval and self.amp_eval):
                return self.cast_context()

        return contextlib.nullcontext()


def build_scaler(use_amp=False, scale_loss=1.0,
                 use_dynamic_loss_scaling=False):
    class Foo:
        def __init__(self):
            pass

        def scale(self, loss):
            return loss

        def step(self, optimizer):
            optimizer.step()

        def update(self):
            return

        def minimize(self, optimizer, loss):
            optimizer.step()

    if use_amp:
        return paddle.amp.GradScaler(
            init_loss_scaling=scale_loss,
            use_dynamic_loss_scaling=use_dynamic_loss_scaling)
    return Foo()