From dcb91fd784f25d826651d56288222fd8746cd36b Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 26 Nov 2021 15:24:49 +0800 Subject: [PATCH] Added interface reset_grad_inplace_version (#37573) reset_inplace_version removes all inplace related records to VarBase/VariableWrapper, the essential purpose of which is to let you use inplace operations as if using its non-inplaced version, which of course will cause unexpected consequences if not used with care. This is essentially a hack interface to satisfy one specific request --- paddle/fluid/imperative/dygraph_grad_maker.h | 7 +-- paddle/fluid/pybind/imperative.cc | 19 +++++++ .../test_reset_grad_inplace_version.py | 49 +++++++++++++++++++ 3 files changed, 72 insertions(+), 3 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index f5a4b0fa315..50e5f6d31d0 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -365,9 +365,10 @@ class TracedGradOp { } } - VariableWrapper new_var_wrapper = *var_wrapper.get(); - new_var_wrapper.ResetInplaceVersion(); - return std::make_shared(new_var_wrapper); + auto new_var_wrapper = + std::make_shared(*var_wrapper.get()); + new_var_wrapper->ResetInplaceVersion(); + return new_var_wrapper; } private: diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 85a9507050f..5ff0e58d858 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1548,6 +1548,25 @@ void BindImperative(py::module *m_ptr) { [](imperative::VarBase &self, framework::proto::VarType::Type type) { self.MutableGradVarBase()->SetType(type); }) + .def("_reset_grad_inplace_version", + [](imperative::VarBase &self) { + /* + *** This interfaceis a complete hack *** + reset_grad_inplace_version removes all inplace related records to + Grad VarBase/VariableWrapper, + the essential purpose of which is to let you use inplace operations + as if using its non-inplaced version, + which of course will cause unexpected consequences if not used with + care. + Make sure you fully understand what you're doing before make use of + this interface, and prepare for the worst. + */ + if (self.HasGradVar()) { + auto grad_var = self.GradVarBase(); + auto var_wrapper = grad_var->SharedVar(); + if (var_wrapper) var_wrapper->ResetInplaceVersion(); + } + }) .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); diff --git a/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py b/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py new file mode 100644 index 00000000000..d9634f4997d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py @@ -0,0 +1,49 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle import _C_ops +from paddle.fluid import framework +import unittest +paddle.set_device('cpu') + + +def clear_grad(w, a): + @paddle.no_grad() + def warp(*_): + assert w.grad is not None + _C_ops.scale_(w.grad, 'scale', 0.5) + w._reset_grad_inplace_version() + + return warp + + +class TestInplaceAndClearGradient(unittest.TestCase): + def test(self): + input_data = np.ones([1, 1]) + w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) + + _clear_grad = clear_grad(w, a="1") + w._register_backward_hook(_clear_grad) + for i in range(2): + print(" Step: ", i) + out0 = _C_ops.scale(w, 'scale', 0.1) + out = _C_ops.matmul_v2(out0, w, 'trans_x', False, 'trans_y', False) + out.backward() + assert w.grad[0] == 0.15 + + +if __name__ == '__main__': + unittest.main() -- GitLab