diff --git a/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc index dcf06bffebef314840ccb72c30b22bb65e5a80b2..db3d2cf519c6ddc892e0502dfcee6914d3e594a8 100644 --- a/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc +++ b/paddle/fluid/eager/tests/task_tests/eager_utils_test.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h" @@ -224,4 +225,26 @@ TEST(EagerUtils, CreateVars) { CHECK(outs[0]->Var().IsInitialized() == false); } +TEST(EagerUtils, GetGradAccumulationNode) { + VLOG(6) << "Check GetGradAccumulationNode"; + paddle::experimental::Tensor t0("test_tensor"); + ASSERT_EQ(egr::EagerUtils::GetGradAccumulationNode(t0), nullptr); + auto autograd_ptr0 = egr::EagerUtils::autograd_meta(&t0); + autograd_ptr0->SetStopGradient(true); + ASSERT_EQ(egr::EagerUtils::GetGradAccumulationNode(t0), nullptr); + autograd_ptr0->SetStopGradient(false); + auto res = std::dynamic_pointer_cast( + egr::EagerUtils::GetGradAccumulationNode(t0)); + ASSERT_TRUE(res != nullptr); + auto res2 = egr::EagerUtils::GetGradAccumulationNode(t0); + ASSERT_EQ(res2.get(), res.get()); + autograd_ptr0->SetStopGradient(true); + auto res3 = egr::EagerUtils::GetGradAccumulationNode(t0); + ASSERT_EQ(res3, nullptr); + autograd_ptr0->SetStopGradient(false); + autograd_ptr0->SetGradNode( + std::make_shared(1, 2.0, 3)); + ASSERT_ANY_THROW(egr::EagerUtils::GetGradAccumulationNode(t0)); +} + } // namespace egr diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index ec2ac2ee2a6e4c8416aa19fb104ba7b75560ac8e..a8c27e86b877ae7483e3c52c87d19308b9a48907 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -21,6 +21,7 @@ #include "paddle/pten/common/layout.h" #include "paddle/pten/core/tensor_meta.h" +#include "paddle/fluid/eager/accumulation/accumulation_node.h" #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/variable.h" @@ -303,4 +304,41 @@ void EagerUtils::CheckAndRetainGrad( } } +std::shared_ptr EagerUtils::GetGradAccumulationNode( + const paddle::experimental::Tensor& tensor) { + auto* autograd_ptr = nullable_autograd_meta(tensor); + if (!autograd_ptr) { + return nullptr; + } + auto node_ptr = autograd_ptr->GetMutableGradNode(); + if (node_ptr && node_ptr.get()) { + if (!autograd_ptr->StopGradient()) { + auto accumulation_ptr = + std::dynamic_pointer_cast(node_ptr); + if (accumulation_ptr) { + return accumulation_ptr; + } else { + // Current GradNode is not a egr::GradNodeAccumulation + PADDLE_THROW(paddle::platform::errors::Fatal( + "GetGradAccumulationNode should only be called on leaf tensor, but " + "target tensor: %s has GradNode which is not a " + "GradNodeAccumulation, and this should not happend unless target " + "tensor is modified by some ops and calling set history for it.", + tensor.name())); + } + } else { + // Current Tensor does not have grad since it's stop_gradient is true; + return nullptr; + } + } else { + if (!autograd_ptr->StopGradient()) { + VLOG(6) << "Add GradNodeAccumulation for tensor: " << tensor.name(); + autograd_ptr->SetGradNode(std::make_shared()); + return autograd_ptr->GetMutableGradNode(); + } else { + return nullptr; + } + } +} + } // namespace egr diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index b6540b7e0178e03e3f4d63432e624b0619f9f04c..11c728e4c6c9bdd3e3ee60fb474200ff5ae20afc 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -189,6 +189,8 @@ class EagerUtils { static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor); static void CheckAndRetainGrad( const std::vector& tensors); + static std::shared_ptr GetGradAccumulationNode( + const paddle::experimental::Tensor& tensor); }; } // namespace egr