未验证 提交 cf586021 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Buf fix for reset grad inplace version (#37811)

* Debug

* Fixed issue with reset_grad_inplace_version when used with clear_gradient & cross-batch accumulation

* Rearranged interfaces

* Fixed ci issues
上级 723cbe9f
...@@ -81,6 +81,7 @@ class TensorInplaceVersion { ...@@ -81,6 +81,7 @@ class TensorInplaceVersion {
bool IsUnique() const { return inplace_version_ == 0; } bool IsUnique() const { return inplace_version_ == 0; }
void Bump() { ++inplace_version_; } void Bump() { ++inplace_version_; }
uint32_t CurrentVersion() const { return inplace_version_; } uint32_t CurrentVersion() const { return inplace_version_; }
void SetInplaceVersionToZero() { inplace_version_ = 0; }
private: private:
uint32_t inplace_version_; uint32_t inplace_version_;
......
...@@ -75,6 +75,7 @@ class Variable { ...@@ -75,6 +75,7 @@ class Variable {
framework::TensorInplaceVersion* InplaceVersionCounter(); framework::TensorInplaceVersion* InplaceVersionCounter();
public: public:
void SetInplaceVersionToZero();
uint32_t CurrentInplaceVersion(); uint32_t CurrentInplaceVersion();
void BumpInplaceVersion(); void BumpInplaceVersion();
...@@ -134,6 +135,12 @@ inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() { ...@@ -134,6 +135,12 @@ inline framework::TensorInplaceVersion* Variable::InplaceVersionCounter() {
return version_counter_ptr; return version_counter_ptr;
} }
inline void Variable::SetInplaceVersionToZero() {
auto inplace_version_counter = this->InplaceVersionCounter();
if (inplace_version_counter)
inplace_version_counter->SetInplaceVersionToZero();
}
inline uint32_t Variable::CurrentInplaceVersion() { inline uint32_t Variable::CurrentInplaceVersion() {
auto version_counter_ptr = InplaceVersionCounter(); auto version_counter_ptr = InplaceVersionCounter();
if (version_counter_ptr) { if (version_counter_ptr) {
......
...@@ -209,13 +209,23 @@ class VariableWrapper { ...@@ -209,13 +209,23 @@ class VariableWrapper {
uint32_t InplaceVersionSnapshot() const { return inplace_version_snapshot_; } uint32_t InplaceVersionSnapshot() const { return inplace_version_snapshot_; }
void ResetInplaceVersion() { void ResetInplaceVersion(bool set_to_zero = false) {
auto new_version = var_.CurrentInplaceVersion(); if (!set_to_zero) {
auto new_version = var_.CurrentInplaceVersion();
VLOG(6) << "The wrapper version of VariableWrapper '" << name_ VLOG(6) << "The wrapper version of VariableWrapper '" << name_
<< "' will be updated from " << inplace_version_snapshot_ << "to " << "' will be updated from " << inplace_version_snapshot_ << "to "
<< new_version; << new_version;
inplace_version_snapshot_ = new_version; inplace_version_snapshot_ = new_version;
} else {
// Reset Snapshot & InplaceVersion to zero
inplace_version_snapshot_ = 0;
auto var = this->MutableVar();
if (var) {
var->SetInplaceVersionToZero();
}
}
} }
bool hasCacheKey(const paddle::framework::OpKernelType& key) { bool hasCacheKey(const paddle::framework::OpKernelType& key) {
......
...@@ -1538,7 +1538,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -1538,7 +1538,7 @@ void BindImperative(py::module *m_ptr) {
self.MutableGradVarBase()->SetType(type); self.MutableGradVarBase()->SetType(type);
}) })
.def("_reset_grad_inplace_version", .def("_reset_grad_inplace_version",
[](imperative::VarBase &self) { [](imperative::VarBase &self, bool set_to_zero) {
/* /*
*** This interfaceis a complete hack *** *** This interfaceis a complete hack ***
reset_grad_inplace_version removes all inplace related records to reset_grad_inplace_version removes all inplace related records to
...@@ -1550,15 +1550,20 @@ void BindImperative(py::module *m_ptr) { ...@@ -1550,15 +1550,20 @@ void BindImperative(py::module *m_ptr) {
Make sure you fully understand what you're doing before make use of Make sure you fully understand what you're doing before make use of
this interface, and prepare for the worst. this interface, and prepare for the worst.
*/ */
py::gil_scoped_release release;
if (self.HasGradVar()) { if (self.HasGradVar()) {
auto grad_var = self.GradVarBase(); auto grad_var = self.GradVarBase();
auto var_wrapper = grad_var->SharedVar(); auto var_wrapper = grad_var->SharedVar();
if (var_wrapper) var_wrapper->ResetInplaceVersion(); if (var_wrapper) {
var_wrapper->ResetInplaceVersion(set_to_zero);
}
} }
}) })
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { [](const imperative::VarBase &self) {
auto &grad_var = self.GradVarBase(); auto &grad_var = self.GradVarBase();
if (grad_var && grad_var->Var().IsInitialized()) { if (grad_var && grad_var->Var().IsInitialized()) {
auto *tensor = auto *tensor =
grad_var->MutableVar()->IsType<framework::LoDTensor>() grad_var->MutableVar()->IsType<framework::LoDTensor>()
...@@ -1567,6 +1572,7 @@ void BindImperative(py::module *m_ptr) { ...@@ -1567,6 +1572,7 @@ void BindImperative(py::module *m_ptr) {
: grad_var->MutableVar() : grad_var->MutableVar()
->GetMutable<framework::SelectedRows>() ->GetMutable<framework::SelectedRows>()
->mutable_value(); ->mutable_value();
if (tensor->IsInitialized()) { if (tensor->IsInitialized()) {
return grad_var; return grad_var;
} }
......
...@@ -177,7 +177,7 @@ class ShardingStage2(nn.Layer): ...@@ -177,7 +177,7 @@ class ShardingStage2(nn.Layer):
for param in self._trainable_params: for param in self._trainable_params:
if param.name in self._param_grads and param.grad is not None: if param.name in self._param_grads and param.grad is not None:
param.grad.scale_(scale=self._world_size_scaling) param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version() param._reset_grad_inplace_version(True)
def _init_internal_storage(self, needs_fresh): def _init_internal_storage(self, needs_fresh):
""" """
...@@ -283,7 +283,7 @@ class ShardingStage2(nn.Layer): ...@@ -283,7 +283,7 @@ class ShardingStage2(nn.Layer):
self._grad_reduced[index] = False self._grad_reduced[index] = False
if not self._accumulate_grads: if not self._accumulate_grads:
param.grad.scale_(scale=self._world_size_scaling) param.grad.scale_(scale=self._world_size_scaling)
param._reset_grad_inplace_version() param._reset_grad_inplace_version(True)
# Clear the gradient that does not belong to the current rank through the callback function # Clear the gradient that does not belong to the current rank through the callback function
def cleanup(): def cleanup():
......
...@@ -20,12 +20,13 @@ import unittest ...@@ -20,12 +20,13 @@ import unittest
paddle.set_device('cpu') paddle.set_device('cpu')
def clear_grad(w, a): # Test 1
def clear_grad_test_0(w, a):
@paddle.no_grad() @paddle.no_grad()
def warp(*_): def warp(*_):
assert w.grad is not None assert w.grad is not None
_C_ops.scale_(w.grad, 'scale', 0.5) _C_ops.scale_(w.grad, 'scale', 0.5)
w._reset_grad_inplace_version() w._reset_grad_inplace_version(True)
return warp return warp
...@@ -35,7 +36,7 @@ class TestInplaceAndClearGradient(unittest.TestCase): ...@@ -35,7 +36,7 @@ class TestInplaceAndClearGradient(unittest.TestCase):
input_data = np.ones([1, 1]) input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False) w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
_clear_grad = clear_grad(w, a="1") _clear_grad = clear_grad_test_0(w, a="1")
w._register_backward_hook(_clear_grad) w._register_backward_hook(_clear_grad)
for i in range(2): for i in range(2):
print(" Step: ", i) print(" Step: ", i)
...@@ -45,5 +46,60 @@ class TestInplaceAndClearGradient(unittest.TestCase): ...@@ -45,5 +46,60 @@ class TestInplaceAndClearGradient(unittest.TestCase):
assert w.grad[0] == 0.15 assert w.grad[0] == 0.15
# Test 2
class Counter:
def __init__(self):
self.num_calls = 0
self.step = 0
def clear_grad_test_1(w, c):
@paddle.no_grad()
def warp(*_):
assert w.grad is not None
if c.step == 1:
w.grad.scale_(scale=0.5)
w._reset_grad_inplace_version(True)
c.num_calls += 1
return warp
class TestInplaceClearGradAccumulation(unittest.TestCase):
def test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
c = Counter()
_clear_grad = clear_grad_test_1(w, c)
w._register_backward_hook(_clear_grad)
for c.step in range(5):
out0 = _C_ops.scale(w, 'scale', 0.1)
out = _C_ops.matmul_v2(out0, w, 'trans_x', False, 'trans_y', False)
out.backward()
if c.step == 1:
w.clear_gradient(False)
assert c.num_calls == 1
c.num_calls = 0
class TestInplaceClearGradAccumulationAlt(unittest.TestCase):
def test(self):
input_data = np.ones([1, 1])
w = paddle.to_tensor(input_data, 'float32', stop_gradient=False)
out = _C_ops.scale(w, 'scale', 0.1)
out.backward()
w.grad.scale_(scale=0.5)
w._reset_grad_inplace_version(False)
assert w.grad._inplace_version() == 1
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册