diff --git a/paddle/fluid/imperative/engine.cc b/paddle/fluid/imperative/engine.cc index 3a41bafbfc4c81d0fba3f07db23b3e7f2b670f79..877e6ceb6a4cfa5305c88c44066d1deacf69ac88 100644 --- a/paddle/fluid/imperative/engine.cc +++ b/paddle/fluid/imperative/engine.cc @@ -44,8 +44,9 @@ void BasicEngine::Init(VarBase* var, const detail::BackwardStrategy& strategy) { const std::vector ops = var->GradVarBase()->GradOps(); var->ClearGradOps(); - if (ops.empty()) { - VLOG(3) << "Skip auto grad since there is no grad op for var: " + if (ops.empty() || var->OverridedStopGradient()) { + VLOG(3) << "Skip auto grad since there is no grad op for var or loss is " + "stop_gradient=True: " << var->Name(); return; } else { diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 509415a367284d0e92f8d45c011695ad727bc8ec..873164fc28773c144f5f97c1498b732b7b0800e4 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -116,11 +116,21 @@ void EagerGradientAccumulator::Add(std::shared_ptr var, } else { if (!var_->Var().IsInitialized() || !var_->Var().Get().IsInitialized()) { - VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero"; + VLOG(6) << "Set StopGradient Grad: " << var_->Name() << " as zero "; + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto* tensor = var_->MutableVar()->GetMutable(); - tensor->mutable_data(place, var->DataType()); - operators::math::set_constant(*dev_ctx, tensor, 0.0); + if (!var_->Var().IsInitialized()) { + auto* tensor = var_->MutableVar()->GetMutable(); + VLOG(6) << "Dims of " << var_->Name() << " is set as: " + << var->Var().Get().dims(); + tensor->Resize(var->Var().Get().dims()); + tensor->mutable_data(place, var->DataType()); + operators::math::set_constant(*dev_ctx, tensor, 0.0); + } else { + auto* tensor = var_->MutableVar()->GetMutable(); + tensor->mutable_data(place, var->DataType()); + operators::math::set_constant(*dev_ctx, tensor, 0.0); + } } } ++cur_cnt_; @@ -162,9 +172,18 @@ void SortedGradientAccumulator::Add(std::shared_ptr var, !var_->Var().Get().IsInitialized()) { VLOG(6) << "Set StopGradient Grad: " << var->Name() << " as zero"; auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - auto* tensor = var_->MutableVar()->GetMutable(); - tensor->mutable_data(place, var->DataType()); - operators::math::set_constant(*dev_ctx, tensor, 0.0); + if (!var_->Var().IsInitialized()) { + auto* tensor = var_->MutableVar()->GetMutable(); + VLOG(6) << "Dims of " << var_->Name() << " is set as: " + << var->Var().Get().dims(); + tensor->Resize(var->Var().Get().dims()); + tensor->mutable_data(place, var->DataType()); + operators::math::set_constant(*dev_ctx, tensor, 0.0); + } else { + auto* tensor = var_->MutableVar()->GetMutable(); + tensor->mutable_data(place, var->DataType()); + operators::math::set_constant(*dev_ctx, tensor, 0.0); + } } // looks like tmp_grad_vars will not have any member but just in case tmp_grad_vars_.clear(); diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py index ac849e1cfb856b426f3088b10a06a0afb237568e..b84b2ac50a8b877b9526ef60575d6175b4fcaf8b 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_prune.py @@ -241,6 +241,72 @@ class TestImperativeAutoPrune(unittest.TestCase): self.assertTrue((fc._w.gradient() == 0).all()) self.assertTrue((out1.gradient() == 0).all()) + def test_auto_prune8(self): + with fluid.dygraph.guard(): + value0 = np.arange(26).reshape(2, 13).astype("float32") + value1 = np.arange(6).reshape(2, 3).astype("float32") + value2 = np.arange(10).reshape(2, 5).astype("float32") + fc = fluid.FC("fc1", size=5, dtype="float32") + fc2 = fluid.FC("fc2", size=3, dtype="float32") + a = fluid.dygraph.to_variable(value0) + b = fluid.dygraph.to_variable(value1) + c = fluid.dygraph.to_variable(value2) + out1 = fc(a) + fc_origin = fc._w.numpy() + out2 = fc2(out1) + fc2_origin = fc2._w.numpy() + fc2._w.stop_gradient = True + out2.backward() + optimizer = fluid.optimizer.SGD(learning_rate=0.003) + optimizer.minimize(out2) + self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) + self.assertFalse(np.array_equal(fc_origin, fc._w.numpy())) + + def test_auto_prune9(self): + with fluid.dygraph.guard(): + value0 = np.arange(26).reshape(2, 13).astype("float32") + value1 = np.arange(6).reshape(2, 3).astype("float32") + value2 = np.arange(10).reshape(2, 5).astype("float32") + fc = fluid.FC("fc1", size=5, dtype="float32") + fc2 = fluid.FC("fc2", size=3, dtype="float32") + a = fluid.dygraph.to_variable(value0) + b = fluid.dygraph.to_variable(value1) + c = fluid.dygraph.to_variable(value2) + out1 = fc(a) + fc_origin = fc._w.numpy() + out2 = fc2(out1) + fc2_origin = fc2._w.numpy() + out2.stop_gradient = True + out2.backward() + optimizer = fluid.optimizer.SGD(learning_rate=0.003) + optimizer.minimize(out2) + self.assertTrue(np.array_equal(fc2_origin, fc2._w.numpy())) + self.assertTrue(np.array_equal(fc_origin, fc._w.numpy())) + try: + fc2._w.gradient() + except ValueError as e: + assert type(e) == ValueError + + def test_auto_prune10(self): + with fluid.dygraph.guard(): + value0 = np.arange(26).reshape(2, 13).astype("float32") + value1 = np.arange(6).reshape(2, 3).astype("float32") + value2 = np.arange(10).reshape(2, 5).astype("float32") + fc = fluid.FC("fc1", size=5, dtype="float32") + fc2 = fluid.FC("fc2", size=3, dtype="float32") + a = fluid.dygraph.to_variable(value0) + b = fluid.dygraph.to_variable(value1) + c = fluid.dygraph.to_variable(value2) + out1 = fc(a) + out2 = fc2(b) + out1.stop_gradient = True + out = fluid.layers.concat(input=[out1, out2, c], axis=1) + backward_strategy = fluid.dygraph.BackwardStrategy() + backward_strategy.sort_sum_gradient = True + out.backward(backward_strategy) + self.assertTrue((fc._w.gradient() == 0).all()) + self.assertTrue((out1.gradient() == 0).all()) + def test_auto_prune_with_optimizer(self): vocab_size = 100 size = 20