未验证 提交 fb7f8529 编写于 作者: Z Zhou Wei 提交者: GitHub

fix print tensor place,add cpu/cuda/pin_memory API for Tensor (#28200)

上级 99408718
......@@ -836,6 +836,127 @@ void BindImperative(py::module *m_ptr) {
}
},
py::call_guard<py::gil_scoped_release>())
.def("cpu",
[](const std::shared_ptr<imperative::VarBase> &self) {
if (platform::is_cpu_place(self->Place())) {
return self;
} else {
auto new_var = self->NewVarBase(platform::CPUPlace(), true);
new_var->SetOverridedStopGradient(self->OverridedStopGradient());
return new_var;
}
},
R"DOC(
Returns a copy of this Tensor in CPU memory.
If this Tensor is already in CPU memory, then no copy is performed and the original Tensor is returned.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
print(x.place) # CUDAPlace(0)
y = x.cpu()
print(y.place) # CPUPlace
)DOC")
.def("pin_memory",
[](const std::shared_ptr<imperative::VarBase> &self) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot copy this Tensor to pinned memory in CPU version "
"Paddle, "
"Please recompile or reinstall Paddle with CUDA support."));
#endif
if (platform::is_cuda_pinned_place(self->Place())) {
return self;
} else {
auto new_var =
self->NewVarBase(platform::CUDAPinnedPlace(), true);
new_var->SetOverridedStopGradient(self->OverridedStopGradient());
return new_var;
}
},
R"DOC(
Returns a copy of this Tensor in pin memory.
If this Tensor is already in pin memory, then no copy is performed and the original Tensor is returned.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
print(x.place) # CUDAPlace(0)
y = x.pin_memory()
print(y.place) # CUDAPinnedPlace
)DOC")
.def("cuda",
[](const std::shared_ptr<imperative::VarBase> &self, int device_id,
bool blocking) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW(platform::errors::PermissionDenied(
"Cannot copy this Tensor to GPU in CPU version Paddle, "
"Please recompile or reinstall Paddle with CUDA support."));
#else
int device_count = platform::GetCUDADeviceCount();
if (device_id == -1) {
if (platform::is_gpu_place(self->Place())) {
return self;
} else {
device_id = 0;
}
}
PADDLE_ENFORCE_GE(
device_id, 0,
platform::errors::InvalidArgument(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)",
device_id, device_count));
PADDLE_ENFORCE_LT(
device_id, device_count,
platform::errors::InvalidArgument(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)",
device_id, device_count));
platform::CUDAPlace place = platform::CUDAPlace(device_id);
if (platform::is_same_place(self->Place(), place)) {
return self;
} else {
auto new_var = self->NewVarBase(place, blocking);
new_var->SetOverridedStopGradient(self->OverridedStopGradient());
return new_var;
}
#endif
},
py::arg("device_id") = -1, py::arg("blocking") = true, R"DOC(
Returns a copy of this Tensor in GPU memory.
If this Tensor is already in GPU memory and device_id is default,
then no copy is performed and the original Tensor is returned.
Args:
device_id(int, optional): The destination GPU device id. Defaults to the current device.
blocking(bool, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
print(x.place) # CPUPlace
y = x.cuda()
print(y.place) # CUDAPlace(0)
y = x.cuda(1)
print(y.place) # CUDAPlace(1)
)DOC")
.def("_copy_to",
[](const imperative::VarBase &self, const platform::CPUPlace &place,
bool blocking) { return self.NewVarBase(place, blocking); },
......@@ -950,12 +1071,14 @@ void BindImperative(py::module *m_ptr) {
[](imperative::Tracer &self,
std::unordered_set<std::string> &allow_ops,
std::unordered_set<std::string> &block_ops) {
// NOTE(zhiqiu): The automatic conversion in pybind11 between c++
// NOTE(zhiqiu): The automatic conversion in pybind11 between
// c++
// STL and python set/list/dict involve a copy operation that
// prevents pass-by-reference semantics, so it is ok to swap.
// The reaseon why not directly pass
// std::shared_ptr<std::unordered_set<std::string>>
// is that pybind11 forbid shared_ptr<T> where T is not custom type.
// is that pybind11 forbid shared_ptr<T> where T is not custom
// type.
imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops);
imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops);
})
......
......@@ -1421,6 +1421,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("_get_device_id",
[](platform::CUDAPlace &self) -> int { return self.GetDeviceId(); })
#endif
.def("__repr__", string::to_string<const platform::CUDAPlace &>)
.def("__str__", string::to_string<const platform::CUDAPlace &>);
py::class_<platform::XPUPlace>(m, "XPUPlace", R"DOC(
......@@ -1479,6 +1480,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("get_device_id",
[](const platform::XPUPlace &self) { return self.GetDeviceId(); })
#endif
.def("__repr__", string::to_string<const platform::XPUPlace &>)
.def("__str__", string::to_string<const platform::XPUPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace", R"DOC(
......@@ -1500,6 +1502,7 @@ All parameter, weight, gradient are variables in Paddle.
.def("_equals", &IsSamePlace<platform::CPUPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CPUPlace, platform::CUDAPinnedPlace>)
.def("__repr__", string::to_string<const platform::CPUPlace &>)
.def("__str__", string::to_string<const platform::CPUPlace &>);
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace", R"DOC(
......@@ -1536,6 +1539,7 @@ All parameter, weight, gradient are variables in Paddle.
&IsSamePlace<platform::CUDAPinnedPlace, platform::CPUPlace>)
.def("_equals",
&IsSamePlace<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>)
.def("__repr__", string::to_string<const platform::CUDAPinnedPlace &>)
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
py::class_<platform::Place>(m, "Place")
......@@ -1578,10 +1582,13 @@ All parameter, weight, gradient are variables in Paddle.
[](platform::Place &self, const platform::CUDAPlace &gpu_place) {
self = gpu_place;
})
.def("set_place", [](platform::Place &self,
const platform::CUDAPinnedPlace &cuda_pinned_place) {
self = cuda_pinned_place;
});
.def("set_place",
[](platform::Place &self,
const platform::CUDAPinnedPlace &cuda_pinned_place) {
self = cuda_pinned_place;
})
.def("__repr__", string::to_string<const platform::Place &>)
.def("__str__", string::to_string<const platform::Place &>);
py::class_<OperatorBase>(m, "Operator")
.def_static(
......
......@@ -1785,8 +1785,6 @@ class ComplexVariable(object):
**Notes**:
**The constructor of ComplexTensor should not be invoked directly.**
**Only support dygraph mode at present. Please use** :ref:`api_fluid_dygraph_to_variable` **to create a dygraph ComplexTensor with complex number data.**
Args:
real (Tensor): The Tensor holding real-part data.
imag (Tensor): The Tensor holding imaginery-part data.
......@@ -1795,14 +1793,14 @@ class ComplexVariable(object):
.. code-block:: python
import paddle
import numpy as np
paddle.enable_imperative()
x = paddle.to_tensor([1.0+2.0j, 0.2])
print(x.name, x.dtype, x.shape)
# ({'real': 'generated_tensor_0.real', 'imag': 'generated_tensor_0.imag'}, 'complex128', [2L])
print(x.numpy())
# [1. +2.j 0.2+0.j]
# ({'real': 'generated_tensor_0.real', 'imag': 'generated_tensor_0.imag'}, complex64, [2])
print(x)
# ComplexTensor[real](shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [ 1., 0.20000000])
# ComplexTensor[imag](shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [2., 0.])
print(type(x))
# <class 'paddle.ComplexTensor'>
"""
......@@ -1858,9 +1856,10 @@ class ComplexVariable(object):
return self.real.numpy() + 1j * self.imag.numpy()
def __str__(self):
return "ComplexTensor[real]: %s\n%s\nComplexTensor[imag]: %s\n%s" % (
self.real.name, str(self.real.value().get_tensor()), self.imag.name,
str(self.imag.value().get_tensor()))
from paddle.tensor.to_string import to_string
return "ComplexTensor containing:\n{real}\n{imag}".format(
real=to_string(self.real, "[real part]Tensor"),
imag=to_string(self.imag, "[imag part]Tensor"))
__repr__ = __str__
......@@ -5335,16 +5334,13 @@ class ParamBase(core.VarBase):
.. code-block:: python
import paddle
paddle.disable_static()
conv = paddle.nn.Conv2D(3, 3, 5)
print(conv.weight)
# Parameter: conv2d_0.w_0
# - place: CUDAPlace(0)
# - shape: [3, 3, 5, 5]
# - layout: NCHW
# - dtype: float
# - data: [...]
paddle.enable_static()
linear = paddle.nn.Linear(3, 3)
print(linear.weight)
# Parameter containing:
# Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [[ 0.48948765, 0.05829060, -0.25524026],
# [-0.70368278, 0.52986908, -0.68742192],
# [-0.54217887, 0.48439729, 0.34082305]])
"""
return "Parameter containing:\n{tensor}".format(
tensor=super(ParamBase, self).__str__())
......
......@@ -64,6 +64,15 @@ class TestVarBase(unittest.TestCase):
y.backward()
self.assertTrue(
np.array_equal(x.grad, np.array([2.4]).astype('float32')))
y = x.cpu()
self.assertEqual(y.place.__repr__(), "CPUPlace")
if core.is_compiled_with_cuda():
y = x.pin_memory()
self.assertEqual(y.place.__repr__(), "CUDAPinnedPlace")
y = x.cuda(blocking=False)
self.assertEqual(y.place.__repr__(), "CUDAPlace(0)")
y = x.cuda(blocking=True)
self.assertEqual(y.place.__repr__(), "CUDAPlace(0)")
# set_default_dtype take effect on complex
x = paddle.to_tensor(1 + 2j, place=place, stop_gradient=False)
......
......@@ -90,61 +90,38 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True):
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
type(paddle.to_tensor(1))
# <class 'paddle.Tensor'>
paddle.to_tensor(1)
# Tensor: generated_tensor_0
# - place: CUDAPlace(0) # allocate on global default place CPU:0
# - shape: [1]
# - layout: NCHW
# - dtype: int64_t
# - data: [1]
# Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [1])
x = paddle.to_tensor(1)
paddle.to_tensor(x, dtype='int32', place=paddle.CPUPlace()) # A new tensor will be constructed due to different dtype or place
# Tensor: generated_tensor_01
# - place: CPUPlace
# - shape: [1]
# - layout: NCHW
# - dtype: int
# - data: [1]
# Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True,
# [1])
paddle.to_tensor((1.1, 2.2), place=paddle.CUDAPinnedPlace())
# Tensor: generated_tensor_1
# - place: CUDAPinnedPlace
# - shape: [2]
# - layout: NCHW
# - dtype: double
# - data: [1.1 2.2]
# Tensor(shape=[1], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True,
# [1])
paddle.to_tensor([[0.1, 0.2], [0.3, 0.4]], place=paddle.CUDAPlace(0), stop_gradient=False)
# Tensor: generated_tensor_2
# - place: CUDAPlace(0)
# - shape: [2, 2]
# - layout: NCHW
# - dtype: double
# - data: [0.1 0.2 0.3 0.4]
# Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False,
# [[0.10000000, 0.20000000],
# [0.30000001, 0.40000001]])
type(paddle.to_tensor([[1+1j, 2], [3+2j, 4]]), dtype='complex64')
# <class 'paddle.ComplexTensor'>
paddle.to_tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64')
# ComplexTensor[real]: generated_tensor_0.real
# - place: CUDAPlace(0)
# - shape: [2, 2]
# - layout: NCHW
# - dtype: float
# - data: [1 2 3 4]
# ComplexTensor[imag]: generated_tensor_0.imag
# - place: CUDAPlace(0)
# - shape: [2, 2]
# - layout: NCHW
# - dtype: float
# - data: [1 0 2 0]
# ComplexTensor[real](shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[1., 2.],
# [3., 4.]])
# ComplexTensor[imag](shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[1., 0.],
# [2., 0.]])
"""
if place is None:
......
......@@ -24,6 +24,7 @@
}
],
"wlist_temp_api":[
"to_tensor",
"LRScheduler",
"ReduceOnPlateau",
"append_LARS",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册