amp.py 891 字节
Newer Older
G
gaotingquan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
import functools
import paddle


def AMP_forward_decorator(func):
    @functools.wraps(func)
    def wrapper(model, *args, **kwargs):
        if AMPForwardDecorator.amp_level:
            with paddle.amp.auto_cast(
                    custom_black_list={
                        "flatten_contiguous_range", "greater_than"
                    },
                    level=AMPForwardDecorator.amp_level):
                return func(model, *args, **kwargs)
        else:
            return func(model, *args, **kwargs)

    return wrapper


class AMPForwardDecorator(object):
    amp_level = None
    amp_eval = None

    def __init__(self, forward_func):
        self.forward_func = forward_func

    @functools.wraps
    def __call__(self, model_obj, *args, **kwargs):
        # print(type(self))
        # print(type(model_obj))
        return self.forward_func(model_obj, *args, **kwargs)