未验证 提交 ae92da87 编写于 作者: J Jiabin Yang 提交者: GitHub

Support GetGradAccumulator for reducer (#39537)

上级 831fd86e
......@@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/eager/eager_tensor.h"
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h"
......@@ -224,4 +225,26 @@ TEST(EagerUtils, CreateVars) {
CHECK(outs[0]->Var().IsInitialized() == false);
}
TEST(EagerUtils, GetGradAccumulationNode) {
VLOG(6) << "Check GetGradAccumulationNode";
paddle::experimental::Tensor t0("test_tensor");
ASSERT_EQ(egr::EagerUtils::GetGradAccumulationNode(t0), nullptr);
auto autograd_ptr0 = egr::EagerUtils::autograd_meta(&t0);
autograd_ptr0->SetStopGradient(true);
ASSERT_EQ(egr::EagerUtils::GetGradAccumulationNode(t0), nullptr);
autograd_ptr0->SetStopGradient(false);
auto res = std::dynamic_pointer_cast<egr::GradNodeAccumulation>(
egr::EagerUtils::GetGradAccumulationNode(t0));
ASSERT_TRUE(res != nullptr);
auto res2 = egr::EagerUtils::GetGradAccumulationNode(t0);
ASSERT_EQ(res2.get(), res.get());
autograd_ptr0->SetStopGradient(true);
auto res3 = egr::EagerUtils::GetGradAccumulationNode(t0);
ASSERT_EQ(res3, nullptr);
autograd_ptr0->SetStopGradient(false);
autograd_ptr0->SetGradNode(
std::make_shared<eager_test::GradTestNode>(1, 2.0, 3));
ASSERT_ANY_THROW(egr::EagerUtils::GetGradAccumulationNode(t0));
}
} // namespace egr
......@@ -21,6 +21,7 @@
#include "paddle/pten/common/layout.h"
#include "paddle/pten/core/tensor_meta.h"
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/fluid/framework/variable.h"
......@@ -303,4 +304,41 @@ void EagerUtils::CheckAndRetainGrad(
}
}
std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor) {
auto* autograd_ptr = nullable_autograd_meta(tensor);
if (!autograd_ptr) {
return nullptr;
}
auto node_ptr = autograd_ptr->GetMutableGradNode();
if (node_ptr && node_ptr.get()) {
if (!autograd_ptr->StopGradient()) {
auto accumulation_ptr =
std::dynamic_pointer_cast<GradNodeAccumulation>(node_ptr);
if (accumulation_ptr) {
return accumulation_ptr;
} else {
// Current GradNode is not a egr::GradNodeAccumulation
PADDLE_THROW(paddle::platform::errors::Fatal(
"GetGradAccumulationNode should only be called on leaf tensor, but "
"target tensor: %s has GradNode which is not a "
"GradNodeAccumulation, and this should not happend unless target "
"tensor is modified by some ops and calling set history for it.",
tensor.name()));
}
} else {
// Current Tensor does not have grad since it's stop_gradient is true;
return nullptr;
}
} else {
if (!autograd_ptr->StopGradient()) {
VLOG(6) << "Add GradNodeAccumulation for tensor: " << tensor.name();
autograd_ptr->SetGradNode(std::make_shared<egr::GradNodeAccumulation>());
return autograd_ptr->GetMutableGradNode();
} else {
return nullptr;
}
}
}
} // namespace egr
......@@ -189,6 +189,8 @@ class EagerUtils {
static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor);
static void CheckAndRetainGrad(
const std::vector<paddle::experimental::Tensor>& tensors);
static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor);
};
} // namespace egr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册