未验证 提交 6e946e9d 编写于 作者: C chentianyu03 提交者: GitHub

add layer.to api (#32040)

* add layer.to api

* add layer.to api

* add layer.to api

* add the doc for Layer.to

* add input type checking

* modify assert and import bug

* format code style

* format code style

* make place support str type

* add SetGradVarBase method to set the gradient after conversion

* modify argument palce to device

* modify argument palce to device

* modify doc of layers.to API

* add xpuplace to device argument
上级 693c7629
......@@ -108,6 +108,10 @@ class VarBase {
void ClearGradVarBase() { grad_var_ = nullptr; }
void SetGradVarBase(VarBase& grad_var) {
MutableGradVarBase()->CopyFrom(grad_var, true);
}
const std::shared_ptr<VarBase>& MutableGradVarBase() {
if (grad_var_ == nullptr) {
if (auto grad_var_wrapper = var_->GetGradVar()) {
......
......@@ -1032,6 +1032,10 @@ void BindImperative(py::module *m_ptr) {
return std::shared_ptr<imperative::VarBase>(nullptr);
},
py::return_value_policy::copy)
.def("_set_grad_ivar",
[](imperative::VarBase &self, imperative::VarBase &grad) {
self.SetGradVarBase(grad);
})
.def("_is_sparse",
[](imperative::VarBase &self) {
return self.Var().IsType<framework::SelectedRows>();
......@@ -1278,6 +1282,16 @@ void BindImperative(py::module *m_ptr) {
return new_var;
},
py::return_value_policy::copy)
.def("_copy_to",
[](const std::shared_ptr<imperative::VarBase> &self,
const platform::Place &place, bool blocking) {
auto new_var = self->NewVarBase(place, blocking);
if (!blocking) {
IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
}
return new_var;
},
py::return_value_policy::copy)
.def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
py::return_value_policy::reference)
.def_property("name", &imperative::VarBase::Name,
......
......@@ -119,28 +119,7 @@ def get_cudnn_version():
return _cudnn_version
def set_device(device):
"""
Paddle supports running calculations on various types of devices, including CPU, GPU and XPU.
They are represented by string identifiers. This function can specify the global device
which the OP will run.
Parameters:
device(str): This parameter determines the specific running device.
It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
index of the GPUs or XPUs.
Examples:
.. code-block:: python
import paddle
paddle.set_device("cpu")
x1 = paddle.ones(name='x1', shape=[1, 2], dtype='int32')
x2 = paddle.zeros(name='x2', shape=[1, 2], dtype='int32')
data = paddle.stack([x1,x2], axis=1)
"""
def _convert_to_place(device):
lower_device = device.lower()
if lower_device == 'cpu':
place = core.CPUPlace()
......@@ -183,7 +162,32 @@ def set_device(device):
device_id = device_info_list[1]
device_id = int(device_id)
place = core.XPUPlace(device_id)
return place
def set_device(device):
"""
Paddle supports running calculations on various types of devices, including CPU, GPU and XPU.
They are represented by string identifiers. This function can specify the global device
which the OP will run.
Parameters:
device(str): This parameter determines the specific running device.
It can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
index of the GPUs or XPUs.
Examples:
.. code-block:: python
import paddle
paddle.set_device("cpu")
x1 = paddle.ones(name='x1', shape=[1, 2], dtype='int32')
x2 = paddle.zeros(name='x2', shape=[1, 2], dtype='int32')
data = paddle.stack([x1,x2], axis=1)
"""
place = _convert_to_place(device)
framework._set_expected_place(place)
return place
......
......@@ -36,6 +36,7 @@ from ..param_attr import ParamAttr
from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.dygraph import no_grad
import paddle.utils.deprecated as deprecated
__all__ = ['Layer']
......@@ -1343,6 +1344,114 @@ class Layer(core.Layer):
for param, state in matched_param_state:
_set_var(param, state)
def _apply(self, func, device, dtype, blocking):
for layer in self.children():
layer._apply(func, device, dtype, blocking)
for key, param in self._parameters.items():
if param is not None:
with no_grad():
param_applied = func(param, device, dtype, blocking)
assert param.is_leaf
param_applied.stop_gradient = param.stop_gradient
self._parameters[key] = param_applied
if param.grad is not None:
with no_grad():
grad_applied = func(param._grad_ivar(), device, dtype,
blocking)
grad_applied.stop_gradient = param._grad_ivar(
).stop_gradient
self._parameters[key]._set_grad_ivar(grad_applied)
for key, buf in self._buffers.items():
self._buffers[key] = func(buf, device, dtype, blocking)
def to(self, device=None, dtype=None, blocking=None):
'''
Cast the parameters and buffers of Layer by the give device, dtype and blocking.
Parameters:
device(str|paddle.CPUPlace()|paddle.CUDAPlace()|paddle.CUDAPinnedPlace()|paddle.XPUPlace()|None, optional): The device of the Layer which want to be stored.
If None, the device is the same with the original Tensor. If device is string, it can be ``cpu``, ``gpu:x`` and ``xpu:x``, where ``x`` is the
index of the GPUs or XPUs. Default: None.
dtype(str|core.VarDesc.VarType|None, optional): The type of the data. If None, the dtype is the same with the original Tensor. Default: None.
blocking(bool|None, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. If None, the blocking is set True. Default: None.
Returns:
None
Examples:
.. code-block:: python
import paddle
linear=paddle.nn.Linear(2, 2)
linear.weight
#Parameter containing:
#Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [[-0.32770029, 0.38653070],
# [ 0.46030545, 0.08158520]])
linear.to(dtype='float64')
linear.weight
#Tenor(shape=[2, 2], dtype=float64, place=CUDAPlace(0), stop_gradient=False,
# [[-0.32770029, 0.38653070],
# [ 0.46030545, 0.08158520]])
linear.to(device='cpu')
linear.weight
#Tensor(shape=[2, 2], dtype=float64, place=CPUPlace, stop_gradient=False,
# [[-0.32770029, 0.38653070],
# [ 0.46030545, 0.08158520]])
linear.to(device=paddle.CUDAPinnedPlace(), blocking=False)
linear.weight
#Tensor(shape=[2, 2], dtype=float64, place=CUDAPinnedPlace, stop_gradient=False,
# [[-0.04989364, -0.56889004],
# [ 0.33960250, 0.96878713]])
'''
if device is None and dtype is None and blocking is None:
return
if device is not None:
if isinstance(device, str):
device = paddle.device._convert_to_place(device)
elif isinstance(device, (core.CPUPlace, core.CUDAPlace,
core.CUDAPinnedPlace, core.XPUPlace)):
pass
else:
raise ValueError(
"device value error, must be str, paddle.CPUPlace(), paddle.CUDAPlace(), paddle.CUDAPinnedPlace() or paddle.XPUPlace(), but the type of device is "
+ type(device).__name__)
if blocking is None:
blocking = True
else:
assert isinstance(
blocking,
bool), "blocking value error, must be the True, False or None"
def transform(t, device, dtype, blocking):
if device is None:
device = t.place
if dtype is None:
dtype = t.dtype
new_t = t._copy_to(device, blocking)
if dtype is not None and dtype != t.dtype:
new_t = new_t.cast(dtype=dtype)
return new_t
self._apply(transform, device, dtype, blocking)
# [aliases] Compatible with old method names
set_dict = set_state_dict
load_dict = set_state_dict
......@@ -331,5 +331,72 @@ class TestModifiedBuffer(unittest.TestCase):
np.array_equal(dy_outs[i].numpy(), st_outs[i].numpy()))
class TestLayerTo(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.linear = paddle.nn.Linear(2, 2)
self.new_grad = np.random.random([2, 2])
self.linear.weight._set_grad_ivar(paddle.to_tensor(self.new_grad))
buffer = paddle.to_tensor([0.0], dtype='float32')
self.linear.register_buffer("buf_name", buffer, persistable=True)
sublayer = paddle.nn.Conv1D(3, 2, 3)
self.linear.add_sublayer(1, sublayer)
def test_to_api(self):
self.linear.to(dtype='double')
self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.buf_name.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(np.allclose(self.linear.weight.grad, self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.linear.to()
self.assertEqual(self.linear.weight.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertEqual(self.linear.buf_name.dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
self.assertTrue(np.allclose(self.linear.weight.grad, self.new_grad))
self.assertTrue(self.linear.weight._grad_ivar().dtype,
paddle.fluid.core.VarDesc.VarType.FP64)
if paddle.fluid.is_compiled_with_cuda():
self.linear.to(device=paddle.CUDAPlace(0))
self.assertTrue(self.linear.weight.place.is_gpu_place())
self.assertEqual(self.linear.weight.place.gpu_device_id(), 0)
self.assertTrue(self.linear.buf_name.place.is_gpu_place())
self.assertEqual(self.linear.buf_name.place.gpu_device_id(), 0)
self.assertTrue(self.linear.weight._grad_ivar().place.is_gpu_place(
))
self.assertEqual(
self.linear.weight._grad_ivar().place.gpu_device_id(), 0)
self.linear.to(device='gpu:0')
self.assertTrue(self.linear.weight.place.is_gpu_place())
self.assertEqual(self.linear.weight.place.gpu_device_id(), 0)
self.assertTrue(self.linear.buf_name.place.is_gpu_place())
self.assertEqual(self.linear.buf_name.place.gpu_device_id(), 0)
self.assertTrue(self.linear.weight._grad_ivar().place.is_gpu_place(
))
self.assertEqual(
self.linear.weight._grad_ivar().place.gpu_device_id(), 0)
self.linear.to(device=paddle.CPUPlace())
self.assertTrue(self.linear.weight.place.is_cpu_place())
self.assertTrue(self.linear.buf_name.place.is_cpu_place())
self.assertTrue(self.linear.weight._grad_ivar().place.is_cpu_place())
self.linear.to(device='cpu')
self.assertTrue(self.linear.weight.place.is_cpu_place())
self.assertTrue(self.linear.buf_name.place.is_cpu_place())
self.assertTrue(self.linear.weight._grad_ivar().place.is_cpu_place())
self.assertRaises(ValueError, self.linear.to, device=1)
self.assertRaises(AssertionError, self.linear.to, blocking=1)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册