未验证 提交 7ed81711 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Modified EagerUtils interfaces (#39200)

上级 eb1f9439
......@@ -1014,7 +1014,7 @@ static std::string GenerateGradNodeCreationContent(
if (output.duplicable()) {
const char* GET_MULTI_AUTOGRAD_META_TEMPLATE =
" std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::multi_autograd_meta(&%s);\n";
"egr::EagerUtils::autograd_meta(&%s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name);
......@@ -1107,7 +1107,7 @@ static std::string GenerateGradNodeCreationContent(
size_t input_position = fwd_inputs_name_pos_map.at(input_name);
const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n";
" grad_node->SetGradOutMeta(&%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_position);
......@@ -1138,6 +1138,11 @@ static std::string GenerateGradNodeCreationContent(
grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(&%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
} else {
pass_stop_gradient_args += ", " + output_autograd_name;
const char* SET_OUT_RANK_TEMPLATE =
......@@ -1149,12 +1154,12 @@ static std::string GenerateGradNodeCreationContent(
" egr::EagerUtils::SetHistory(%s, grad_node);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
}
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
}
VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE =
......
......@@ -90,9 +90,9 @@ const std::vector<GradSlotMeta>& GradNodeBase::OutputMeta() const {
return bwd_out_meta_;
}
void GradNodeBase::SetGradInMeta(const std::vector<AutogradMeta*>& fwd_out,
void GradNodeBase::SetGradInMeta(std::vector<AutogradMeta*>* fwd_out,
size_t slot_rank) {
size_t slot_size = fwd_out.size();
size_t slot_size = fwd_out->size();
PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1),
paddle::platform::errors::InvalidArgument(
......@@ -108,15 +108,15 @@ void GradNodeBase::SetGradInMeta(const std::vector<AutogradMeta*>& fwd_out,
// Init stop gradient vector before use to avoid push back
meta.Init(slot_size);
for (size_t i = 0; i < slot_size; i++) {
PADDLE_ENFORCE_NOT_NULL(fwd_out[i],
PADDLE_ENFORCE_NOT_NULL((*fwd_out)[i],
paddle::platform::errors::PreconditionNotMet(
"Bwd_in_meta should only be called while "
"autograd_meta is not null. If you got this "
"error, it indicates bugs in framework."));
if (fwd_out[i]->StopGradient()) {
if ((*fwd_out)[i]->StopGradient()) {
// Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false.
meta.SetStopGradient(i, fwd_out[i]->StopGradient());
meta.SetStopGradient(i, (*fwd_out)[i]->StopGradient());
}
}
}
......@@ -140,9 +140,9 @@ void GradNodeBase::SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank) {
meta.SetStopGradient(0, fwd_out->StopGradient());
}
void GradNodeBase::SetGradOutMeta(const std::vector<AutogradMeta*>& fwd_in,
void GradNodeBase::SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in,
size_t slot_rank) {
size_t slot_size = fwd_in.size();
size_t slot_size = fwd_in->size();
PADDLE_ENFORCE_LE(
slot_rank, (bwd_out_meta_.size() - 1),
paddle::platform::errors::InvalidArgument(
......@@ -158,14 +158,14 @@ void GradNodeBase::SetGradOutMeta(const std::vector<AutogradMeta*>& fwd_in,
// Init stop gradient vector before use to avoid push back
meta.Init(slot_size);
for (size_t i = 0; i < slot_size; i++) {
if (!fwd_in[i]) {
if (!(*fwd_in)[i]) {
meta.SetStopGradient(i, true);
continue;
}
if (fwd_in[i]->StopGradient()) {
if ((*fwd_in)[i]->StopGradient()) {
// Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false.
meta.SetStopGradient(i, fwd_in[i]->StopGradient());
meta.SetStopGradient(i, (*fwd_in)[i]->StopGradient());
}
}
}
......
......@@ -121,12 +121,10 @@ class GradNodeBase {
* Set bwd ins and outs info with forward vars
* **/
void SetGradInMeta(const std::vector<AutogradMeta*>& fwd_out,
size_t slot_rank);
void SetGradInMeta(std::vector<AutogradMeta*>* fwd_out, size_t slot_rank);
void SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank);
void SetGradOutMeta(const std::vector<AutogradMeta*>& fwd_in,
size_t slot_rank);
void SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in, size_t slot_rank);
void SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank);
/**
......
......@@ -75,9 +75,9 @@ TEST(GradNodeInfo, GradNodeBase) {
VLOG(6) << "Test Set Meta and Get Meta";
auto_grad1->SetStopGradient(true);
grad_test_node0->SetGradInMeta(metas, 0);
grad_test_node0->SetGradInMeta(&metas, 0);
grad_test_node0->SetGradInMeta(auto_grad1.get(), 1);
grad_test_node0->SetGradOutMeta(metas, 0);
grad_test_node0->SetGradOutMeta(&metas, 0);
grad_test_node0->SetGradOutMeta(auto_grad1.get(), 1);
CHECK_EQ(grad_test_node0->InputMeta()[0].Size(), 1);
CHECK_EQ(grad_test_node0->InputMeta()[1].Size(), 1);
......@@ -149,7 +149,7 @@ TEST(GradNodeInfo, Edge) {
auto auto_grad1 = std::make_shared<egr::AutogradMeta>();
std::vector<egr::AutogradMeta*> metas = {auto_grad1.get()};
// Uninitialized AutogradMeta indicates
mt_grad_node->SetGradInMeta(metas, 0);
mt_grad_node->SetGradInMeta(&metas, 0);
CHECK(grad_node->InputMeta()[0].IsStopGradient(0) == true);
VLOG(6) << "Test Get/Set Edge Rank Info";
CHECK_EQ(edge2.GetEdgeRankInfo().first, size_t(1));
......
......@@ -51,7 +51,6 @@ TEST(EagerUtils, AutoGradMeta) {
// unsafe_autograd_meta()
// autograd_meta()
// multi_autograd_meta()
AutogradMeta* autograd_meta0 = EagerUtils::autograd_meta(&et0);
AutogradMeta* autograd_meta1 = EagerUtils::autograd_meta(&et1);
......@@ -59,8 +58,7 @@ TEST(EagerUtils, AutoGradMeta) {
EagerUtils::unsafe_autograd_meta(et0);
CHECK_NOTNULL(unsafe_autograd_meta_after);
std::vector<AutogradMeta*> autograd_metas =
EagerUtils::multi_autograd_meta(&ets);
std::vector<AutogradMeta*> autograd_metas = EagerUtils::autograd_meta(&ets);
std::vector<AutogradMeta*> unsafe_autograd_metas =
EagerUtils::unsafe_autograd_meta(ets);
CHECK_NOTNULL(unsafe_autograd_metas[0]);
......
......@@ -79,12 +79,12 @@ std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
return metas;
}
std::vector<AutogradMeta*> EagerUtils::multi_autograd_meta(
std::vector<AutogradMeta*> EagerUtils::autograd_meta(
std::vector<egr::EagerTensor>* targets) {
std::vector<AutogradMeta*> ret;
ret.reserve(targets->size());
// for multi_autograd_meta we can tolerent it has nullptr.
// for autograd_meta we can tolerent it has nullptr.
for (auto& t : (*targets)) {
auto* p_autograd_meta = autograd_meta(&t);
ret.push_back(static_cast<AutogradMeta*>(p_autograd_meta));
......
......@@ -94,7 +94,7 @@ class EagerUtils {
* **/
static AutogradMeta* autograd_meta(egr::EagerTensor* target);
static std::vector<AutogradMeta*> multi_autograd_meta(
static std::vector<AutogradMeta*> autograd_meta(
std::vector<egr::EagerTensor>* targets);
static std::pair<size_t, size_t> OutRankInfo(const egr::EagerTensor& target);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册