未验证 提交 63f5c2d4 编写于 作者: W Weilong Wu 提交者: GitHub

[Bug fixes] Add default arg to enhance varbase ClearGradient func (#36837)

* Add default arg to enhance varbase ClearGradient func

* Removed default arg, use a Flag to enhance varbase ClearGradient func

* Renamed Flags to FLAGS_real_release

* Use default arg to enhance varbase ClearGradient func and expose two func to set/get gradient isEmpty

* Removed DECLARE_bool statement

* Polished Code
上级 7a0cc0a9
......@@ -28,7 +28,6 @@
#endif
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace imperative {
......@@ -186,7 +185,7 @@ size_t VarBase::GradOpNum() const {
return grad_node_ ? grad_node_->size() : 0;
}
void VarBase::ClearGradient() {
void VarBase::ClearGradient(bool set_to_zero) {
VLOG(4) << "ClearGradient " << Name();
if (grad_var_) {
if (grad_var_->Var().IsType<framework::SelectedRows>()) {
......@@ -204,9 +203,13 @@ void VarBase::ClearGradient() {
auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::LoDTensor>();
if (grad_t->IsInitialized()) {
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place());
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
if (set_to_zero) {
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place());
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
} else {
grad_t->clear();
}
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place());
#endif
......@@ -219,6 +222,28 @@ void VarBase::ClearGradient() {
}
}
void VarBase::_GradientSetEmpty(bool is_empty) {
VLOG(4) << "Set gradient " << Name() << " is_empty:" << is_empty;
if (grad_var_) {
auto share_var = grad_var_->SharedVar();
if (share_var) {
share_var->SetIsEmpty(is_empty);
}
}
}
bool VarBase::_IsGradientSetEmpty() {
bool res = true;
if (grad_var_) {
auto share_var = grad_var_->SharedVar();
if (share_var) {
res = share_var->is_empty_;
VLOG(4) << "Check gradient " << Name() << " is empty:" << res;
}
}
return res;
}
std::shared_ptr<VarBase> VarBase::NewVarBase(const platform::Place& dst_place,
const bool blocking) const {
PADDLE_ENFORCE_EQ(
......
......@@ -222,7 +222,10 @@ class VarBase {
const platform::Place Place() const { return var_->Place(); }
void ClearGradient();
void ClearGradient(bool set_to_zero = true);
void _GradientSetEmpty(bool is_empty = true);
bool _IsGradientSetEmpty();
std::shared_ptr<VarBase> NewVarBase(const platform::Place& dst_place,
const bool blocking) const;
......
......@@ -1480,7 +1480,8 @@ void BindImperative(py::module *m_ptr) {
# one of the variables needed for gradient computation has been modified by an inplace operation.
)DOC")
.def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(
.def("clear_gradient", &imperative::VarBase::ClearGradient,
py::arg("set_to_zero") = true, R"DOC(
Only for Tensor that has gradient, normally we use this for Parameters since other temporary Tensor doesen't has gradient.
......@@ -1500,6 +1501,9 @@ void BindImperative(py::module *m_ptr) {
linear.weight.clear_gradient()
print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
)DOC")
.def("_gradient_set_empty", &imperative::VarBase::_GradientSetEmpty,
py::arg("set_is_empty") = true)
.def("_is_gradient_set_empty", &imperative::VarBase::_IsGradientSetEmpty)
.def("clone",
[](std::shared_ptr<imperative::VarBase> &self) {
const auto &tensor = self->Var().Get<framework::LoDTensor>();
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle
from paddle.fluid.wrapped_decorator import wrap_decorator
import unittest
from unittest import TestCase
import numpy as np
def _dygraph_guard_(func):
def __impl__(*args, **kwargs):
if fluid.in_dygraph_mode():
return func(*args, **kwargs)
else:
with fluid.dygraph.guard():
return func(*args, **kwargs)
return __impl__
dygraph_guard = wrap_decorator(_dygraph_guard_)
class TestDygraphClearGradient(TestCase):
def setUp(self):
self.input_shape = [10, 2]
@dygraph_guard
def test_tensor_method_clear_gradient_case1(self):
input = paddle.uniform(self.input_shape)
linear = paddle.nn.Linear(2, 3)
out = linear(input)
out.backward()
linear.weight.clear_gradient()
# actual result
gradient_actual = linear.weight.grad
# expected result
gradient_expected = np.zeros([2, 3]).astype('float64')
self.assertTrue(np.allclose(gradient_actual.numpy(), gradient_expected))
@dygraph_guard
def test_tensor_method_clear_gradient_case2(self):
input = paddle.uniform(self.input_shape)
linear = paddle.nn.Linear(2, 3)
out = linear(input)
out.backward()
# default arg set_to_zero is true
# so, False means real clear gradient
linear.weight.clear_gradient(False)
# before ._gradient_set_empty(False),
# the return of ._is_gradient_set_empty() should be True
self.assertTrue(linear.weight._is_gradient_set_empty())
# reset, because ClearGradient will call SetIsEmpty(True), but this is not our expectation.
linear.weight._gradient_set_empty(False)
# after ._gradient_set_empty(False),
# the return of ._is_gradient_set_empty() should be False
self.assertFalse(linear.weight._is_gradient_set_empty())
# actual result
gradient_actual = linear.weight.grad
# expected result
self.assertTrue(np.empty(gradient_actual))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册