diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 4c46af3199e29e50e661c660e81665a404beaa67..4d68afeede4e5140db0487fd814ba9589e977eb1 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -836,6 +836,127 @@ void BindImperative(py::module *m_ptr) { } }, py::call_guard()) + .def("cpu", + [](const std::shared_ptr &self) { + if (platform::is_cpu_place(self->Place())) { + return self; + } else { + auto new_var = self->NewVarBase(platform::CPUPlace(), true); + new_var->SetOverridedStopGradient(self->OverridedStopGradient()); + return new_var; + } + }, + R"DOC( + Returns a copy of this Tensor in CPU memory. + + If this Tensor is already in CPU memory, then no copy is performed and the original Tensor is returned. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0)) + print(x.place) # CUDAPlace(0) + + y = x.cpu() + print(y.place) # CPUPlace + + )DOC") + .def("pin_memory", + [](const std::shared_ptr &self) { +#ifndef PADDLE_WITH_CUDA + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot copy this Tensor to pinned memory in CPU version " + "Paddle, " + "Please recompile or reinstall Paddle with CUDA support.")); +#endif + if (platform::is_cuda_pinned_place(self->Place())) { + return self; + } else { + auto new_var = + self->NewVarBase(platform::CUDAPinnedPlace(), true); + new_var->SetOverridedStopGradient(self->OverridedStopGradient()); + return new_var; + } + }, + R"DOC( + Returns a copy of this Tensor in pin memory. + + If this Tensor is already in pin memory, then no copy is performed and the original Tensor is returned. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0)) + print(x.place) # CUDAPlace(0) + + y = x.pin_memory() + print(y.place) # CUDAPinnedPlace + + )DOC") + .def("cuda", + [](const std::shared_ptr &self, int device_id, + bool blocking) { +#ifndef PADDLE_WITH_CUDA + PADDLE_THROW(platform::errors::PermissionDenied( + "Cannot copy this Tensor to GPU in CPU version Paddle, " + "Please recompile or reinstall Paddle with CUDA support.")); +#else + int device_count = platform::GetCUDADeviceCount(); + if (device_id == -1) { + if (platform::is_gpu_place(self->Place())) { + return self; + } else { + device_id = 0; + } + } + PADDLE_ENFORCE_GE( + device_id, 0, + platform::errors::InvalidArgument( + "Can not copy Tensor to Invalid CUDAPlace(%d), device id " + "must inside [0, %d)", + device_id, device_count)); + PADDLE_ENFORCE_LT( + device_id, device_count, + platform::errors::InvalidArgument( + "Can not copy Tensor to Invalid CUDAPlace(%d), device id " + "must inside [0, %d)", + device_id, device_count)); + platform::CUDAPlace place = platform::CUDAPlace(device_id); + if (platform::is_same_place(self->Place(), place)) { + return self; + } else { + auto new_var = self->NewVarBase(place, blocking); + new_var->SetOverridedStopGradient(self->OverridedStopGradient()); + return new_var; + } +#endif + }, + py::arg("device_id") = -1, py::arg("blocking") = true, R"DOC( + Returns a copy of this Tensor in GPU memory. + + If this Tensor is already in GPU memory and device_id is default, + then no copy is performed and the original Tensor is returned. + + Args: + device_id(int, optional): The destination GPU device id. Defaults to the current device. + blocking(bool, optional): If False and the source is in pinned memory, the copy will be + asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor(1.0, place=paddle.CPUPlace()) + print(x.place) # CPUPlace + + y = x.cuda() + print(y.place) # CUDAPlace(0) + + y = x.cuda(1) + print(y.place) # CUDAPlace(1) + )DOC") .def("_copy_to", [](const imperative::VarBase &self, const platform::CPUPlace &place, bool blocking) { return self.NewVarBase(place, blocking); }, @@ -950,12 +1071,14 @@ void BindImperative(py::module *m_ptr) { [](imperative::Tracer &self, std::unordered_set &allow_ops, std::unordered_set &block_ops) { - // NOTE(zhiqiu): The automatic conversion in pybind11 between c++ + // NOTE(zhiqiu): The automatic conversion in pybind11 between + // c++ // STL and python set/list/dict involve a copy operation that // prevents pass-by-reference semantics, so it is ok to swap. // The reaseon why not directly pass // std::shared_ptr> - // is that pybind11 forbid shared_ptr where T is not custom type. + // is that pybind11 forbid shared_ptr where T is not custom + // type. imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops); imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops); }) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3d9d204991f7984ac88989bc8393dde41aca4162..8ff7e9006533062989d475b0434cacaa0348c2b8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1421,6 +1421,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_get_device_id", [](platform::CUDAPlace &self) -> int { return self.GetDeviceId(); }) #endif + .def("__repr__", string::to_string) .def("__str__", string::to_string); py::class_(m, "XPUPlace", R"DOC( @@ -1479,6 +1480,7 @@ All parameter, weight, gradient are variables in Paddle. .def("get_device_id", [](const platform::XPUPlace &self) { return self.GetDeviceId(); }) #endif + .def("__repr__", string::to_string) .def("__str__", string::to_string); py::class_(m, "CPUPlace", R"DOC( @@ -1500,6 +1502,7 @@ All parameter, weight, gradient are variables in Paddle. .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("__repr__", string::to_string) .def("__str__", string::to_string); py::class_(m, "CUDAPinnedPlace", R"DOC( @@ -1536,6 +1539,7 @@ All parameter, weight, gradient are variables in Paddle. &IsSamePlace) .def("_equals", &IsSamePlace) + .def("__repr__", string::to_string) .def("__str__", string::to_string); py::class_(m, "Place") @@ -1578,10 +1582,13 @@ All parameter, weight, gradient are variables in Paddle. [](platform::Place &self, const platform::CUDAPlace &gpu_place) { self = gpu_place; }) - .def("set_place", [](platform::Place &self, - const platform::CUDAPinnedPlace &cuda_pinned_place) { - self = cuda_pinned_place; - }); + .def("set_place", + [](platform::Place &self, + const platform::CUDAPinnedPlace &cuda_pinned_place) { + self = cuda_pinned_place; + }) + .def("__repr__", string::to_string) + .def("__str__", string::to_string); py::class_(m, "Operator") .def_static( diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 6be7fe0612e5afbd5eeda6261835395dfc6c89fe..904622caf45fc7b9a21b58cb5a591c2cf0b31f24 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1785,8 +1785,6 @@ class ComplexVariable(object): **Notes**: **The constructor of ComplexTensor should not be invoked directly.** - **Only support dygraph mode at present. Please use** :ref:`api_fluid_dygraph_to_variable` **to create a dygraph ComplexTensor with complex number data.** - Args: real (Tensor): The Tensor holding real-part data. imag (Tensor): The Tensor holding imaginery-part data. @@ -1795,14 +1793,14 @@ class ComplexVariable(object): .. code-block:: python import paddle - import numpy as np - - paddle.enable_imperative() x = paddle.to_tensor([1.0+2.0j, 0.2]) print(x.name, x.dtype, x.shape) - # ({'real': 'generated_tensor_0.real', 'imag': 'generated_tensor_0.imag'}, 'complex128', [2L]) - print(x.numpy()) - # [1. +2.j 0.2+0.j] + # ({'real': 'generated_tensor_0.real', 'imag': 'generated_tensor_0.imag'}, complex64, [2]) + print(x) + # ComplexTensor[real](shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [ 1., 0.20000000]) + # ComplexTensor[imag](shape=[2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [2., 0.]) print(type(x)) # """ @@ -1858,9 +1856,10 @@ class ComplexVariable(object): return self.real.numpy() + 1j * self.imag.numpy() def __str__(self): - return "ComplexTensor[real]: %s\n%s\nComplexTensor[imag]: %s\n%s" % ( - self.real.name, str(self.real.value().get_tensor()), self.imag.name, - str(self.imag.value().get_tensor())) + from paddle.tensor.to_string import to_string + return "ComplexTensor containing:\n{real}\n{imag}".format( + real=to_string(self.real, "[real part]Tensor"), + imag=to_string(self.imag, "[imag part]Tensor")) __repr__ = __str__ @@ -5335,16 +5334,13 @@ class ParamBase(core.VarBase): .. code-block:: python import paddle - paddle.disable_static() - conv = paddle.nn.Conv2D(3, 3, 5) - print(conv.weight) - # Parameter: conv2d_0.w_0 - # - place: CUDAPlace(0) - # - shape: [3, 3, 5, 5] - # - layout: NCHW - # - dtype: float - # - data: [...] - paddle.enable_static() + linear = paddle.nn.Linear(3, 3) + print(linear.weight) + # Parameter containing: + # Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [[ 0.48948765, 0.05829060, -0.25524026], + # [-0.70368278, 0.52986908, -0.68742192], + # [-0.54217887, 0.48439729, 0.34082305]]) """ return "Parameter containing:\n{tensor}".format( tensor=super(ParamBase, self).__str__()) diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index ecbf2415247b1fed6eee1775030c5eeed4cb6f5b..42fd2de864d0882eb93cec338dbf42cc2012640c 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -64,6 +64,15 @@ class TestVarBase(unittest.TestCase): y.backward() self.assertTrue( np.array_equal(x.grad, np.array([2.4]).astype('float32'))) + y = x.cpu() + self.assertEqual(y.place.__repr__(), "CPUPlace") + if core.is_compiled_with_cuda(): + y = x.pin_memory() + self.assertEqual(y.place.__repr__(), "CUDAPinnedPlace") + y = x.cuda(blocking=False) + self.assertEqual(y.place.__repr__(), "CUDAPlace(0)") + y = x.cuda(blocking=True) + self.assertEqual(y.place.__repr__(), "CUDAPlace(0)") # set_default_dtype take effect on complex x = paddle.to_tensor(1 + 2j, place=place, stop_gradient=False) diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 65a33ade27a220853bef6d0fdc780e4e51c2b262..8aa94ae4203426623961545b40c81d5dee9ae977 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -90,61 +90,38 @@ def to_tensor(data, dtype=None, place=None, stop_gradient=True): .. code-block:: python import paddle - import numpy as np - paddle.disable_static() type(paddle.to_tensor(1)) # paddle.to_tensor(1) - # Tensor: generated_tensor_0 - # - place: CUDAPlace(0) # allocate on global default place CPU:0 - # - shape: [1] - # - layout: NCHW - # - dtype: int64_t - # - data: [1] + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [1]) x = paddle.to_tensor(1) paddle.to_tensor(x, dtype='int32', place=paddle.CPUPlace()) # A new tensor will be constructed due to different dtype or place - # Tensor: generated_tensor_01 - # - place: CPUPlace - # - shape: [1] - # - layout: NCHW - # - dtype: int - # - data: [1] + # Tensor(shape=[1], dtype=int32, place=CPUPlace, stop_gradient=True, + # [1]) paddle.to_tensor((1.1, 2.2), place=paddle.CUDAPinnedPlace()) - # Tensor: generated_tensor_1 - # - place: CUDAPinnedPlace - # - shape: [2] - # - layout: NCHW - # - dtype: double - # - data: [1.1 2.2] + # Tensor(shape=[1], dtype=float32, place=CUDAPinnedPlace, stop_gradient=True, + # [1]) paddle.to_tensor([[0.1, 0.2], [0.3, 0.4]], place=paddle.CUDAPlace(0), stop_gradient=False) - # Tensor: generated_tensor_2 - # - place: CUDAPlace(0) - # - shape: [2, 2] - # - layout: NCHW - # - dtype: double - # - data: [0.1 0.2 0.3 0.4] + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [[0.10000000, 0.20000000], + # [0.30000001, 0.40000001]]) type(paddle.to_tensor([[1+1j, 2], [3+2j, 4]]), dtype='complex64') # paddle.to_tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64') - # ComplexTensor[real]: generated_tensor_0.real - # - place: CUDAPlace(0) - # - shape: [2, 2] - # - layout: NCHW - # - dtype: float - # - data: [1 2 3 4] - # ComplexTensor[imag]: generated_tensor_0.imag - # - place: CUDAPlace(0) - # - shape: [2, 2] - # - layout: NCHW - # - dtype: float - # - data: [1 0 2 0] + # ComplexTensor[real](shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[1., 2.], + # [3., 4.]]) + # ComplexTensor[imag](shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[1., 0.], + # [2., 0.]]) """ if place is None: diff --git a/tools/wlist.json b/tools/wlist.json index 9844fa486cc044490c8b03891a5a99275def02a0..648cbf6c3b77b2b2b88cb9bcbba3b47123d82ea3 100644 --- a/tools/wlist.json +++ b/tools/wlist.json @@ -24,6 +24,7 @@ } ], "wlist_temp_api":[ + "to_tensor", "LRScheduler", "ReduceOnPlateau", "append_LARS",