diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 50637a0c3d3f9c6975578e94e6ddc2c898c926e0..a56ca342ad1a867ee768d70c95e3f96c76ceda4a 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -60,6 +60,10 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr, size); + } else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_cuda_pinned_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size); } else if (platform::is_gpu_place(src_place) && // NOLINT platform::is_cpu_place(dst_place)) { auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place); @@ -82,6 +86,28 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place, auto stream = reinterpret_cast(ctx).stream(); memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); + } else if (platform::is_gpu_place(src_place) && // NOLINT + platform::is_cuda_pinned_place(dst_place)) { + auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place); + auto dst_cuda_pinned_place = + BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place); + auto ctx_place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx_place), true, + platform::errors::PreconditionNotMet( + "Device context place mismatch. When copying Tensor " + "data from GPU memory to CUDA Pinned memory, current " + "device context place should be GPU.")); + auto ctx_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx_place); + PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place, + platform::errors::PreconditionNotMet( + "The source GPU device and current device context do " + "not match. The source GPU device number is %d, but " + "device context GPU number is %d.", + src_gpu_place.device, ctx_gpu_place.device)); + auto stream = + reinterpret_cast(ctx).stream(); + memory::Copy(dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, + stream); } else if (platform::is_cuda_pinned_place(src_place) && platform::is_gpu_place(dst_place)) { auto src_cuda_pinned_place = @@ -180,6 +206,15 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, memory::Copy(BOOST_GET_CONST(platform::CPUPlace, dst_place), dst_ptr, BOOST_GET_CONST(platform::CUDAPinnedPlace, src_place), src_ptr, size); + } else if (platform::is_cpu_place(src_place) && // NOLINT + platform::is_cuda_pinned_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CPUPlace, src_place), src_ptr, size); + } else if (platform::is_gpu_place(src_place) && // NOLINT + platform::is_cuda_pinned_place(dst_place)) { + memory::Copy(BOOST_GET_CONST(platform::CUDAPinnedPlace, dst_place), dst_ptr, + BOOST_GET_CONST(platform::CUDAPlace, src_place), src_ptr, size, + nullptr); } else if (platform::is_gpu_place(src_place) && // NOLINT platform::is_cpu_place(dst_place)) { auto src_gpu_place = BOOST_GET_CONST(platform::CUDAPlace, src_place); diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index ac1d2bc1f31d62a2ca9ccb9378bc17ac37d09ec9..be55201595e34ef7fa520908b56208b8a6ff1895 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -78,11 +78,15 @@ static void InitTensorForVarBase(imperative::VarBase *self, const py::array &array, const platform::Place place, bool persistable = false, - bool zero_copy = false, - std::string name = "") { + bool zero_copy = false, std::string name = "", + int stop_gradient = -1) { if (name == "") { - name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var"); + name = + imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"); } + VLOG(5) << "Init Tensor as: / name: " << name + << " / persistable: " << persistable << " / zero_copy: " << zero_copy + << " / stop_gradient: " << stop_gradient; new (self) imperative::VarBase(name); auto *tensor = self->MutableVar()->GetMutable(); if (platform::is_cpu_place(place)) { @@ -99,6 +103,9 @@ static void InitTensorForVarBase(imperative::VarBase *self, PADDLE_THROW(platform::errors::InvalidArgument( "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace")); } + if (stop_gradient != -1) { + self->SetOverridedStopGradient(stop_gradient); + } self->SetPersistable(persistable); self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetDataType(tensor->type()); @@ -106,12 +113,11 @@ static void InitTensorForVarBase(imperative::VarBase *self, static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, const py::kwargs &kwargs) { - VLOG(4) << "Init VarBase"; + VLOG(4) << "Init VarBase from kwargs: "; PADDLE_ENFORCE_EQ( kwargs.contains("value"), true, platform::errors::NotFound( "The kwargs used to create Varbase misses argument: value")); - auto persistable = kwargs.contains("persistable") ? kwargs["persistable"].cast() : false; @@ -120,10 +126,14 @@ static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self, auto zero_copy = kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast() : false; auto name = kwargs.contains("name") ? kwargs["name"].cast() : ""; + auto stop_gradient = kwargs.contains("stop_gradient") + ? kwargs["stop_gradient"].cast() + : -1; auto default_place = imperative::GetCurrentTracer()->ExpectedPlace(); auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"]) : default_place; - InitTensorForVarBase(self, array, place, persistable, zero_copy, name); + InitTensorForVarBase(self, array, place, persistable, zero_copy, name, + stop_gradient); } template @@ -131,15 +141,24 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self, const py::array &array, const P &place, bool persistable = false, bool zero_copy = false, - std::string name = "") { - VLOG(4) << "Init VarBase"; - // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name + std::string name = "", + int stop_gradient = -1) { + VLOG(4) << "Init VarBase from Arg: "; + // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name , 6: + // stop_gradient if (name == "") { - name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var"); + name = + imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"); } + VLOG(5) << "Init Tensor as: / name: " << name + << " / persistable: " << persistable << " / zero_copy: " << zero_copy + << " / stop_gradient: " << stop_gradient; new (self) imperative::VarBase(name); self->SetPersistable(persistable); auto *tensor = self->MutableVar()->GetMutable(); + if (stop_gradient != -1) { + self->SetOverridedStopGradient(stop_gradient); + } SetTensorFromPyArray

(tensor, array, place, zero_copy); self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetDataType(tensor->type()); @@ -147,7 +166,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self, static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self, const py::array &array) { - VLOG(4) << "Init VarBase"; + VLOG(4) << "Init VarBase from numpy: "; auto place = imperative::GetCurrentTracer()->ExpectedPlace(); InitTensorForVarBase(self, array, place); } @@ -157,7 +176,7 @@ static void InitVarBaseFromTensorWithArgDefault( VLOG(4) << "Init VarBase"; auto place = imperative::GetCurrentTracer()->ExpectedPlace(); new (self) imperative::VarBase( - imperative::GetCurrentTracer()->GenerateUniqueName("generated_var")); + imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor")); self->SetPersistable(false); self->SetType(framework::proto::VarType::LOD_TENSOR); self->SetDataType(tensor.type()); @@ -551,7 +570,7 @@ void BindImperative(py::module *m_ptr) { std::string act_name = ""; if (!name.ptr() || name.ptr() == Py_None) { act_name = imperative::GetCurrentTracer()->GenerateUniqueName( - "generated_var"); + "generated_tensor"); } else { act_name = name.cast(); } @@ -567,13 +586,16 @@ void BindImperative(py::module *m_ptr) { }) .def("__init__", &InitVarBaseFromNumpyWithArg, py::arg("value"), py::arg("place"), py::arg("persistable") = false, - py::arg("zero_copy") = false, py::arg("name") = "") + py::arg("zero_copy") = false, py::arg("name") = "", + py::arg("stop_gradient") = -1) .def("__init__", &InitVarBaseFromNumpyWithArg, py::arg("value"), py::arg("place"), py::arg("persistable") = false, - py::arg("zero_copy") = false, py::arg("name") = "") + py::arg("zero_copy") = false, py::arg("name") = "", + py::arg("stop_gradient") = -1) .def("__init__", &InitVarBaseFromNumpyWithArg, py::arg("value"), py::arg("place"), py::arg("persistable") = false, - py::arg("zero_copy") = false, py::arg("name") = "") + py::arg("zero_copy") = false, py::arg("name") = "", + py::arg("stop_gradient") = -1) .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value")) .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor")) .def("__init__", &InitVarBaseFromNumpyWithKwargs) @@ -796,6 +818,11 @@ void BindImperative(py::module *m_ptr) { [](const imperative::VarBase &self, const platform::CPUPlace &place, bool blocking) { return self.NewVarBase(place, blocking); }, py::return_value_policy::copy) + .def("_copy_to", + [](const imperative::VarBase &self, + const platform::CUDAPinnedPlace &place, + bool blocking) { return self.NewVarBase(place, blocking); }, + py::return_value_policy::copy) .def("_copy_to", [](const imperative::VarBase &self, const platform::CUDAPlace &place, bool blocking) { return self.NewVarBase(place, blocking); }, @@ -824,6 +851,9 @@ void BindImperative(py::module *m_ptr) { return std::vector(); } }) + .def_property_readonly( + "place", [](imperative::VarBase &self) { return self.Place(); }, + py::return_value_policy::copy) .def_property_readonly("type", &imperative::VarBase::Type) .def_property_readonly("dtype", &imperative::VarBase::DataType); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d68e225849e7fcfa7c7297942df96e2fede30f8e..2bfd4ff49cf0d244c8c54b12cde24dff20511b1a 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1070,7 +1070,7 @@ All parameter, weight, gradient are variables in Paddle. .def("find_var", &Scope::FindVar, py::arg("name"), R"DOC( Find variable named :code:`name` in the current scope or - its parent scope. Return None if not found. + its parent scope. Return None if not found. Args: name (str): the variable name. @@ -1319,12 +1319,16 @@ All parameter, weight, gradient are variables in Paddle. std::exit(-1); #endif }) +#ifdef PADDLE_WITH_CUDA .def("_type", &PlaceIndex) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) .def("_equals", &IsSamePlace) + .def("_get_device_id", + [](platform::CUDAPlace &self) -> int { return self.GetDeviceId(); }) +#endif .def("__str__", string::to_string); py::class_(m, "CPUPlace", R"DOC( diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index ba79c4b44374eb9b50ad4982a2eacd664fc6e75e..c16b22b9fc3aed95d6f55045a90ddad278cbd92e 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -211,7 +211,7 @@ void SetTensorFromPyArrayT( } #else PADDLE_THROW(platform::errors::PermissionDenied( - "Cannot use CUDAPlace in CPU only version, " + "Cannot use CUDAPlace or CUDAPinnedPlace in CPU only version, " "Please recompile or reinstall Paddle with CUDA support.")); #endif } diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 402980a0c1ea4cf0f2c8bac2fc851efb9e7e7403..b3efaff42c4562da225b49bf67c1abd1d39b7c44 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -31,13 +31,15 @@ import paddle.reader import paddle.dataset import paddle.batch batch = batch.batch +import paddle.framework +from .framework import VarBase as Tensor +from .framework import ComplexVariable as ComplexTensor import paddle.compat import paddle.distributed import paddle.sysconfig import paddle.tensor import paddle.nn import paddle.distributed.fleet -import paddle.framework import paddle.optimizer import paddle.metric import paddle.incubate.complex as complex @@ -48,9 +50,7 @@ from .tensor.random import randperm from .tensor.attribute import rank #DEFINE_ALIAS from .tensor.attribute import shape #DEFINE_ALIAS -from .tensor.creation import create_tensor #DEFINE_ALIAS -# from .tensor.creation import create_lod_tensor #DEFINE_ALIAS -# from .tensor.creation import create_random_int_lodtensor #DEFINE_ALIAS +from .tensor.creation import to_tensor #DEFINE_ALIAS from .tensor.creation import crop_tensor #DEFINE_ALIAS from .tensor.creation import diag #DEFINE_ALIAS from .tensor.creation import eye #DEFINE_ALIAS @@ -231,7 +231,6 @@ from .tensor.stat import reduce_mean #DEFINE_ALIAS from .tensor.stat import std #DEFINE_ALIAS from .tensor.stat import var #DEFINE_ALIAS from .fluid.data import data -# from .tensor.tensor import Tensor #DEFINE_ALIAS # from .tensor.tensor import LoDTensor #DEFINE_ALIAS # from .tensor.tensor import LoDTensorArray #DEFINE_ALIAS diff --git a/python/paddle/fluid/data_feeder.py b/python/paddle/fluid/data_feeder.py index e8d708e04ce54bf6589ada0a55de13f06f0ba2a9..45aa85d4168a55e206460ce2e39292013caa9ce0 100644 --- a/python/paddle/fluid/data_feeder.py +++ b/python/paddle/fluid/data_feeder.py @@ -50,14 +50,15 @@ def convert_dtype(dtype): elif isinstance(dtype, type): if dtype in [ np.bool, np.float16, np.float32, np.float64, np.int8, np.int16, - np.int32, np.int64, np.uint8 + np.int32, np.int64, np.uint8, np.complex64, np.complex128 ]: return dtype.__name__ else: if dtype in [ 'bool', 'float16', 'float32', 'float64', 'int8', 'int16', - 'int32', 'int64', 'uint8', u'bool', u'float16', u'float32', - u'float64', u'int8', u'int16', u'int32', u'int64', u'uint8' + 'int32', 'int64', 'uint8', 'complex64', 'complex128', u'bool', + u'float16', u'float32', u'float64', u'int8', u'int16', u'int32', + u'int64', u'uint8', u'complex64', u'complex128' ]: # this code is a little bit dangerous, since error could happen # when casting no-ascii code to str in python2. @@ -68,7 +69,7 @@ def convert_dtype(dtype): raise TypeError( "dtype must be any of [bool, float16, float32, float64, int8, int16, " - "int32, int64, uint8], but received %s" % dtype) + "int32, int64, uint8, complex64, complex128], but received %s" % dtype) def check_variable_and_dtype(input, diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 7b4390c7a7b4e32fcb7937d47bedd875f1236006..9dbaab2580d21397fa7a4e03b03a5f1c4ac887f2 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -50,14 +50,19 @@ def monkey_patch_varbase(): static_var = var_base._to_static_var() """ + + # Note: getattr(self, attr, None) will call x.grad=x.gradient(), but gradient() only available in dygraph. + # It will fail. So, for propery in dygraph only, should not let it getattr(self, attr, None). + attr_not_need_keys = ['grad'] if isinstance(self, ParamBase): attr_kwargs = self.__dict__.copy() else: - attr_names = [ - name for name in dir(self) - if not (inspect.ismethod(getattr(self, name)) or - name.startswith('_')) - ] + attr_names = [] + for name in dir(self): + if name not in attr_not_need_keys and not ( + inspect.ismethod(getattr(self, name)) or + name.startswith('_')): + attr_names.append(name) attr_kwargs = {name: getattr(self, name) for name in attr_names} attr_keys = ['block', 'shape', 'dtype', 'type', 'name', 'persistable'] @@ -216,6 +221,14 @@ def monkey_patch_varbase(): else: return np.array(new_ivar.value().get_tensor()) + @property + def grad(self): + """ + The alias of gradient(). + """ + + return self.gradient() + def __str__(self): """ Convert a VarBase object to a readable string. @@ -239,9 +252,9 @@ def monkey_patch_varbase(): """ tensor = self.value().get_tensor() if tensor._is_initialized(): - return 'Variable: %s\n%s' % (self.name, str(tensor)) + return 'Tensor: %s\n%s' % (self.name, str(tensor)) else: - return 'Variable: %s, not initialized' % (self.name) + return 'Tensor: %s, not initialized' % (self.name) @property def block(self): @@ -260,8 +273,9 @@ def monkey_patch_varbase(): for method_name, method in ( ("__bool__", __bool__), ("__nonzero__", __nonzero__), ("_to_static_var", _to_static_var), ("set_value", set_value), - ("block", block), ("backward", backward), ("gradient", gradient), - ("__str__", __str__)): + ("block", block), ("backward", backward), ("grad", grad), + ("gradient", gradient), ("__str__", __str__), ("__repr__", __str__), + ("__module__", "paddle"), ("__name__", "Tensor")): setattr(core.VarBase, method_name, method) # patch math methods for varbase diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 8fe22024e6f12238e1b5bdb5adab052aff811b04..fe0aba6f243609d79eb8e11664711c33c5d1ac2f 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1689,34 +1689,40 @@ def get_all_op_protos(): class ComplexVariable(object): """ - The Variable defined on the complex number domain. It contains two common - real number Variables as its members, :attr:`real` and :attr:`imag` + The ComplexTensor defined on the complex number domain. It contains two common + real number Tensor as its members, :attr:`real` and :attr:`imag` holding the real part and imaginary part of complex numbers respectively. **Notes**: - **The constructor of ComplexVariable should not be invoked directly.** + **The constructor of ComplexTensor should not be invoked directly.** - **Only support dygraph mode at present. Please use** :ref:`api_fluid_dygraph_to_variable` **to create a dygraph ComplexVariable with complex number data.** + **Only support dygraph mode at present. Please use** :ref:`api_fluid_dygraph_to_variable` **to create a dygraph ComplexTensor with complex number data.** Args: - real (Variable): The Variable holding real-part data. - imag (Variable): The Variable holding imaginery-part data. + real (Tensor): The Tensor holding real-part data. + imag (Tensor): The Tensor holding imaginery-part data. Examples: .. code-block:: python - import paddle.fluid as fluid + import paddle import numpy as np - a = np.array([1.0+2.0j, 0.2]) - with fluid.dygraph.guard(): - var = fluid.dygraph.to_variable(a, name="new_var") - print(var.name, var.dtype, var.shape) - # ({'real': u'new_var.real', 'imag': u'new_var.imag'}, 'complex128', [2L]) - print(var.numpy()) - # [1. +2.j 0.2+0.j] + paddle.enable_imperative() + x = paddle.to_tensor([1.0+2.0j, 0.2]) + print(x.name, x.dtype, x.shape) + # ({'real': 'generated_tensor_0.real', 'imag': 'generated_tensor_0.imag'}, 'complex128', [2L]) + print(x.numpy()) + # [1. +2.j 0.2+0.j] + print(type(x)) + # """ + def __new__(cls, *arg, **kwargs): + cls.__module__ = "paddle" + cls.__name__ = "ComplexTensor" + return super(ComplexVariable, cls).__new__(cls) + def __init__(self, real, imag): assert real.shape == imag.shape, "The real part and imaginary part " \ "of a ComplexVariable should have the same shape!" @@ -1763,7 +1769,9 @@ class ComplexVariable(object): return self.real.numpy() + 1j * self.imag.numpy() def __str__(self): - return "REAL: " + self.real.__str__() + "IMAG: " + self.imag.__str__() + return "ComplexTensor[real]: %s\n%s\nComplexTensor[imag]: %s\n%s" % ( + self.real.name, str(self.real.value().get_tensor()), self.imag.name, + str(self.imag.value().get_tensor())) __repr__ = __str__ @@ -5092,12 +5100,13 @@ class Parameter(Variable): class ParamBase(core.VarBase): """ - ParamBase is derived from VarBase( Which is the Variable in Dygraph Mode ). A ParamBase is a persistable - VarBase, and will be updated by optimizers after each iteration. + ParamBase is derived from Tensor( Which is the concept in Dygraph Mode). + A ParamBase is a persistable Tensor, and will be updated by optimizers + after each iteration. The training of a neural network is essentially the updating of its ParamBase. - Relative to a general Variable, a ParamBase has several its own + Relative to a general Tensor, a ParamBase has several its own member variables: Args: @@ -5186,11 +5195,8 @@ class ParamBase(core.VarBase): # - data: [...] paddle.enable_static() """ - tensor = self.value().get_tensor() - if tensor._is_initialized(): - return 'Parameter: %s\n%s' % (self.name, str(tensor)) - else: - return 'Parameter: %s, not initialized' % (self.name) + return "Parameter containing:\n {}\n - stop_gradient: {}".format( + super(ParamBase, self).__str__(), self.stop_gradient) __repr__ = __str__ diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 7657485f09b65e3395bc1e87d942c53aae2b3e8a..9a2b1108e1ee9d3dc46b01d39ea6acb997c4d5eb 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -12049,8 +12049,8 @@ def logical_and(x, y, out=None, name=None): paddle.disable_static() x_data = np.array([True, True, False, False], dtype=np.bool) y_data = np.array([True, False, True, False], dtype=np.bool) - x = paddle.to_variable(x_data) - y = paddle.to_variable(y_data) + x = paddle.to_tensor(x_data) + y = paddle.to_tensor(y_data) res = paddle.logical_and(x, y) print(res.numpy()) # [True False False False] """ diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 7e565ca31b219366b7ab83267b46f32e5812d983..80b94704c388824901312b5d577cb5cfd0d0c75b 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest from paddle.fluid.framework import default_main_program, Program, convert_np_dtype_to_dtype_, in_dygraph_mode +import paddle import paddle.fluid as fluid import paddle.fluid.layers as layers import paddle.fluid.core as core @@ -28,6 +29,74 @@ class TestVarBase(unittest.TestCase): self.dtype = np.float32 self.array = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + def test_to_tensor(self): + def _test_place(place): + with fluid.dygraph.guard(): + x = paddle.to_tensor( + 1, dtype='float32', place=place, stop_gradient=False) + self.assertTrue(np.array_equal(x.numpy(), [1.])) + self.assertEqual(x.dtype, core.VarDesc.VarType.FP32) + self.assertEqual(x.shape, [1]) + self.assertEqual(x.stop_gradient, False) + self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR) + + x = paddle.to_tensor( + (1, 2), dtype='float32', place=place, stop_gradient=False) + x = paddle.to_tensor( + [1, 2], dtype='float32', place=place, stop_gradient=False) + self.assertTrue(np.array_equal(x.numpy(), [1., 2.])) + self.assertEqual(x.dtype, core.VarDesc.VarType.FP32) + self.assertEqual(x.grad, None) + self.assertEqual(x.shape, [2]) + self.assertEqual(x.stop_gradient, False) + self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR) + + x = paddle.to_tensor( + self.array, + dtype='float32', + place=place, + stop_gradient=False) + self.assertTrue(np.array_equal(x.numpy(), self.array)) + self.assertEqual(x.dtype, core.VarDesc.VarType.FP32) + self.assertEqual(x.shape, self.shape) + self.assertEqual(x.stop_gradient, False) + self.assertEqual(x.type, core.VarDesc.VarType.LOD_TENSOR) + + y = paddle.to_tensor(x) + y = paddle.to_tensor(y, dtype='float64', place=place) + self.assertTrue(np.array_equal(y.numpy(), self.array)) + self.assertEqual(y.dtype, core.VarDesc.VarType.FP64) + self.assertEqual(y.shape, self.shape) + self.assertEqual(y.stop_gradient, True) + self.assertEqual(y.type, core.VarDesc.VarType.LOD_TENSOR) + z = x + y + self.assertTrue(np.array_equal(z.numpy(), 2 * self.array)) + + x = paddle.to_tensor( + [1 + 2j, 1 - 2j], dtype='complex64', place=place) + y = paddle.to_tensor(x) + self.assertTrue(np.array_equal(x.numpy(), [1 + 2j, 1 - 2j])) + self.assertEqual(y.dtype, 'complex64') + self.assertEqual(y.shape, [2]) + self.assertEqual(y.real.stop_gradient, True) + self.assertEqual(y.real.type, core.VarDesc.VarType.LOD_TENSOR) + + with self.assertRaises(TypeError): + paddle.to_tensor('test') + with self.assertRaises(TypeError): + paddle.to_tensor(1, dtype='test') + with self.assertRaises(ValueError): + paddle.to_tensor([[1], [2, 3]]) + with self.assertRaises(ValueError): + paddle.to_tensor([[1], [2, 3]], place='test') + with self.assertRaises(ValueError): + paddle.to_tensor([[1], [2, 3]], place=1) + + _test_place(core.CPUPlace()) + if core.is_compiled_with_cuda(): + _test_place(core.CUDAPinnedPlace()) + _test_place(core.CUDAPlace(0)) + def test_to_variable(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array, name="abc") @@ -76,7 +145,7 @@ class TestVarBase(unittest.TestCase): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) - self.assertEqual(var.name, 'generated_var_0') + self.assertEqual(var.name, 'generated_tensor_0') var.name = 'test' self.assertEqual(var.name, 'test') diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index 20f1b453a0cd37aaf0888991a3f20c9e68c438d0..215546293a4065da87b996826b3937361a4ca54e 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -32,12 +32,14 @@ from . import random from .random import manual_seed from ..fluid.framework import Variable #DEFINE_ALIAS +from ..fluid.framework import ComplexVariable #DEFINE_ALIAS from ..fluid.param_attr import ParamAttr #DEFINE_ALIAS from ..fluid.layers.tensor import create_global_var #DEFINE_ALIAS from ..fluid.layers.tensor import create_parameter #DEFINE_ALIAS from ..fluid.core import CPUPlace #DEFINE_ALIAS from ..fluid.core import CUDAPlace #DEFINE_ALIAS from ..fluid.core import CUDAPinnedPlace #DEFINE_ALIAS +from ..fluid.core import VarBase #DEFINE_ALIAS from paddle.fluid import core #DEFINE_ALIAS from ..fluid.dygraph.base import no_grad #DEFINE_ALIAS diff --git a/python/paddle/incubate/complex/tensor/math.py b/python/paddle/incubate/complex/tensor/math.py index 52fdbcbc82be291f356067258789c876fede8f16..231cbd918281216eae5112c42544c8dfececd9d4 100644 --- a/python/paddle/incubate/complex/tensor/math.py +++ b/python/paddle/incubate/complex/tensor/math.py @@ -262,7 +262,7 @@ def trace(x, offset=0, axis1=0, axis2=1, name=None): case1 = np.random.randn(3, 10, 10).astype('float64') + 1j * np.random.randn(3, 10, 10).astype('float64') paddle.disable_static() - case1 = paddle.to_variable(case1) + case1 = paddle.to_tensor(case1) data1 = paddle.complex.trace(case1, offset=1, axis1=1, axis2=2) # data1.shape = [3] """ complex_variable_exists([x], "trace") diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 75ba7d2114a2b11c664f2062616c168369acf6bd..676b122b37003abeaeacaa3f88005dc89355e310 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -461,7 +461,7 @@ def softmax(x, axis=-1, name=None): [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [6.0, 7.0, 8.0, 9.0]]], 'float32') - x = paddle.to_variable(x) + x = paddle.to_tensor(x) out = F.softmax(x) # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426], # [0.0320586 , 0.08714432, 0.23688282, 0.64391426], diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index fd418300fa3451d7f7d540be88f76a07f0cc0f7a..992613f2f0aaa7079d83112445ee45c6a28a4348 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -283,7 +283,7 @@ class LeakyReLU(layers.Layer): paddle.disable_static() lrelu = paddle.nn.LeakyReLU() - x = paddle.to_variable(np.array([-2, 0, 1], 'float32')) + x = paddle.to_tensor(np.array([-2, 0, 1], 'float32')) out = lrelu(x) # [-0.02, 0., 1.] """ diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index cddb365e4359f4a9b8f7401894c3cb5aaa92ff57..aa0d8c408899aa06f6f0e98816c5831aad84d732 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -22,9 +22,7 @@ from __future__ import print_function from .random import randperm from .attribute import rank #DEFINE_ALIAS from .attribute import shape #DEFINE_ALIAS -from .creation import create_tensor #DEFINE_ALIAS -# from .creation import create_lod_tensor #DEFINE_ALIAS -# from .creation import create_random_int_lodtensor #DEFINE_ALIAS +from .creation import to_tensor #DEFINE_ALIAS from .creation import crop_tensor #DEFINE_ALIAS from .creation import diag #DEFINE_ALIAS from .creation import eye #DEFINE_ALIAS @@ -179,6 +177,5 @@ from .stat import mean #DEFINE_ALIAS from .stat import reduce_mean #DEFINE_ALIAS from .stat import std #DEFINE_ALIAS from .stat import var #DEFINE_ALIAS -# from .tensor import Tensor #DEFINE_ALIAS # from .tensor import LoDTensor #DEFINE_ALIAS # from .tensor import LoDTensorArray #DEFINE_ALIAS diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 0875fb4c219a0876eab7595d654ed144aedaeac7..55bf1344014c3de963a001e95c940100f49a8300 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -13,7 +13,12 @@ # limitations under the License. from __future__ import print_function +import numpy as np + from ..fluid.framework import Variable +from ..fluid.framework import unique_name +from ..fluid.framework import _current_expected_place +from ..fluid.framework import dygraph_only from ..fluid.initializer import Constant from ..fluid.layers import core from ..fluid.layer_helper import LayerHelper @@ -21,20 +26,16 @@ from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtyp from ..fluid.framework import convert_np_dtype_to_dtype_, in_dygraph_mode, _varbase_creator, device_guard, OpProtoHolder from ..fluid.layers import fill_constant from paddle.common_ops_import import * -import paddle # TODO: define functions to get create a tensor from ..fluid.layers import crop_tensor #DEFINE_ALIAS from ..fluid.layers import diag #DEFINE_ALIAS from ..fluid.layers import fill_constant #DEFINE_ALIAS -from ..fluid.layers import create_tensor #DEFINE_ALIAS from ..fluid.layers import linspace #DEFINE_ALIAS import paddle __all__ = [ - 'create_tensor', - # 'create_lod_tensor', - # 'create_random_int_lodtensor', + 'to_tensor', 'crop_tensor', 'diag', 'fill_constant', @@ -54,6 +55,170 @@ __all__ = [ ] +@dygraph_only +def to_tensor(data, dtype=None, place=None, stop_gradient=True): + """ + Constructs a ``paddle.Tensor`` or ``paddle.ComplexTensor`` from ``data`` , + which can be scalar, tuple, list, numpy\.ndarray, paddle\.Tensor, paddle\.ComplexTensor. + + If the ``data`` is already a tensor, and ``dtype`` or ``place`` does't change, no copy + will be performed and return origin tensor, otherwise a new tensor will be constructed + and returned. Similarly, if the data is an numpy\.ndarray of with the same ``dtype`` + and the current place is cpu, no copy will be performed. + + The ``ComplexTensor`` is a unique type of paddle. If x is ``ComplexTensor``, then + ``x.real`` is the real part, and ``x.imag`` is the imaginary part. + + Args: + data(scalar|tuple|list|ndarray|Tensor|ComplexTensor): Initial data for the tensor. + Can be a scalar, list, tuple, numpy\.ndarray, paddle\.Tensor, paddle\.ComplexTensor. + dtype(str, optional): The desired data type of returned tensor. Can be 'bool' , 'float16' , + 'float32' , 'float64' , 'int8' , 'int16' , 'int32' , 'int64' , 'uint8'. And + 'complex64' , 'complex128' only for ComplexTensor. + Default: None, infers data type from ``data`` . + place(CPUPlace|CUDAPinnedPlace|CUDAPlace, optional): The place to allocate Tensor. Can be + CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. + stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True. + + Returns: + Tensor: A Tensor or ComplexTensor constructed from ``data``. + + Raises: + TypeError: If the data type of ``data`` is not scalar, list, tuple, numpy.ndarray, paddle.Tensor, paddle.ComplexTensor + ValueError: If ``data`` is tuple|list, it can't contain nested tuple|list with different lengths , such as: [[1, 2], [3, 4, 5]] + TypeError: If ``dtype`` is not bool, float16, float32, float64, int8, int16, int32, int64, uint8, complex64, complex128 + ValueError: If ``place`` is not paddle.Place, paddle.CUDAPinnedPlace, paddle.CUDAPlace + + Examples: + + .. code-block:: python + + import paddle + import numpy as np + paddle.enable_imperative() + + type(paddle.to_tensor(1)) + # + + paddle.to_tensor(1) + # Tensor: generated_tensor_0 + # - place: CUDAPlace(0) # allocate on global default place CPU:0 + # - shape: [1] + # - layout: NCHW + # - dtype: int64_t + # - data: [1] + + x = paddle.to_tensor(1) + paddle.to_tensor(x, dtype='int32', place=paddle.CPUPlace()) # A new tensor will be constructed due to different dtype or place + # Tensor: generated_tensor_01 + # - place: CPUPlace + # - shape: [1] + # - layout: NCHW + # - dtype: int + # - data: [1] + + paddle.to_tensor((1.1, 2.2), place=paddle.CUDAPinnedPlace()) + # Tensor: generated_tensor_1 + # - place: CUDAPinnedPlace + # - shape: [2] + # - layout: NCHW + # - dtype: double + # - data: [1.1 2.2] + + paddle.to_tensor([[0.1, 0.2], [0.3, 0.4]], place=paddle.CUDAPlace(0), stop_gradient=False) + # Tensor: generated_tensor_2 + # - place: CUDAPlace(0) + # - shape: [2, 2] + # - layout: NCHW + # - dtype: double + # - data: [0.1 0.2 0.3 0.4] + + type(paddle.to_tensor([[1+1j, 2], [3+2j, 4]]), , dtype='complex64') + # + + paddle.to_tensor([[1+1j, 2], [3+2j, 4]], dtype='complex64') + # ComplexTensor[real]: generated_tensor_0.real + # - place: CUDAPlace(0) + # - shape: [2, 2] + # - layout: NCHW + # - dtype: float + # - data: [1 2 3 4] + # ComplexTensor[imag]: generated_tensor_0.imag + # - place: CUDAPlace(0) + # - shape: [2, 2] + # - layout: NCHW + # - dtype: float + # - data: [1 0 2 0] + """ + + if place is None: + place = _current_expected_place() + elif not isinstance(place, + (core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace)): + raise ValueError( + "'place' must be any of paddle.Place, paddle.CUDAPinnedPlace, paddle.CUDAPlace" + ) + + #Todo(zhouwei): Support allocate tensor on any other specified card + if isinstance(place, core.CUDAPlace) and isinstance( + _current_expected_place(), core.CUDAPlace) and place._get_device_id( + ) != _current_expected_place()._get_device_id(): + place = _current_expected_place() + + if not isinstance(data, np.ndarray): + if np.isscalar(data) and not isinstance(data, str): + data = np.array([data]) + elif isinstance(data, (list, tuple)): + data = np.array(data) + if data.dtype == np.object: + raise ValueError( + "\n\tFaild to convert input data to a regular ndarray :\n\t - Usually " + "this means the input data contains nested lists with different lengths. " + ) + elif isinstance(data, paddle.Tensor): + data.stop_gradient = stop_gradient + if not data.place._equals(place): + data = data._copy_to(place, False) + if dtype: + if convert_dtype(dtype) != convert_dtype(data.dtype): + return data.astype(convert_dtype(dtype)) + return data + elif isinstance(data, paddle.ComplexTensor): + return data + else: + raise TypeError( + "Can't constructs a 'paddle.Tensor' with data type {}, data type must be scalar|list|tuple|numpy.ndarray|paddle.Tensor|paddle.ComplexTensor". + format(type(data))) + + if dtype: + dtype = convert_dtype(dtype) + if dtype != data.dtype: + data = data.astype(dtype) + + if not np.iscomplexobj(data): + return paddle.Tensor( + value=data, + place=place, + persistable=False, + zero_copy=True, + stop_gradient=stop_gradient) + else: + name = unique_name.generate('generated_tensor') + real_tensor = paddle.Tensor( + value=data.real, + place=place, + zero_copy=True, + name=name + ".real", + stop_gradient=stop_gradient) + imag_tensor = paddle.Tensor( + value=data.imag, + place=place, + zero_copy=True, + name=name + ".imag", + stop_gradient=stop_gradient) + return paddle.ComplexTensor(real_tensor, imag_tensor) + + def full_like(x, fill_value, dtype=None, name=None): """ :alias_main: paddle.full_like @@ -201,7 +366,7 @@ def ones_like(x, dtype=None, name=None): paddle.disable_static() - x = paddle.to_variable(np.array([1,2,3], dtype='float32')) + x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) out1 = paddle.zeros_like(x) # [1., 1., 1.] out2 = paddle.zeros_like(x, dtype='int32') # [1, 1, 1] @@ -291,7 +456,7 @@ def zeros_like(x, dtype=None, name=None): paddle.disable_static() - x = paddle.to_variable(np.array([1,2,3], dtype='float32')) + x = paddle.to_tensor(np.array([1,2,3], dtype='float32')) out1 = paddle.zeros_like(x) # [0., 0., 0.] out2 = paddle.zeros_like(x, dtype='int32') # [0, 0, 0] @@ -471,7 +636,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None): out3 = paddle.arange(4.999, dtype='float32') # [0., 1., 2., 3., 4.] - start_var = paddle.to_variable(np.array([3])) + start_var = paddle.to_tensor(np.array([3])) out4 = paddle.arange(start_var, 7) # [3, 4, 5, 6] @@ -713,8 +878,8 @@ def meshgrid(*args, **kwargs): input_3 = np.random.randint(0, 100, [100, ]).astype('int32') input_4 = np.random.randint(0, 100, [200, ]).astype('int32') - tensor_3 = paddle.to_variable(input_3) - tensor_4 = paddle.to_variable(input_4) + tensor_3 = paddle.to_tensor(input_3) + tensor_4 = paddle.to_tensor(input_4) grid_x, grid_y = paddle.tensor.meshgrid(tensor_3, tensor_4) #the shape of grid_x is (100, 200)