提交 fd7ef078 编写于 作者: T Tingquan Gao

Revert "add the amp decorator"

This reverts commit d3b7690f.
上级 03795249
import functools
import paddle
def AMP_forward_decorator(func):
@functools.wraps(func)
def wrapper(model, *args, **kwargs):
if AMPForwardDecorator.amp_level:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=AMPForwardDecorator.amp_level):
return func(model, *args, **kwargs)
else:
return func(model, *args, **kwargs)
return wrapper
class AMPForwardDecorator(object):
amp_level = None
amp_eval = None
def __init__(self, forward_func):
self.forward_func = forward_func
@functools.wraps
def __call__(self, model_obj, *args, **kwargs):
# print(type(self))
# print(type(model_obj))
return self.forward_func(model_obj, *args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册