diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index f87db415768a181db53516d5b6918b65eb1f98e2..e43921636d961966bc51d640e3e5a37d7479bd73 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -108,6 +108,10 @@ class VarBase { void ClearGradVarBase() { grad_var_ = nullptr; } + void SetGradVarBase(VarBase& grad_var) { + MutableGradVarBase()->CopyFrom(grad_var, true); + } + const std::shared_ptr& MutableGradVarBase() { if (grad_var_ == nullptr) { if (auto grad_var_wrapper = var_->GetGradVar()) { diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 4ab507fe367254490a48846a28b6f5fe7692437d..68c6b855572a78a5335f531a1320657c6468072f 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1032,6 +1032,10 @@ void BindImperative(py::module *m_ptr) { return std::shared_ptr(nullptr); }, py::return_value_policy::copy) + .def("_set_grad_ivar", + [](imperative::VarBase &self, imperative::VarBase &grad) { + self.SetGradVarBase(grad); + }) .def("_is_sparse", [](imperative::VarBase &self) { return self.Var().IsType(); @@ -1278,6 +1282,16 @@ void BindImperative(py::module *m_ptr) { return new_var; }, py::return_value_policy::copy) + .def("_copy_to", + [](const std::shared_ptr &self, + const platform::Place &place, bool blocking) { + auto new_var = self->NewVarBase(place, blocking); + if (!blocking) { + IncreaseVarbaseReferenceCountUntilCopyComplete(self, place); + } + return new_var; + }, + py::return_value_policy::copy) .def("value", [](imperative::VarBase &self) { return self.MutableVar(); }, py::return_value_policy::reference) .def_property("name", &imperative::VarBase::Name, diff --git a/python/paddle/device.py b/python/paddle/device.py index 20453998fb7ae9b588cf625a9b7d571ab42050dd..035d240e713fe8ff90a7fb40a1c5ad58d10bb4a3 100644 --- a/python/paddle/device.py +++ b/python/paddle/device.py @@ -119,28 +119,7 @@ def get_cudnn_version(): return _cudnn_version -def set_device(device): - """ - Paddle supports running calculations on various types of devices, including CPU, GPU and XPU. - They are represented by string identifiers. This function can specify the global device - which the OP will run. - - Parameters: - device(str): This parameter determines the specific running device. - It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the - index of the GPUs or XPUs. - - Examples: - - .. code-block:: python - - import paddle - - paddle.set_device("cpu") - x1 = paddle.ones(name='x1', shape=[1, 2], dtype='int32') - x2 = paddle.zeros(name='x2', shape=[1, 2], dtype='int32') - data = paddle.stack([x1,x2], axis=1) - """ +def _convert_to_place(device): lower_device = device.lower() if lower_device == 'cpu': place = core.CPUPlace() @@ -183,7 +162,32 @@ def set_device(device): device_id = device_info_list[1] device_id = int(device_id) place = core.XPUPlace(device_id) + return place + +def set_device(device): + """ + Paddle supports running calculations on various types of devices, including CPU, GPU and XPU. + They are represented by string identifiers. This function can specify the global device + which the OP will run. + + Parameters: + device(str): This parameter determines the specific running device. + It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the + index of the GPUs or XPUs. + + Examples: + + .. code-block:: python + + import paddle + + paddle.set_device("cpu") + x1 = paddle.ones(name='x1', shape=[1, 2], dtype='int32') + x2 = paddle.zeros(name='x2', shape=[1, 2], dtype='int32') + data = paddle.stack([x1,x2], axis=1) + """ + place = _convert_to_place(device) framework._set_expected_place(place) return place diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 36637abc6d0b85049bbf68b081aea94ccaf385c4..b495976474221322dd7c7f2991d9c78273cdba0f 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -36,6 +36,7 @@ from ..param_attr import ParamAttr from paddle.fluid.executor import Executor, global_scope from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import _current_expected_place as _get_device +from paddle.fluid.dygraph import no_grad import paddle.utils.deprecated as deprecated __all__ = ['Layer'] @@ -1343,6 +1344,114 @@ class Layer(core.Layer): for param, state in matched_param_state: _set_var(param, state) + def _apply(self, func, device, dtype, blocking): + for layer in self.children(): + layer._apply(func, device, dtype, blocking) + + for key, param in self._parameters.items(): + if param is not None: + with no_grad(): + param_applied = func(param, device, dtype, blocking) + assert param.is_leaf + param_applied.stop_gradient = param.stop_gradient + self._parameters[key] = param_applied + + if param.grad is not None: + with no_grad(): + grad_applied = func(param._grad_ivar(), device, dtype, + blocking) + + grad_applied.stop_gradient = param._grad_ivar( + ).stop_gradient + self._parameters[key]._set_grad_ivar(grad_applied) + + for key, buf in self._buffers.items(): + self._buffers[key] = func(buf, device, dtype, blocking) + + def to(self, device=None, dtype=None, blocking=None): + ''' + Cast the parameters and buffers of Layer by the give device, dtype and blocking. + + Parameters: + device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored. + If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the + index of the GPUs or XPUs. Default: None. + + dtype(str|core.VarDesc.VarType|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None. + + blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be + asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + + linear=paddle.nn.Linear(2, 2) + linear.weight + #Parameter containing: + #Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [[-0.32770029, 0.38653070], + # [ 0.46030545, 0.08158520]]) + + linear.to(dtype='float64') + linear.weight + #Tenor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=False, + # [[-0.32770029, 0.38653070], + # [ 0.46030545, 0.08158520]]) + + linear.to(device='cpu') + linear.weight + #Tensor(shape=[2, 2], dtype=float64, place=CPUPlace, stop_gradient=False, + # [[-0.32770029, 0.38653070], + # [ 0.46030545, 0.08158520]]) + linear.to(device=paddle.CUDAPinnedPlace(), blocking=False) + linear.weight + #Tensor(shape=[2, 2], dtype=float64, place=CUDAPinnedPlace, stop_gradient=False, + # [[-0.04989364, -0.56889004], + # [ 0.33960250, 0.96878713]]) + + + ''' + + if device is None and dtype is None and blocking is None: + return + + if device is not None: + if isinstance(device, str): + device = paddle.device._convert_to_place(device) + elif isinstance(device, (core.CPUPlace, core.CUDAPlace, + core.CUDAPinnedPlace, core.XPUPlace)): + pass + else: + raise ValueError( + "device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is " + + type(device).__name__) + + if blocking is None: + blocking = True + else: + assert isinstance( + blocking, + bool), "blocking value error, must be the True, False or None" + + def transform(t, device, dtype, blocking): + if device is None: + device = t.place + if dtype is None: + dtype = t.dtype + + new_t = t._copy_to(device, blocking) + if dtype is not None and dtype != t.dtype: + new_t = new_t.cast(dtype=dtype) + + return new_t + + self._apply(transform, device, dtype, blocking) + # [aliases] Compatible with old method names set_dict = set_state_dict load_dict = set_state_dict diff --git a/python/paddle/fluid/tests/unittests/test_base_layer.py b/python/paddle/fluid/tests/unittests/test_base_layer.py index 31879dae0dad06d75a1ad5c6b6780ef2c3d2b93b..e6e15575f2ca639a05153b3da9f712ebe1d55476 100644 --- a/python/paddle/fluid/tests/unittests/test_base_layer.py +++ b/python/paddle/fluid/tests/unittests/test_base_layer.py @@ -331,5 +331,72 @@ class TestModifiedBuffer(unittest.TestCase): np.array_equal(dy_outs[i].numpy(), st_outs[i].numpy())) +class TestLayerTo(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.linear = paddle.nn.Linear(2, 2) + self.new_grad = np.random.random([2, 2]) + self.linear.weight._set_grad_ivar(paddle.to_tensor(self.new_grad)) + buffer = paddle.to_tensor([0.0], dtype='float32') + self.linear.register_buffer("buf_name", buffer, persistable=True) + + sublayer = paddle.nn.Conv1D(3, 2, 3) + self.linear.add_sublayer(1, sublayer) + + def test_to_api(self): + self.linear.to(dtype='double') + self.assertEqual(self.linear.weight.dtype, + paddle.fluid.core.VarDesc.VarType.FP64) + self.assertEqual(self.linear.buf_name.dtype, + paddle.fluid.core.VarDesc.VarType.FP64) + self.assertTrue(np.allclose(self.linear.weight.grad, self.new_grad)) + self.assertTrue(self.linear.weight._grad_ivar().dtype, + paddle.fluid.core.VarDesc.VarType.FP64) + + self.linear.to() + self.assertEqual(self.linear.weight.dtype, + paddle.fluid.core.VarDesc.VarType.FP64) + self.assertEqual(self.linear.buf_name.dtype, + paddle.fluid.core.VarDesc.VarType.FP64) + self.assertTrue(np.allclose(self.linear.weight.grad, self.new_grad)) + self.assertTrue(self.linear.weight._grad_ivar().dtype, + paddle.fluid.core.VarDesc.VarType.FP64) + + if paddle.fluid.is_compiled_with_cuda(): + self.linear.to(device=paddle.CUDAPlace(0)) + self.assertTrue(self.linear.weight.place.is_gpu_place()) + self.assertEqual(self.linear.weight.place.gpu_device_id(), 0) + self.assertTrue(self.linear.buf_name.place.is_gpu_place()) + self.assertEqual(self.linear.buf_name.place.gpu_device_id(), 0) + self.assertTrue(self.linear.weight._grad_ivar().place.is_gpu_place( + )) + self.assertEqual( + self.linear.weight._grad_ivar().place.gpu_device_id(), 0) + + self.linear.to(device='gpu:0') + self.assertTrue(self.linear.weight.place.is_gpu_place()) + self.assertEqual(self.linear.weight.place.gpu_device_id(), 0) + self.assertTrue(self.linear.buf_name.place.is_gpu_place()) + self.assertEqual(self.linear.buf_name.place.gpu_device_id(), 0) + self.assertTrue(self.linear.weight._grad_ivar().place.is_gpu_place( + )) + self.assertEqual( + self.linear.weight._grad_ivar().place.gpu_device_id(), 0) + + self.linear.to(device=paddle.CPUPlace()) + self.assertTrue(self.linear.weight.place.is_cpu_place()) + self.assertTrue(self.linear.buf_name.place.is_cpu_place()) + self.assertTrue(self.linear.weight._grad_ivar().place.is_cpu_place()) + + self.linear.to(device='cpu') + self.assertTrue(self.linear.weight.place.is_cpu_place()) + self.assertTrue(self.linear.buf_name.place.is_cpu_place()) + self.assertTrue(self.linear.weight._grad_ivar().place.is_cpu_place()) + + self.assertRaises(ValueError, self.linear.to, device=1) + + self.assertRaises(AssertionError, self.linear.to, blocking=1) + + if __name__ == '__main__': unittest.main()