未验证 提交 dcb91fd7 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Added interface reset_grad_inplace_version (#37573)

reset_inplace_version removes all inplace related records to VarBase/VariableWrapper, the essential purpose of which is to let you use inplace operations as if using its non-inplaced version, which of course will cause unexpected consequences if not used with care.

This is essentially a hack interface to satisfy one specific request
上级 80b7c96c
...@@ -365,9 +365,10 @@ class TracedGradOp { ...@@ -365,9 +365,10 @@ class TracedGradOp {
} }
} }
VariableWrapper new_var_wrapper = *var_wrapper.get(); auto new_var_wrapper =
new_var_wrapper.ResetInplaceVersion(); std::make_shared<VariableWrapper>(*var_wrapper.get());
return std::make_shared<VariableWrapper>(new_var_wrapper); new_var_wrapper->ResetInplaceVersion();
return new_var_wrapper;
} }
private: private:
......
...@@ -1548,6 +1548,25 @@ void BindImperative(py::module *m_ptr) { ...@@ -1548,6 +1548,25 @@ void BindImperative(py::module *m_ptr) {
[](imperative::VarBase &self, framework::proto::VarType::Type type) { [](imperative::VarBase &self, framework::proto::VarType::Type type) {
self.MutableGradVarBase()->SetType(type); self.MutableGradVarBase()->SetType(type);
}) })
.def("_reset_grad_inplace_version",
[](imperative::VarBase &self) {
/*
*** This interfaceis a complete hack ***
reset_grad_inplace_version removes all inplace related records to
Grad VarBase/VariableWrapper,
the essential purpose of which is to let you use inplace operations
as if using its non-inplaced version,
which of course will cause unexpected consequences if not used with
care.
Make sure you fully understand what you're doing before make use of
this interface, and prepare for the worst.
*/
if (self.HasGradVar()) {
auto grad_var = self.GradVarBase();
auto var_wrapper = grad_var->SharedVar();
if (var_wrapper) var_wrapper->ResetInplaceVersion();
}
})
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { [](const imperative::VarBase &self) {
auto &grad_var = self.GradVarBase(); auto &grad_var = self.GradVarBase();
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle import _C_ops
from paddle.fluid import framework
import unittest
paddle.set_device('cpu')
def clear_grad(w, a):
@paddle.no_grad()
def warp(*_):
assert w.grad is not None
_C_ops.scale_(w.grad, 'scale', 0.5)
w._reset_grad_inplace_version()
return warp
class TestInplaceAndClearGradient(unittest.TestCase):
def test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
_clear_grad = clear_grad(w, a="1")
w._register_backward_hook(_clear_grad)
for i in range(2):
print(" Step: ", i)
out0 = _C_ops.scale(w, 'scale', 0.1)
out = _C_ops.matmul_v2(out0, w, 'trans_x', False, 'trans_y', False)
out.backward()
assert w.grad[0] == 0.15
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册