From e74e1a226ca9b20275d77aa1af069d29cfee2af0 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Wed, 9 Dec 2020 20:38:33 +0800 Subject: [PATCH] support deepcopy for Layer/Tensor/Paramerbase (#29387) * support deepcopy for Layer/Tensor/Paramerbase * fix some code --- paddle/fluid/imperative/layer.cc | 30 +++++++++ paddle/fluid/imperative/layer.h | 2 + paddle/fluid/pybind/imperative.cc | 8 +++ python/paddle/fluid/dygraph/layers.py | 28 ++++++--- .../fluid/dygraph/varbase_patch_methods.py | 35 ++++++++++- python/paddle/fluid/framework.py | 31 +++++++++ .../tests/unittests/test_imperative_basic.py | 38 ++++++----- .../fluid/tests/unittests/test_parameter.py | 34 +++++++++- .../fluid/tests/unittests/test_var_base.py | 63 +++++++++++++++++++ 9 files changed, 242 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 6f490c3c2be..94f2f722df0 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -282,6 +282,36 @@ std::shared_ptr VarBase::NewVarBase(const platform::Place& dst_place, } } +void VarBase::CopyFrom(const VarBase& src, const bool blocking) { + if (SharedVar()->IsEmpty()) { + VLOG(3) << "deep copy Variable from " << src.Name() << " to " << Name(); + SetPersistable(src.Persistable()); + SetDataType(src.DataType()); + SetType(src.Type()); + SetOverridedStopGradient(src.OverridedStopGradient()); + if (!src.SharedVar()->IsEmpty()) { + const platform::Place& place = src.Place(); + if (src.Var().IsType()) { + auto& src_tensor = src.Var().Get(); + auto* dst_tensor = MutableVar()->GetMutable(); + dst_tensor->set_lod(src_tensor.lod()); + framework::TensorCopy(src_tensor, place, dst_tensor); + } else if (src.Var().IsType()) { + auto& src_selected_rows = src.Var().Get(); + auto* dst_selected_rows = + MutableVar()->GetMutable(); + dst_selected_rows->set_height(src_selected_rows.height()); + dst_selected_rows->set_rows(src_selected_rows.rows()); + framework::TensorCopy(src_selected_rows.value(), place, + dst_selected_rows->mutable_value()); + } + if (blocking) { + platform::DeviceContextPool::Instance().Get(place)->Wait(); + } + } + } +} + void VarBase::BumpInplaceVersion() { PADDLE_ENFORCE_EQ( Var().IsInitialized(), true, diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 1a974ab346e..5e4767994dc 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -208,6 +208,8 @@ class VarBase { std::shared_ptr NewVarBase(const platform::Place& dst_place, const bool blocking) const; + void CopyFrom(const imperative::VarBase& src, bool blocking); + void BumpInplaceVersion(); private: diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7a48ffa82a4..ec59eacef14 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -526,6 +526,13 @@ void BindImperative(py::module *m_ptr) { py::class_>( m, "VarBase", R"DOC()DOC") .def_static("_alive_vars", &imperative::VarBase::AliveVarNames) + .def("__init__", + [](imperative::VarBase &self) { + std::string name = + imperative::GetCurrentTracer()->GenerateUniqueName( + "generated_tensor"); + new (&self) imperative::VarBase(name); + }) .def("__init__", [](imperative::VarBase &self, framework::proto::VarType::Type dtype, const std::vector &dims, const py::handle &name, @@ -1023,6 +1030,7 @@ void BindImperative(py::module *m_ptr) { y = x.cuda(1) print(y.place) # CUDAPlace(1) )DOC") + .def("copy_", &imperative::VarBase::CopyFrom) .def("_copy_to", [](const imperative::VarBase &self, const platform::CPUPlace &place, bool blocking) { return self.NewVarBase(place, blocking); }, diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index fe60c24ff36..ad3a20869ce 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -21,6 +21,7 @@ import re import copy import weakref import warnings +from copy import deepcopy from . import parallel_helper from .. import unique_name @@ -1010,15 +1011,26 @@ class Layer(core.Layer): self._parameters[name] = parameter return parameter + def __getstate__(self): + return self.__dict__ + + def __setstate__(self, state): + self.__dict__.update(state) + def __getattr__(self, name): - if name in self._parameters: - return self._parameters[name] - elif name in self._sub_layers: - return self._sub_layers[name] - elif name in self._buffers: - return self._buffers[name] - else: - return object.__getattribute__(self, name) + if '_parameters' in self.__dict__: + _parameters = self.__dict__['_parameters'] + if name in self._parameters: + return self._parameters[name] + if '_sub_layers' in self.__dict__: + _sub_layers = self.__dict__['_sub_layers'] + if name in self._sub_layers: + return self._sub_layers[name] + if '_buffers' in self.__dict__: + _buffers = self.__dict__['_buffers'] + if name in _buffers: + return _buffers[name] + return object.__getattribute__(self, name) def __setattr__(self, name, value): def _remove_if_exist(*dicts): diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 6a59e33285c..7b0a3453b13 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -18,6 +18,7 @@ import numpy as np import paddle from .. import framework from .. import core +from .. import unique_name from ..framework import Variable, Parameter, ParamBase from .base import switch_to_static_graph from .math_op_patch import monkey_patch_math_varbase @@ -263,6 +264,37 @@ def monkey_patch_varbase(): from paddle.tensor.to_string import to_string return to_string(self) + def __deepcopy__(self, memo): + """ + Deep copy Tensor, it will always performs Tensor copy. + + Examples: + .. code-block:: python + + import paddle + import copy + x = paddle.to_tensor(2.) + y = copy.deepcopy(x) + + print(x) + # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, + # [2.]) + + print(y) + # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=True, + # [2.]) + + """ + if not self.is_leaf: + raise RuntimeError( + "Only Leaf Tensor support the deepcopy at the moment, non-Leaf Tensors contains graph information that does't support deepcopy" + ) + new_varbase = core.VarBase() + new_varbase.name = self.name + unique_name.generate("_deepcopy") + memo[id(self)] = new_varbase + new_varbase.copy_(self, True) + return new_varbase + @property def block(self): return framework.default_main_program().global_block() @@ -283,7 +315,8 @@ def monkey_patch_varbase(): ("block", block), ("backward", backward), ("clear_grad", clear_grad), ("inplace_version", inplace_version), ("grad", grad), ("gradient", gradient), ("__str__", __str__), ("__repr__", __str__), - ("__module__", "paddle"), ("__name__", "Tensor")): + ("__deepcopy__", __deepcopy__), ("__module__", "paddle"), + ("__name__", "Tensor")): setattr(core.VarBase, method_name, method) # patch math methods for varbase diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 7be4c0b28c1..6f1a5e61777 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -23,6 +23,7 @@ import os import re import traceback import six +import copy import numpy as np import subprocess @@ -5274,6 +5275,36 @@ class ParamBase(core.VarBase): return "Parameter containing:\n{tensor}".format( tensor=super(ParamBase, self).__str__()) + def __deepcopy__(self, memo): + """ + Deep copy parameter, it will always performs Tensor copy. + + Examples: + .. code-block:: python + + import paddle + import copy + linear = paddle.nn.Linear(1, 3) + linear_copy = copy.deepcopy(linear) + + print(linear.weight) + # Parameter containing: + # Tensor(shape=[1, 3], dtype=float32, place=CPUPlace, stop_gradient=False, + # [[-0.30929261, -0.90929240, -1.07851017]]) + + print(linear_copy.weight) + # Parameter containing: + # Tensor(shape=[1, 3], dtype=float32, place=CPUPlace, stop_gradient=False, + # [[-0.30929261, -0.90929240, -1.07851017]]) + + """ + state = copy.deepcopy(self.__dict__, memo) + state["name"] = self.name + unique_name.generate("_deepcopy") + new_param = ParamBase(self.shape, self.dtype, **state) + memo[id(self)] = new_param + new_param.copy_(self, True) + return new_param + __repr__ = __str__ diff --git a/python/paddle/fluid/tests/unittests/test_imperative_basic.py b/python/paddle/fluid/tests/unittests/test_imperative_basic.py index e33e7247d02..cb48013902a 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_basic.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_basic.py @@ -287,7 +287,6 @@ class TestImperative(unittest.TestCase): with paddle.no_grad(): self.assertTrue(l1.weight.stop_gradient is False) tmp = l1.weight * 2 - print(tmp) self.assertTrue(tmp.stop_gradient) x = fluid.dygraph.to_variable(data) y = l0(x) + tmp @@ -485,15 +484,15 @@ class TestImperative(unittest.TestCase): for i in range(10): y = paddle.pow(x, 4.0) y.backward() - print(x.grad) self.assertEqual(x.grad, (i + 1) * 500) x.clear_gradient() self.assertEqual(x.grad, 0.) - for i in range(5): + for i in range(10): y = paddle.pow(x, 4.0) y.backward() - print(x.grad) self.assertEqual(x.grad, (i + 1) * 500) + x.clear_grad() + self.assertEqual(x.grad, 0.) def test_simple_net(sort_sum_gradient): fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient}) @@ -504,9 +503,18 @@ class TestImperative(unittest.TestCase): def fun(x, y, z): loss1 = x * x * y loss2 = x * z + loss1.backward(retain_graph=True) + loss2.backward(retain_graph=True) + self.assertTrue(np.array_equal(x.grad, [23.])) + self.assertTrue(np.array_equal(y.grad, [25.])) + self.assertTrue(np.array_equal(z.grad, [5.])) + x.clear_grad() + y.clear_grad() + z.clear_grad() + dx = paddle.grad([loss1], x, create_graph=True)[0] - # loss = x*x*y + x*z + 2*x*y loss = loss1 + loss2 + dx + # loss = x*x*y + x*z + 2*x*y return loss loss = fun(x, y, z) @@ -539,12 +547,12 @@ class TestImperative(unittest.TestCase): # generate the gradient of each step mlp2 = MLP(input_size=input_size) - expected_weight1_grad = np.zeros(mlp2._linear1.weight.shape) - expected_bias1_grad = np.zeros(mlp2._linear1.bias.shape) - expected_weight2_grad = np.zeros(mlp2._linear2.weight.shape) - expected_bias2_grad = np.zeros(mlp2._linear2.bias.shape) + expected_weight1_grad = 0. + expected_bias1_grad = 0. + expected_weight2_grad = 0. + expected_bias2_grad = 0. - for batch_id in range(24): + for batch_id in range(100): x = paddle.uniform([10, input_size]) detach_x = x.detach() clear_loss = mlp2(detach_x) @@ -571,12 +579,12 @@ class TestImperative(unittest.TestCase): mlp2.clear_gradients() self.assertTrue(np.array_equal(clear_loss.grad, [1])) - if ((batch_id + 1) % 8) == 0: + if ((batch_id + 1) % 10) == 0: mlp1.clear_gradients() - expected_weight1_grad = np.zeros(mlp2._linear1.weight.shape) - expected_bias1_grad = np.zeros(mlp2._linear1.bias.shape) - expected_weight2_grad = np.zeros(mlp2._linear2.weight.shape) - expected_bias2_grad = np.zeros(mlp2._linear2.bias.shape) + expected_weight1_grad = 0. + expected_bias1_grad = 0. + expected_weight2_grad = 0. + expected_bias2_grad = 0. with fluid.dygraph.guard(): test_single_api(False) diff --git a/python/paddle/fluid/tests/unittests/test_parameter.py b/python/paddle/fluid/tests/unittests/test_parameter.py index 05c19776a37..46e211f4729 100644 --- a/python/paddle/fluid/tests/unittests/test_parameter.py +++ b/python/paddle/fluid/tests/unittests/test_parameter.py @@ -15,6 +15,9 @@ from __future__ import print_function import unittest +import copy +import paddle +from paddle.fluid.dygraph import guard from paddle.fluid.framework import default_main_program import paddle.fluid.core as core from paddle.fluid.executor import Executor @@ -26,7 +29,7 @@ main_program = default_main_program() class ParameterChecks(unittest.TestCase): - def check_param(self): + def check_parameter(self): shape = [784, 100] val = 1.0625 b = main_program.global_block() @@ -46,6 +49,28 @@ class ParameterChecks(unittest.TestCase): p = io.get_parameter_value_by_name('fc.w', exe, main_program) self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val)) + def check_parambase(self): + with guard(): + linear = paddle.nn.Linear(10, 10) + param = linear.weight + + memo = {} + param_copy = copy.deepcopy(param, memo) + self.assertEqual(param_copy.shape, param.shape) + self.assertEqual(param_copy.type, param.type) + self.assertEqual(param_copy.dtype, param.dtype) + self.assertEqual(str(param_copy.place), str(param.place)) + self.assertTrue(np.array_equal(param_copy.numpy(), param.numpy())) + self.assertEqual(param_copy.optimize_attr, param.optimize_attr) + self.assertEqual(param_copy.regularizer, param.regularizer) + self.assertEqual(param_copy.do_model_average, + param.do_model_average) + self.assertEqual(param_copy.need_clip, param.need_clip) + self.assertEqual(param_copy.is_distributed, param.is_distributed) + + pram_copy2 = copy.deepcopy(param, memo) + self.assertEqual(id(param_copy), id(pram_copy2)) + def check_exceptions(self): b = main_program.global_block() with self.assertRaises(ValueError): @@ -63,8 +88,11 @@ class ParameterChecks(unittest.TestCase): class TestParameter(ParameterChecks): - def test_param(self): - self.check_param() + def _test_parameter(self): + self.check_parameter() + + def test_parambase(self): + self.check_parambase() def test_exceptions(self): self.check_exceptions() diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 6d74505bc1f..06009e4ba8b 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import numpy as np import six +import copy import paddle import paddle.fluid as fluid @@ -264,6 +265,68 @@ class TestVarBase(unittest.TestCase): var.stop_gradient = False self.assertEqual(var.stop_gradient, False) + def test_deep_copy(self): + with fluid.dygraph.guard(): + empty_var = core.VarBase() + empty_var_copy = copy.deepcopy(empty_var) + self.assertEqual(empty_var.stop_gradient, + empty_var_copy.stop_gradient) + self.assertEqual(empty_var.persistable, empty_var_copy.persistable) + self.assertEqual(empty_var.type, empty_var_copy.type) + self.assertEqual(empty_var.dtype, empty_var_copy.dtype) + + x = paddle.to_tensor([2.], stop_gradient=False) + y = paddle.to_tensor([3.], stop_gradient=False) + z = x * y + memo = {} + x_copy = copy.deepcopy(x, memo) + y_copy = copy.deepcopy(y, memo) + + self.assertEqual(x_copy.stop_gradient, y_copy.stop_gradient) + self.assertEqual(x_copy.persistable, y_copy.persistable) + self.assertEqual(x_copy.type, y_copy.type) + self.assertEqual(x_copy.dtype, y_copy.dtype) + self.assertTrue(np.array_equal(x.numpy(), x_copy.numpy())) + self.assertTrue(np.array_equal(y.numpy(), y_copy.numpy())) + + self.assertNotEqual(id(x), id(x_copy)) + x_copy[:] = 5. + self.assertTrue(np.array_equal(x_copy.numpy(), [5.])) + self.assertTrue(np.array_equal(x.numpy(), [2.])) + + with self.assertRaises(RuntimeError): + copy.deepcopy(z) + + x_copy2 = copy.deepcopy(x, memo) + y_copy2 = copy.deepcopy(y, memo) + self.assertEqual(id(x_copy), id(x_copy2)) + self.assertEqual(id(y_copy), id(y_copy2)) + + # test copy selected rows + x = core.VarBase(core.VarDesc.VarType.FP32, [3, 100], + "selected_rows", + core.VarDesc.VarType.SELECTED_ROWS, True) + selected_rows = x.value().get_selected_rows() + selected_rows.get_tensor().set( + np.random.rand(3, 100), core.CPUPlace()) + selected_rows.set_height(10) + selected_rows.set_rows([3, 5, 7]) + x_copy = copy.deepcopy(x) + + self.assertEqual(x_copy.stop_gradient, x.stop_gradient) + self.assertEqual(x_copy.persistable, x.persistable) + self.assertEqual(x_copy.type, x.type) + self.assertEqual(x_copy.dtype, x.dtype) + + copy_selected_rows = x_copy.value().get_selected_rows() + self.assertEqual(copy_selected_rows.height(), + selected_rows.height()) + self.assertEqual(copy_selected_rows.rows(), selected_rows.rows()) + self.assertTrue( + np.array_equal( + np.array(copy_selected_rows.get_tensor()), + np.array(selected_rows.get_tensor()))) + # test some patched methods def test_set_value(self): with fluid.dygraph.guard(): -- GitLab