diff --git a/ppcls/utils/amp.py b/ppcls/utils/amp.py deleted file mode 100644 index 187c9b42fdd176e3cfd8a6db32ea1712d1a19bde..0000000000000000000000000000000000000000 --- a/ppcls/utils/amp.py +++ /dev/null @@ -1,32 +0,0 @@ -import functools -import paddle - - -def AMP_forward_decorator(func): - @functools.wraps(func) - def wrapper(model, *args, **kwargs): - if AMPForwardDecorator.amp_level: - with paddle.amp.auto_cast( - custom_black_list={ - "flatten_contiguous_range", "greater_than" - }, - level=AMPForwardDecorator.amp_level): - return func(model, *args, **kwargs) - else: - return func(model, *args, **kwargs) - - return wrapper - - -class AMPForwardDecorator(object): - amp_level = None - amp_eval = None - - def __init__(self, forward_func): - self.forward_func = forward_func - - @functools.wraps - def __call__(self, model_obj, *args, **kwargs): - # print(type(self)) - # print(type(model_obj)) - return self.forward_func(model_obj, *args, **kwargs)