diff --git a/paddle/fluid/framework/tensor.h b/paddle/fluid/framework/tensor.h index 494a02878f1a2c1fc94a50777d3b4b8676b99e8e..4b1ae041fc4cada165e792a37250a0c1a3de27b0 100644 --- a/paddle/fluid/framework/tensor.h +++ b/paddle/fluid/framework/tensor.h @@ -81,6 +81,7 @@ class TensorInplaceVersion { bool IsUnique() const { return inplace_version_ == 0; } void Bump() { ++inplace_version_; } uint32_t CurrentVersion() const { return inplace_version_; } + void SetInplaceVersionToZero() { inplace_version_ = 0; } private: uint32_t inplace_version_; diff --git a/paddle/fluid/framework/variable.h b/paddle/fluid/framework/variable.h index 792a2accd41d67e76d56dfdc058e4128018614e7..f8ad990a668ce62348b5cc06b68d2d4ee21a26a1 100644 --- a/paddle/fluid/framework/variable.h +++ b/paddle/fluid/framework/variable.h @@ -75,6 +75,7 @@ class Variable { framework::TensorInplaceVersion* InplaceVersionCounter(); public: + void SetInplaceVersionToZero(); uint32_t CurrentInplaceVersion(); void BumpInplaceVersion(); @@ -134,6 +135,12 @@ inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() { return version_counter_ptr; } +inline void Variable::SetInplaceVersionToZero() { + auto inplace_version_counter = this->InplaceVersionCounter(); + if (inplace_version_counter) + inplace_version_counter->SetInplaceVersionToZero(); +} + inline uint32_t Variable::CurrentInplaceVersion() { auto version_counter_ptr = InplaceVersionCounter(); if (version_counter_ptr) { diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index 9fbbe7d06f8ad826b5a9ff3581bb8c1f42ce2bb6..c257191a546e439cedee0d2075549a45a3467423 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -209,13 +209,23 @@ class VariableWrapper { uint32_t InplaceVersionSnapshot() const { return inplace_version_snapshot_; } - void ResetInplaceVersion() { - auto new_version = var_.CurrentInplaceVersion(); + void ResetInplaceVersion(bool set_to_zero = false) { + if (!set_to_zero) { + auto new_version = var_.CurrentInplaceVersion(); - VLOG(6) << "The wrapper version of VariableWrapper '" << name_ - << "' will be updated from " << inplace_version_snapshot_ << "to " - << new_version; - inplace_version_snapshot_ = new_version; + VLOG(6) << "The wrapper version of VariableWrapper '" << name_ + << "' will be updated from " << inplace_version_snapshot_ << "to " + << new_version; + inplace_version_snapshot_ = new_version; + + } else { + // Reset Snapshot & InplaceVersion to zero + inplace_version_snapshot_ = 0; + auto var = this->MutableVar(); + if (var) { + var->SetInplaceVersionToZero(); + } + } } bool hasCacheKey(const paddle::framework::OpKernelType& key) { diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 2c850f0ca84d5f4a79f023646e9370e7a382a160..dc97d98e8c47fcadacd082a4278ae35adc04047d 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -1538,7 +1538,7 @@ void BindImperative(py::module *m_ptr) { self.MutableGradVarBase()->SetType(type); }) .def("_reset_grad_inplace_version", - [](imperative::VarBase &self) { + [](imperative::VarBase &self, bool set_to_zero) { /* *** This interfaceis a complete hack *** reset_grad_inplace_version removes all inplace related records to @@ -1550,15 +1550,20 @@ void BindImperative(py::module *m_ptr) { Make sure you fully understand what you're doing before make use of this interface, and prepare for the worst. */ + py::gil_scoped_release release; + if (self.HasGradVar()) { auto grad_var = self.GradVarBase(); auto var_wrapper = grad_var->SharedVar(); - if (var_wrapper) var_wrapper->ResetInplaceVersion(); + if (var_wrapper) { + var_wrapper->ResetInplaceVersion(set_to_zero); + } } }) .def("_grad_ivar", [](const imperative::VarBase &self) { auto &grad_var = self.GradVarBase(); + if (grad_var && grad_var->Var().IsInitialized()) { auto *tensor = grad_var->MutableVar()->IsType() @@ -1567,6 +1572,7 @@ void BindImperative(py::module *m_ptr) { : grad_var->MutableVar() ->GetMutable() ->mutable_value(); + if (tensor->IsInitialized()) { return grad_var; } diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py index 329dc9eaa4e575cccb57b7aacf2214feab2f41f0..37b85751149f71336ba431e2876e48eeaa85496d 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/sharding_stage2.py @@ -177,7 +177,7 @@ class ShardingStage2(nn.Layer): for param in self._trainable_params: if param.name in self._param_grads and param.grad is not None: param.grad.scale_(scale=self._world_size_scaling) - param._reset_grad_inplace_version() + param._reset_grad_inplace_version(True) def _init_internal_storage(self, needs_fresh): """ @@ -283,7 +283,7 @@ class ShardingStage2(nn.Layer): self._grad_reduced[index] = False if not self._accumulate_grads: param.grad.scale_(scale=self._world_size_scaling) - param._reset_grad_inplace_version() + param._reset_grad_inplace_version(True) # Clear the gradient that does not belong to the current rank through the callback function def cleanup(): diff --git a/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py b/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py index d9634f4997d80eaa6192d3edb8c580683aff192b..fee5bb8f47f260e8f4675c19fd194ec80a9dd4a6 100644 --- a/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py +++ b/python/paddle/fluid/tests/unittests/test_reset_grad_inplace_version.py @@ -20,12 +20,13 @@ import unittest paddle.set_device('cpu') -def clear_grad(w, a): +# Test 1 +def clear_grad_test_0(w, a): @paddle.no_grad() def warp(*_): assert w.grad is not None _C_ops.scale_(w.grad, 'scale', 0.5) - w._reset_grad_inplace_version() + w._reset_grad_inplace_version(True) return warp @@ -35,7 +36,7 @@ class TestInplaceAndClearGradient(unittest.TestCase): input_data = np.ones([1, 1]) w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) - _clear_grad = clear_grad(w, a="1") + _clear_grad = clear_grad_test_0(w, a="1") w._register_backward_hook(_clear_grad) for i in range(2): print(" Step: ", i) @@ -45,5 +46,60 @@ class TestInplaceAndClearGradient(unittest.TestCase): assert w.grad[0] == 0.15 +# Test 2 +class Counter: + def __init__(self): + self.num_calls = 0 + self.step = 0 + + +def clear_grad_test_1(w, c): + @paddle.no_grad() + def warp(*_): + assert w.grad is not None + if c.step == 1: + w.grad.scale_(scale=0.5) + w._reset_grad_inplace_version(True) + + c.num_calls += 1 + + return warp + + +class TestInplaceClearGradAccumulation(unittest.TestCase): + def test(self): + input_data = np.ones([1, 1]) + w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) + c = Counter() + + _clear_grad = clear_grad_test_1(w, c) + w._register_backward_hook(_clear_grad) + + for c.step in range(5): + out0 = _C_ops.scale(w, 'scale', 0.1) + out = _C_ops.matmul_v2(out0, w, 'trans_x', False, 'trans_y', False) + + out.backward() + + if c.step == 1: + w.clear_gradient(False) + + assert c.num_calls == 1 + c.num_calls = 0 + + +class TestInplaceClearGradAccumulationAlt(unittest.TestCase): + def test(self): + input_data = np.ones([1, 1]) + w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) + out = _C_ops.scale(w, 'scale', 0.1) + out.backward() + + w.grad.scale_(scale=0.5) + w._reset_grad_inplace_version(False) + + assert w.grad._inplace_version() == 1 + + if __name__ == '__main__': unittest.main()