未验证 提交 cf586021 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Buf fix for reset grad inplace version (#37811)

* Debug

* Fixed issue with reset_grad_inplace_version when used with clear_gradient & cross-batch accumulation

* Rearranged interfaces

* Fixed ci issues
上级 723cbe9f
......@@ -81,6 +81,7 @@ class TensorInplaceVersion {
bool IsUnique() const { return inplace_version_ == 0; }
void Bump() { ++inplace_version_; }
uint32_t CurrentVersion() const { return inplace_version_; }
void SetInplaceVersionToZero() { inplace_version_ = 0; }
private:
uint32_t inplace_version_;
......
......@@ -75,6 +75,7 @@ class Variable {
framework::TensorInplaceVersion* InplaceVersionCounter();
public:
void SetInplaceVersionToZero();
uint32_t CurrentInplaceVersion();
void BumpInplaceVersion();
......@@ -134,6 +135,12 @@ inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
return version_counter_ptr;
}
inline void Variable::SetInplaceVersionToZero() {
auto inplace_version_counter = this->InplaceVersionCounter();
if (inplace_version_counter)
inplace_version_counter->SetInplaceVersionToZero();
}
inline uint32_t Variable::CurrentInplaceVersion() {
auto version_counter_ptr = InplaceVersionCounter();
if (version_counter_ptr) {
......
......@@ -209,13 +209,23 @@ class VariableWrapper {
uint32_t InplaceVersionSnapshot() const { return inplace_version_snapshot_; }
void ResetInplaceVersion() {
void ResetInplaceVersion(bool set_to_zero = false) {
if (!set_to_zero) {
auto new_version = var_.CurrentInplaceVersion();
VLOG(6) << "The wrapper version of VariableWrapper '" << name_
<< "' will be updated from " << inplace_version_snapshot_ << "to "
<< new_version;
inplace_version_snapshot_ = new_version;
} else {
// Reset Snapshot & InplaceVersion to zero
inplace_version_snapshot_ = 0;
auto var = this->MutableVar();
if (var) {
var->SetInplaceVersionToZero();
}
}
}
bool hasCacheKey(const paddle::framework::OpKernelType& key) {
......
......@@ -1538,7 +1538,7 @@ void BindImperative(py::module *m_ptr) {
self.MutableGradVarBase()->SetType(type);
})
.def("_reset_grad_inplace_version",
[](imperative::VarBase &self) {
[](imperative::VarBase &self, bool set_to_zero) {
/*
*** This interfaceis a complete hack ***
reset_grad_inplace_version removes all inplace related records to
......@@ -1550,15 +1550,20 @@ void BindImperative(py::module *m_ptr) {
Make sure you fully understand what you're doing before make use of
this interface, and prepare for the worst.
*/
py::gil_scoped_release release;
if (self.HasGradVar()) {
auto grad_var = self.GradVarBase();
auto var_wrapper = grad_var->SharedVar();
if (var_wrapper) var_wrapper->ResetInplaceVersion();
if (var_wrapper) {
var_wrapper->ResetInplaceVersion(set_to_zero);
}
}
})
.def("_grad_ivar",
[](const imperative::VarBase &self) {
auto &grad_var = self.GradVarBase();
if (grad_var && grad_var->Var().IsInitialized()) {
auto *tensor =
grad_var->MutableVar()->IsType<framework::LoDTensor>()
......@@ -1567,6 +1572,7 @@ void BindImperative(py::module *m_ptr) {
: grad_var->MutableVar()
->GetMutable<framework::SelectedRows>()
->mutable_value();
if (tensor->IsInitialized()) {
return grad_var;
}
......
......@@ -177,7 +177,7 @@ class ShardingStage2(nn.Layer):
for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version()
param._reset_grad_inplace_version(True)
def _init_internal_storage(self, needs_fresh):
"""
......@@ -283,7 +283,7 @@ class ShardingStage2(nn.Layer):
self._grad_reduced[index] = False
if not self._accumulate_grads:
param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version()
param._reset_grad_inplace_version(True)
# Clear the gradient that does not belong to the current rank through the callback function
def cleanup():
......
......@@ -20,12 +20,13 @@ import unittest
paddle.set_device('cpu')
def clear_grad(w, a):
# Test 1
def clear_grad_test_0(w, a):
@paddle.no_grad()
def warp(*_):
assert w.grad is not None
_C_ops.scale_(w.grad, 'scale', 0.5)
w._reset_grad_inplace_version()
w._reset_grad_inplace_version(True)
return warp
......@@ -35,7 +36,7 @@ class TestInplaceAndClearGradient(unittest.TestCase):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
_clear_grad = clear_grad(w, a="1")
_clear_grad = clear_grad_test_0(w, a="1")
w._register_backward_hook(_clear_grad)
for i in range(2):
print(" Step: ", i)
......@@ -45,5 +46,60 @@ class TestInplaceAndClearGradient(unittest.TestCase):
assert w.grad[0] == 0.15
# Test 2
class Counter:
def __init__(self):
self.num_calls = 0
self.step = 0
def clear_grad_test_1(w, c):
@paddle.no_grad()
def warp(*_):
assert w.grad is not None
if c.step == 1:
w.grad.scale_(scale=0.5)
w._reset_grad_inplace_version(True)
c.num_calls += 1
return warp
class TestInplaceClearGradAccumulation(unittest.TestCase):
def test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
c = Counter()
_clear_grad = clear_grad_test_1(w, c)
w._register_backward_hook(_clear_grad)
for c.step in range(5):
out0 = _C_ops.scale(w, 'scale', 0.1)
out = _C_ops.matmul_v2(out0, w, 'trans_x', False, 'trans_y', False)
out.backward()
if c.step == 1:
w.clear_gradient(False)
assert c.num_calls == 1
c.num_calls = 0
class TestInplaceClearGradAccumulationAlt(unittest.TestCase):
def test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
out = _C_ops.scale(w, 'scale', 0.1)
out.backward()
w.grad.scale_(scale=0.5)
w._reset_grad_inplace_version(False)
assert w.grad._inplace_version() == 1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册