未验证 提交 f80cee11 编写于 作者: Z zhangbo9674 提交者: GitHub

add float_only for layer_to (#43760)

上级 178b2440
...@@ -173,7 +173,9 @@ def pure_fp16_initialize(models): ...@@ -173,7 +173,9 @@ def pure_fp16_initialize(models):
paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm)): paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm)):
continue continue
layer._to_impl(dtype='float16', include_sublayers=False) layer._to_impl(dtype='float16',
include_sublayers=False,
floating_only=True)
return models return models
......
...@@ -1576,7 +1576,8 @@ class Layer(object): ...@@ -1576,7 +1576,8 @@ class Layer(object):
return self._to_impl(device=device, return self._to_impl(device=device,
dtype=dtype, dtype=dtype,
blocking=blocking, blocking=blocking,
include_sublayers=True) include_sublayers=True,
floating_only=False)
def _apply(self, func, device, dtype, blocking, include_sublayers=True): def _apply(self, func, device, dtype, blocking, include_sublayers=True):
if include_sublayers: if include_sublayers:
...@@ -1599,53 +1600,7 @@ class Layer(object): ...@@ -1599,53 +1600,7 @@ class Layer(object):
self._dtype = dtype self._dtype = dtype
def _to_impl(self, def _transform(self, t, device, dtype, blocking):
device=None,
dtype=None,
blocking=None,
include_sublayers=True):
'''
Cast the parameters and buffers of Layer by the give device, dtype and blocking.
Parameters:
device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
index of the GPUs or XPUs. Default: None.
dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
include_sublayers(bool|True, optional): If True, deal with self and all sublayers parameters and buffers, if not only deal with self parameters and buffers. Default: True.
Returns:
self
'''
if device is None and dtype is None and blocking is None:
return self
if device is not None:
if isinstance(device, str):
device = paddle.device._convert_to_place(device)
elif isinstance(device, (core.CPUPlace, core.CUDAPlace,
core.CUDAPinnedPlace, core.XPUPlace)):
pass
else:
raise ValueError(
"device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
+ type(device).__name__)
if blocking is None:
blocking = True
else:
assert isinstance(
blocking,
bool), "blocking value error, must be the True, False or None"
def transform(t, device, dtype, blocking):
if device is None: if device is None:
device = t.place device = t.place
if dtype is None: if dtype is None:
...@@ -1695,6 +1650,60 @@ class Layer(object): ...@@ -1695,6 +1650,60 @@ class Layer(object):
return t return t
def _to_impl(self,
device=None,
dtype=None,
blocking=None,
include_sublayers=True,
floating_only=False):
'''
Cast the parameters and buffers of Layer by the give device, dtype and blocking.
Parameters:
device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
index of the GPUs or XPUs. Default: None.
dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
include_sublayers(bool|True, optional): If True, deal with self and all sublayers parameters and buffers, if not only deal with self parameters and buffers. Default: True.
floating_only(bool|False, optional): If True, only cast all floating point parameters and buffers of Layer by the give device, dtype and blocking.
Returns:
self
'''
if device is None and dtype is None and blocking is None:
return self
if device is not None:
if isinstance(device, str):
device = paddle.device._convert_to_place(device)
elif isinstance(device, (core.CPUPlace, core.CUDAPlace,
core.CUDAPinnedPlace, core.XPUPlace)):
pass
else:
raise ValueError(
"device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
+ type(device).__name__)
if blocking is None:
blocking = True
else:
assert isinstance(
blocking,
bool), "blocking value error, must be the True, False or None"
def transform(t, device, dtype, blocking):
if floating_only and (not paddle.is_floating_point(t)):
return t
return self._transform(t, device, dtype, blocking)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning) warnings.filterwarnings("ignore", category=UserWarning)
self._apply(transform, device, dtype, blocking, include_sublayers) self._apply(transform, device, dtype, blocking, include_sublayers)
......
...@@ -707,6 +707,14 @@ class TestAmpDecorator(unittest.TestCase): ...@@ -707,6 +707,14 @@ class TestAmpDecorator(unittest.TestCase):
for param in model.parameters(): for param in model.parameters():
self.assertEqual((param.dtype == paddle.float32), True) self.assertEqual((param.dtype == paddle.float32), True)
def test_floating_only(self):
model = paddle.nn.Linear(2, 4)
buffer = paddle.to_tensor(np.array([5]).astype("int32"))
model.register_buffer("buffer_name", buffer, persistable=True)
model = paddle.amp.decorate(models=model, level='O2')
self.assertEqual((model._buffers["buffer_name"].dtype == paddle.int32),
True)
class TestStateDictHookForAMP(unittest.TestCase): class TestStateDictHookForAMP(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册