From 0eb03ed7fe27d37da2b6a4e0a2591b937c199a2f Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Thu, 23 Dec 2021 13:39:15 +0800 Subject: [PATCH] add new API: paddle.clone;Tensor.element_size;nn.utils.parameters_to_vector (#38020) * add new API: paddle.clone;Tensor.element_size;nn.utils.parameters_to_vector * fix comment --- CMakeLists.txt | 2 + paddle/fluid/framework/var_desc.cc | 5 + paddle/fluid/framework/var_desc.h | 2 + paddle/fluid/imperative/layer.h | 4 +- paddle/fluid/pybind/imperative.cc | 79 ++++++++---- paddle/fluid/pybind/protobuf.cc | 2 + python/paddle/__init__.py | 2 + python/paddle/fluid/framework.py | 27 ++++ .../fluid/tests/unittests/test_assign_op.py | 25 ++++ .../fluid/tests/unittests/test_parameter.py | 45 ++++--- .../fluid/tests/unittests/test_var_base.py | 35 +++++ .../fluid/tests/unittests/test_variable.py | 29 +++++ python/paddle/nn/utils/__init__.py | 3 +- .../paddle/nn/utils/transform_parameters.py | 122 ++++++++++++++++++ python/paddle/tensor/creation.py | 33 ++++- 15 files changed, 374 insertions(+), 41 deletions(-) create mode 100644 python/paddle/nn/utils/transform_parameters.py diff --git a/CMakeLists.txt b/CMakeLists.txt index f122dbb9cfc..43bd4e0fcf8 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -126,6 +126,8 @@ if(WIN32) endforeach(flag_var) endif() + # NOTE(zhouwei): msvc max/min macro conflict with std::min/max, define NOMINMAX globally + add_definitions("-DNOMINMAX") # windows build turn off warnings, use parallel compiling. foreach(flag_var CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE diff --git a/paddle/fluid/framework/var_desc.cc b/paddle/fluid/framework/var_desc.cc index 41fe9fbbc03..0a24efd003b 100644 --- a/paddle/fluid/framework/var_desc.cc +++ b/paddle/fluid/framework/var_desc.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/var_desc.h" #include "glog/logging.h" +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { @@ -116,6 +117,10 @@ proto::VarType::Type VarDesc::GetDataType() const { return tensor_desc().data_type(); } +size_t VarDesc::ElementSize() const { + return framework::SizeOfType(GetDataType()); +} + std::vector VarDesc::GetDataTypes() const { std::vector descs = tensor_descs(); std::vector res; diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index a6f56ad4458..afe420dd25f 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -96,6 +96,8 @@ class VarDesc { proto::VarType::Type GetDataType() const; + size_t ElementSize() const; + std::vector GetDataTypes() const; void SetLoDLevel(int32_t lod_level); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 892c864027d..199d62bff1f 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -25,6 +25,7 @@ #include #include +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/var_type.h" @@ -37,7 +38,6 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" #include "paddle/pten/include/core.h" - namespace paddle { namespace framework { class Variable; @@ -212,6 +212,8 @@ class VarBase { framework::proto::VarType::Type DataType() const { return var_->DataType(); } + size_t ElementSize() const { return framework::SizeOfType(var_->DataType()); } + void SetForwardDataType(framework::proto::VarType::Type data_type) { var_->SetForwardDataType(data_type); } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index a99f6761ee3..e981de44c5a 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -2013,6 +2013,29 @@ void BindImperative(py::module *m_ptr) { auto *t = self->MutableVar()->GetMutable(); return t->numel(); }) + .def("element_size", &imperative::VarBase::ElementSize, R"DOC( + Returns the size in bytes of an element in the Tensor. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.to_tensor(1, dtype='bool') + x.element_size() # 1 + + x = paddle.to_tensor(1, dtype='float16') + x.element_size() # 2 + + x = paddle.to_tensor(1, dtype='float32') + x.element_size() # 4 + + x = paddle.to_tensor(1, dtype='float64') + x.element_size() # 8 + + x = paddle.to_tensor(1, dtype='complex128') + x.element_size() # 16 + )DOC") .def_property("name", &imperative::VarBase::Name, &imperative::VarBase::SetName) .def_property("stop_gradient", @@ -2020,28 +2043,40 @@ void BindImperative(py::module *m_ptr) { &imperative::VarBase::SetOverridedStopGradient) .def_property("persistable", &imperative::VarBase::Persistable, &imperative::VarBase::SetPersistable) - .def_property_readonly( - "shape", - [](imperative::VarBase &self) { - if (self.Var().IsType()) { - return framework::vectorize( - self.Var().Get().dims()); - } else if (self.Var().IsType()) { - return framework::vectorize( - self.Var().Get().value().dims()); - } else if (self.Var().IsType()) { - return std::vector{static_cast( - self.Var().Get().size())}; - } else if (self.Var().IsType()) { - return std::vector{ - static_cast(self.Var().Get().size())}; - } else { - VLOG(2) << "It is meaningless to get shape of " - "variable type " - << GetTypeName(self); - return std::vector(); - } - }) + .def_property_readonly("shape", + [](imperative::VarBase &self) { + if (self.Var().IsType()) { + return framework::vectorize( + self.Var() + .Get() + .dims()); + } else if (self.Var() + .IsType< + framework::SelectedRows>()) { + return framework::vectorize( + self.Var() + .Get() + .value() + .dims()); + } else if (self.Var() + .IsType()) { + return std::vector{static_cast( + self.Var() + .Get() + .size())}; + } else if (self.Var() + .IsType()) { + return std::vector{static_cast( + self.Var() + .Get() + .size())}; + } else { + VLOG(2) << "It is meaningless to get shape of " + "variable type " + << GetTypeName(self); + return std::vector(); + } + }) .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf, R"DOC( Whether a Tensor is leaf Tensor. diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 9e5e391920b..44a8a54c8c1 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -179,6 +179,8 @@ void BindVarDsec(pybind11::module *m) { pybind11::return_value_policy::reference) .def("dtype", &pd::VarDesc::GetDataType, pybind11::return_value_policy::reference) + .def("element_size", &pd::VarDesc::ElementSize, + pybind11::return_value_policy::reference) .def("dtypes", &pd::VarDesc::GetDataTypes, pybind11::return_value_policy::reference) .def("lod_level", &pd::VarDesc::GetLoDLevel) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 3e808262d5d..e0e33d3805e 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -91,6 +91,7 @@ from .tensor.creation import empty # noqa: F401 from .tensor.creation import empty_like # noqa: F401 from .tensor.creation import assign # noqa: F401 from .tensor.creation import complex # noqa: F401 +from .tensor.creation import clone # noqa: F401 from .tensor.linalg import matmul # noqa: F401 from .tensor.linalg import dot # noqa: F401 from .tensor.linalg import norm # noqa: F401 @@ -587,4 +588,5 @@ __all__ = [ # noqa 'fmin', 'moveaxis', 'repeat_interleave', + 'clone', ] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index fe041ded8ec..dd83fc58e00 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1396,6 +1396,33 @@ class Variable(object): __repr__ = __str__ + def element_size(self): + """ + Returns the size in bytes of an element in the Tensor. + + Examples: + .. code-block:: python + + import paddle + paddle.enable_static() + + x = paddle.static.data(name='x1', shape=[3, 2], dtype='bool') + x.element_size() # 1 + + x = paddle.static.data(name='x2', shape=[3, 2], dtype='int16') + x.element_size() # 2 + + x = paddle.static.data(name='x3', shape=[3, 2], dtype='float16') + x.element_size() # 2 + + x = paddle.static.data(name='x4', shape=[3, 2], dtype='float32') + x.element_size() # 4 + + x = paddle.static.data(name='x5', shape=[3, 2], dtype='float64') + x.element_size() # 8 + """ + return self.desc.element_size() + @property def stop_gradient(self): """ diff --git a/python/paddle/fluid/tests/unittests/test_assign_op.py b/python/paddle/fluid/tests/unittests/test_assign_op.py index 7513d8810e6..3dbd9311a71 100644 --- a/python/paddle/fluid/tests/unittests/test_assign_op.py +++ b/python/paddle/fluid/tests/unittests/test_assign_op.py @@ -169,6 +169,31 @@ class TestAssignOApi(unittest.TestCase): self.assertTrue(np.allclose(result3.numpy(), np.array([1]))) paddle.enable_static() + def test_clone(self): + paddle.disable_static() + x = paddle.ones([2]) + x.stop_gradient = False + clone_x = paddle.clone(x) + + y = clone_x**3 + y.backward() + + self.assertTrue(np.array_equal(x, [1, 1]), True) + self.assertTrue(np.array_equal(clone_x.grad.numpy(), [3, 3]), True) + self.assertTrue(np.array_equal(x.grad.numpy(), [3, 3]), True) + paddle.enable_static() + + with program_guard(Program(), Program()): + x_np = np.random.randn(2, 3).astype('float32') + x = paddle.static.data("X", shape=[2, 3]) + clone_x = paddle.clone(x) + exe = paddle.static.Executor() + y_np = exe.run(paddle.static.default_main_program(), + feed={'X': x_np}, + fetch_list=[clone_x])[0] + + self.assertTrue(np.array_equal(y_np, x_np), True) + class TestAssignOpErrorApi(unittest.TestCase): def test_errors(self): diff --git a/python/paddle/fluid/tests/unittests/test_parameter.py b/python/paddle/fluid/tests/unittests/test_parameter.py index 46e211f4729..85ba69cd438 100644 --- a/python/paddle/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_parameter.py @@ -18,18 +18,19 @@ import unittest import copy import paddle from paddle.fluid.dygraph import guard -from paddle.fluid.framework import default_main_program +from paddle.fluid.framework import default_main_program, Variable import paddle.fluid.core as core from paddle.fluid.executor import Executor import paddle.fluid.io as io from paddle.fluid.initializer import ConstantInitializer import numpy as np +paddle.enable_static() main_program = default_main_program() class ParameterChecks(unittest.TestCase): - def check_parameter(self): + def test_parameter(self): shape = [784, 100] val = 1.0625 b = main_program.global_block() @@ -43,13 +44,13 @@ class ParameterChecks(unittest.TestCase): self.assertEqual((784, 100), param.shape) self.assertEqual(core.VarDesc.VarType.FP32, param.dtype) self.assertEqual(0, param.block.idx) - exe = Executor(core.CPUPlace()) + exe = Executor(paddle.CPUPlace()) p = exe.run(main_program, fetch_list=[param])[0] - self.assertTrue(np.allclose(p, np.ones(shape) * val)) + self.assertTrue(np.array_equal(p, np.ones(shape) * val)) p = io.get_parameter_value_by_name('fc.w', exe, main_program) - self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) + self.assertTrue(np.array_equal(p, np.ones(shape) * val)) - def check_parambase(self): + def test_parambase(self): with guard(): linear = paddle.nn.Linear(10, 10) param = linear.weight @@ -71,7 +72,7 @@ class ParameterChecks(unittest.TestCase): pram_copy2 = copy.deepcopy(param, memo) self.assertEqual(id(param_copy), id(pram_copy2)) - def check_exceptions(self): + def test_exception(self): b = main_program.global_block() with self.assertRaises(ValueError): b.create_parameter( @@ -86,16 +87,30 @@ class ParameterChecks(unittest.TestCase): b.create_parameter( name='test', shape=[-1], dtype='float32', initializer=None) + def test_parambase_to_vector(self): + with guard(): + initializer = paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(3.)) + linear1 = paddle.nn.Linear(10, 15, initializer) -class TestParameter(ParameterChecks): - def _test_parameter(self): - self.check_parameter() - - def test_parambase(self): - self.check_parambase() + vec = paddle.nn.utils.parameters_to_vector(linear1.parameters()) + self.assertEqual(linear1.weight.shape, [10, 15]) + self.assertEqual(linear1.bias.shape, [15]) + self.assertTrue(isinstance(vec, Variable)) + self.assertTrue(vec.shape, [165]) - def test_exceptions(self): - self.check_exceptions() + linear2 = paddle.nn.Linear(10, 15) + paddle.nn.utils.vector_to_parameters(vec, linear2.parameters()) + self.assertEqual(linear2.weight.shape, [10, 15]) + self.assertEqual(linear2.bias.shape, [15]) + self.assertTrue( + np.array_equal(linear1.weight.numpy(), linear2.weight.numpy()), + True) + self.assertTrue( + np.array_equal(linear1.bias.numpy(), linear2.bias.numpy()), + True) + self.assertTrue(linear2.weight.is_leaf, True) + self.assertTrue(linear2.bias.is_leaf, True) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index ab6e8003833..c4c4edbbb93 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -497,6 +497,41 @@ class TestVarBase(unittest.TestCase): var = fluid.dygraph.to_variable(self.array) self.assertTrue(isinstance(str(var), str)) + def test_element_size(self): + with fluid.dygraph.guard(): + x = paddle.to_tensor(1, dtype='bool') + self.assertEqual(x.element_size(), 1) + + x = paddle.to_tensor(1, dtype='float16') + self.assertEqual(x.element_size(), 2) + + x = paddle.to_tensor(1, dtype='float32') + self.assertEqual(x.element_size(), 4) + + x = paddle.to_tensor(1, dtype='float64') + self.assertEqual(x.element_size(), 8) + + x = paddle.to_tensor(1, dtype='int8') + self.assertEqual(x.element_size(), 1) + + x = paddle.to_tensor(1, dtype='int16') + self.assertEqual(x.element_size(), 2) + + x = paddle.to_tensor(1, dtype='int32') + self.assertEqual(x.element_size(), 4) + + x = paddle.to_tensor(1, dtype='int64') + self.assertEqual(x.element_size(), 8) + + x = paddle.to_tensor(1, dtype='uint8') + self.assertEqual(x.element_size(), 1) + + x = paddle.to_tensor(1, dtype='complex64') + self.assertEqual(x.element_size(), 8) + + x = paddle.to_tensor(1, dtype='complex128') + self.assertEqual(x.element_size(), 16) + def test_backward(self): with fluid.dygraph.guard(): var = fluid.dygraph.to_variable(self.array) diff --git a/python/paddle/fluid/tests/unittests/test_variable.py b/python/paddle/fluid/tests/unittests/test_variable.py index 2eb3ecf7104..5ba54daa0d4 100644 --- a/python/paddle/fluid/tests/unittests/test_variable.py +++ b/python/paddle/fluid/tests/unittests/test_variable.py @@ -63,6 +63,35 @@ class TestVariable(unittest.TestCase): self.assertRaises(ValueError, lambda: b.create_var(name="fc.w", shape=(24, 100))) + def test_element_size(self): + with fluid.program_guard(Program(), Program()): + x = paddle.static.data(name='x1', shape=[2], dtype='bool') + self.assertEqual(x.element_size(), 1) + + x = paddle.static.data(name='x2', shape=[2], dtype='float16') + self.assertEqual(x.element_size(), 2) + + x = paddle.static.data(name='x3', shape=[2], dtype='float32') + self.assertEqual(x.element_size(), 4) + + x = paddle.static.data(name='x4', shape=[2], dtype='float64') + self.assertEqual(x.element_size(), 8) + + x = paddle.static.data(name='x5', shape=[2], dtype='int8') + self.assertEqual(x.element_size(), 1) + + x = paddle.static.data(name='x6', shape=[2], dtype='int16') + self.assertEqual(x.element_size(), 2) + + x = paddle.static.data(name='x7', shape=[2], dtype='int32') + self.assertEqual(x.element_size(), 4) + + x = paddle.static.data(name='x8', shape=[2], dtype='int64') + self.assertEqual(x.element_size(), 8) + + x = paddle.static.data(name='x9', shape=[2], dtype='uint8') + self.assertEqual(x.element_size(), 1) + def test_step_scopes(self): prog = Program() b = prog.current_block() diff --git a/python/paddle/nn/utils/__init__.py b/python/paddle/nn/utils/__init__.py index b6801cfe320..8f9b55d15ca 100644 --- a/python/paddle/nn/utils/__init__.py +++ b/python/paddle/nn/utils/__init__.py @@ -14,7 +14,8 @@ from .spectral_norm_hook import spectral_norm from .weight_norm_hook import weight_norm, remove_weight_norm # noqa: F401 +from .transform_parameters import parameters_to_vector, vector_to_parameters # noqa: F401 __all__ = [ #noqa - 'weight_norm', 'remove_weight_norm', 'spectral_norm' + 'weight_norm', 'remove_weight_norm', 'spectral_norm', 'parameters_to_vector', 'vector_to_parameters' ] diff --git a/python/paddle/nn/utils/transform_parameters.py b/python/paddle/nn/utils/transform_parameters.py new file mode 100644 index 00000000000..ea7067cb950 --- /dev/null +++ b/python/paddle/nn/utils/transform_parameters.py @@ -0,0 +1,122 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import reduce + +import paddle +from paddle.fluid.framework import dygraph_only, _dygraph_tracer, _varbase_creator +from paddle import _C_ops + + +#input==output, inplace strategy of reshape has no cost almostly +def _inplace_reshape_dygraph(x, shape): + x_shape = _varbase_creator(dtype=x.dtype) + _dygraph_tracer().trace_op( + type="reshape2", + inputs={'X': x}, + outputs={'Out': x, + 'XShape': x_shape}, + attrs={'shape': shape}, + stop_gradient=True) + + +@dygraph_only +def parameters_to_vector(parameters, name=None): + """ + Flatten parameters to a 1-D Tensor. + + Args: + parameters(Iterable[Tensor]): Iterable Tensors that are trainable parameters of a Layer. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A 1-D Tensor, which represents the parameters of a Layer. + + + Examples: + .. code-block:: python + + import paddle + linear = paddle.nn.Linear(10, 15) + + paddle.nn.utils.parameters_to_vector(linear.parameters()) + # 1-D Tensor: [165] + + """ + dtype = parameters[0].dtype + origin_shapes = [] + for param in parameters: + origin_shapes.append(param.shape) + _inplace_reshape_dygraph(param, [-1]) + + out = _varbase_creator(dtype=dtype) + _dygraph_tracer().trace_op( + type='concat', + inputs={'X': parameters}, + outputs={'Out': [out]}, + attrs={'axis': 0}, + stop_gradient=True) + for i, param in enumerate(parameters): + _inplace_reshape_dygraph(param, origin_shapes[i]) + return out + + +@dygraph_only +def vector_to_parameters(vec, parameters, name=None): + """ + Transform a Tensor with 1-D shape to the parameters. + + Args: + vec (Tensor): A Tensor with 1-D shape, which represents the parameters of a Layer. + parameters (Iterable[Tensor]): Iterable Tensors that are trainable parameters of a Layer. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Examples: + .. code-block:: python + + import paddle + weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(3.)) + linear1 = paddle.nn.Linear(10, 15, weight_attr) + + vec = paddle.nn.utils.parameters_to_vector(linear1.parameters()) + + linear2 = paddle.nn.Linear(10, 15) + # copy weight of linear1 to linear2 + paddle.nn.utils.vector_to_parameters(vec, linear2.parameters()) + # weight: Tensor(shape=[10, 15], dtype=float32, place=CUDAPlace(0), stop_gradient=False, + # [[3. , ..., 3. ], + # [..., ..., ...], + # [3. , ..., 3. ]]) + """ + origin_shapes = [] + sections = [] + for param in parameters: + shape = param.shape + origin_shapes.append(shape) + numel = reduce(lambda x, y: x * y, shape) + sections.append(numel) + + _dygraph_tracer().trace_op( + type='split', + inputs={'X': [vec]}, + outputs={'Out': parameters}, + attrs={'axis': 0, + 'sections': sections}, + stop_gradient=True) + + for i, param in enumerate(parameters): + _inplace_reshape_dygraph(param, origin_shapes[i]) + return diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 8a376884063..facec0975b6 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1158,8 +1158,7 @@ def empty_like(x, dtype=None, name=None): def assign(x, output=None): """ - - + The OP copies the :attr:`x` to the :attr:`output`. Parameters: @@ -1192,6 +1191,36 @@ def assign(x, output=None): return tensor.assign(x, output) +def clone(x, name=None): + """ + Returns a copy of input Tensor. It will always have a Tensor copy. + + In addition, This function is derivable, so gradients will flow back from the output to input. + + Parameters: + x (Tensor): The input Tensor. + name(str, optional): The default value is None. Normally there is no need for user to set this + property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: A Tensor copied from ``input`` . + + Examples: + .. code-block:: python + + import paddle + + x = paddle.ones([2]) + x.stop_gradient = False + clone_x = paddle.clone(x) + + y = clone_x**3 + y.backward() + print(clone_x.grad) # [3] + print(x.grad) # [3] + """ + return x.clone() + + #NOTE(zhiqiu): not public def _memcpy(input, place=None, output=None): """ -- GitLab