From cf5860215c19caaea09d298514085bc24aad0439 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Tue, 7 Dec 2021 17:25:03 +0800 Subject: [PATCH] Buf fix for reset grad inplace version (#37811) * Debug * Fixed issue with reset_grad_inplace_version when used with clear_gradient & cross-batch accumulation * Rearranged interfaces * Fixed ci issues --- paddle/fluid/framework/tensor.h | 1 + paddle/fluid/framework/variable.h | 7 +++ paddle/fluid/imperative/variable_wrapper.h | 22 +++++-- paddle/fluid/pybind/imperative.cc | 10 ++- .../meta_parallel/sharding/sharding_stage2.py | 4 +- .../test_reset_grad_inplace_version.py | 62 ++++++++++++++++++- 6 files changed, 93 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 494a02878f1..4b1ae041fc4 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -81,6 +81,7 @@ class TensorInplaceVersion { bool IsUnique() const { return inplace_version_ == 0; } void Bump() { ++inplace_version_; } uint32_t CurrentVersion() const { return inplace_version_; } + void SetInplaceVersionToZero() { inplace_version_ = 0; } private: uint32_t inplace_version_; diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 792a2accd41..f8ad990a668 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -75,6 +75,7 @@ class Variable { framework::TensorInplaceVersion* InplaceVersionCounter(); public: + void SetInplaceVersionToZero(); uint32_t CurrentInplaceVersion(); void BumpInplaceVersion(); @@ -134,6 +135,12 @@ inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() { return version_counter_ptr; } +inline void Variable::SetInplaceVersionToZero() { + auto inplace_version_counter = this->InplaceVersionCounter(); + if (inplace_version_counter) + inplace_version_counter->SetInplaceVersionToZero(); +} + inline uint32_t Variable::CurrentInplaceVersion() { auto version_counter_ptr = InplaceVersionCounter(); if (version_counter_ptr) { diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index 9fbbe7d06f8..c257191a546 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -209,13 +209,23 @@ class VariableWrapper { uint32_t InplaceVersionSnapshot() const { return inplace_version_snapshot_; } - void ResetInplaceVersion() { - auto new_version = var_.CurrentInplaceVersion(); + void ResetInplaceVersion(bool set_to_zero = false) { + if (!set_to_zero) { + auto new_version = var_.CurrentInplaceVersion(); - VLOG(6) << "The wrapper version of VariableWrapper '" << name_ - << "' will be updated from " << inplace_version_snapshot_ << "to " - << new_version; - inplace_version_snapshot_ = new_version; + VLOG(6) << "The wrapper version of VariableWrapper '" << name_ + << "' will be updated from " << inplace_version_snapshot_ << "to " + << new_version; + inplace_version_snapshot_ = new_version; + + } else { + // Reset Snapshot & InplaceVersion to zero + inplace_version_snapshot_ = 0; + auto var = this->MutableVar(); + if (var) { + var->SetInplaceVersionToZero(); + } + } } bool hasCacheKey(const paddle::framework::OpKernelType& key) { diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 2c850f0ca84..dc97d98e8c4 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1538,7 +1538,7 @@ void BindImperative(py::module *m_ptr) { self.MutableGradVarBase()->SetType(type); }) .def("_reset_grad_inplace_version", - [](imperative::VarBase &self) { + [](imperative::VarBase &self, bool set_to_zero) { /* *** This interfaceis a complete hack *** reset_grad_inplace_version removes all inplace related records to @@ -1550,15 +1550,20 @@ void BindImperative(py::module *m_ptr) { Make sure you fully understand what you're doing before make use of this interface, and prepare for the worst. */ + py::gil_scoped_release release; + if (self.HasGradVar()) { auto grad_var = self.GradVarBase(); auto var_wrapper = grad_var->SharedVar(); - if (var_wrapper) var_wrapper->ResetInplaceVersion(); + if (var_wrapper) { + var_wrapper->ResetInplaceVersion(set_to_zero); + } } }) .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); + if (grad_var && grad_var->Var().IsInitialized()) { auto *tensor = grad_var->MutableVar()->IsType() @@ -1567,6 +1572,7 @@ void BindImperative(py::module *m_ptr) { : grad_var->MutableVar() ->GetMutable() ->mutable_value(); + if (tensor->IsInitialized()) { return grad_var; } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index 329dc9eaa4e..37b85751149 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -177,7 +177,7 @@ class ShardingStage2(nn.Layer): for param in self._trainable_params: if param.name in self._param_grads and param.grad is not None: param.grad.scale_(scale=self._world_size_scaling) - param._reset_grad_inplace_version() + param._reset_grad_inplace_version(True) def _init_internal_storage(self, needs_fresh): """ @@ -283,7 +283,7 @@ class ShardingStage2(nn.Layer): self._grad_reduced[index] = False if not self._accumulate_grads: param.grad.scale_(scale=self._world_size_scaling) - param._reset_grad_inplace_version() + param._reset_grad_inplace_version(True) # Clear the gradient that does not belong to the current rank through the callback function def cleanup(): diff --git a/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py b/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py index d9634f4997d..fee5bb8f47f 100644 --- a/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py +++ b/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py @@ -20,12 +20,13 @@ import unittest paddle.set_device('cpu') -def clear_grad(w, a): +# Test 1 +def clear_grad_test_0(w, a): @paddle.no_grad() def warp(*_): assert w.grad is not None _C_ops.scale_(w.grad, 'scale', 0.5) - w._reset_grad_inplace_version() + w._reset_grad_inplace_version(True) return warp @@ -35,7 +36,7 @@ class TestInplaceAndClearGradient(unittest.TestCase): input_data = np.ones([1, 1]) w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) - _clear_grad = clear_grad(w, a="1") + _clear_grad = clear_grad_test_0(w, a="1") w._register_backward_hook(_clear_grad) for i in range(2): print(" Step: ", i) @@ -45,5 +46,60 @@ class TestInplaceAndClearGradient(unittest.TestCase): assert w.grad[0] == 0.15 +# Test 2 +class Counter: + def __init__(self): + self.num_calls = 0 + self.step = 0 + + +def clear_grad_test_1(w, c): + @paddle.no_grad() + def warp(*_): + assert w.grad is not None + if c.step == 1: + w.grad.scale_(scale=0.5) + w._reset_grad_inplace_version(True) + + c.num_calls += 1 + + return warp + + +class TestInplaceClearGradAccumulation(unittest.TestCase): + def test(self): + input_data = np.ones([1, 1]) + w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) + c = Counter() + + _clear_grad = clear_grad_test_1(w, c) + w._register_backward_hook(_clear_grad) + + for c.step in range(5): + out0 = _C_ops.scale(w, 'scale', 0.1) + out = _C_ops.matmul_v2(out0, w, 'trans_x', False, 'trans_y', False) + + out.backward() + + if c.step == 1: + w.clear_gradient(False) + + assert c.num_calls == 1 + c.num_calls = 0 + + +class TestInplaceClearGradAccumulationAlt(unittest.TestCase): + def test(self): + input_data = np.ones([1, 1]) + w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) + out = _C_ops.scale(w, 'scale', 0.1) + out.backward() + + w.grad.scale_(scale=0.5) + w._reset_grad_inplace_version(False) + + assert w.grad._inplace_version() == 1 + + if __name__ == '__main__': unittest.main() -- GitLab