diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 642ac5e26a25e7f8f5baee8055f926a0dd369ccb..14a2bce63f3b391d80cf218b9f84190d74b988b2 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -128,12 +128,10 @@ def pure_fp16_initialize(models): for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): layer._casted_by_pure_fp16 = True - if len(layer._sub_layers) is 0: - - if (layer._dtype is 'float16') or isinstance(layer, ( - paddle.nn.BatchNorm, paddle.nn.LayerNorm)): - continue - layer.to(dtype='float16') + if (layer._dtype is 'float16') or isinstance(layer, ( + paddle.nn.BatchNorm, paddle.nn.LayerNorm)): + continue + layer._to_impl(dtype='float16', include_sublayers=False) return models diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 1985a4ebb9256f47c6c46681c392cd9e398b12d0..caa30df5c1ffd4e6ce2043b7952db2194b3d5c76 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1465,23 +1465,6 @@ class Layer(object): for param, state in matched_param_state: _set_var(param, state) - def _apply(self, func, device, dtype, blocking): - for layer in self.children(): - layer._apply(func, device, dtype, blocking) - - for key, param in self._parameters.items(): - if param is not None: - with no_grad(): - param_applied = func(param, device, dtype, blocking) - - if param.grad is not None: - with no_grad(): - grad_applied = func(param._grad_ivar(), device, dtype, - blocking) - - for key, buf in self._buffers.items(): - self._buffers[key] = func(buf, device, dtype, blocking) - def to(self, device=None, dtype=None, blocking=None): ''' Cast the parameters and buffers of Layer by the give device, dtype and blocking. @@ -1495,7 +1478,7 @@ class Layer(object): blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None. - + Returns: self @@ -1529,6 +1512,53 @@ class Layer(object): # [[-0.04989364, -0.56889004], # [ 0.33960250, 0.96878713]]) + ''' + return self._to_impl( + device=device, + dtype=dtype, + blocking=blocking, + include_sublayers=True) + + def _apply(self, func, device, dtype, blocking, include_sublayers=True): + if include_sublayers: + for layer in self.children(): + layer._apply(func, device, dtype, blocking, include_sublayers) + + for key, param in self._parameters.items(): + if param is not None: + with no_grad(): + param_applied = func(param, device, dtype, blocking) + + if param.grad is not None: + with no_grad(): + grad_applied = func(param._grad_ivar(), device, dtype, + blocking) + + for key, buf in self._buffers.items(): + self._buffers[key] = func(buf, device, dtype, blocking) + + def _to_impl(self, + device=None, + dtype=None, + blocking=None, + include_sublayers=True): + ''' + Cast the parameters and buffers of Layer by the give device, dtype and blocking. + + Parameters: + device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored. + If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the + index of the GPUs or XPUs. Default: None. + + dtype(str|numpy.dtype|paddle.dtype|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None. + + blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be + asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None. + + include_sublayers(bool|True, optional): If True, deal with self and all sublayers parameters and buffers, if not only deal with self parameters and buffers. Default: True. + + Returns: + self ''' @@ -1605,7 +1635,7 @@ class Layer(object): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) - self._apply(transform, device, dtype, blocking) + self._apply(transform, device, dtype, blocking, include_sublayers) self._dtype = dtype return self