未验证 提交 7ed81711 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Modified EagerUtils interfaces (#39200)

上级 eb1f9439
...@@ -1014,7 +1014,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1014,7 +1014,7 @@ static std::string GenerateGradNodeCreationContent(
if (output.duplicable()) { if (output.duplicable()) {
const char* GET_MULTI_AUTOGRAD_META_TEMPLATE = const char* GET_MULTI_AUTOGRAD_META_TEMPLATE =
" std::vector<egr::AutogradMeta*> %s = " " std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::multi_autograd_meta(&%s);\n"; "egr::EagerUtils::autograd_meta(&%s);\n";
get_autograd_meta_str += paddle::string::Sprintf( get_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name); GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name);
...@@ -1107,7 +1107,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1107,7 +1107,7 @@ static std::string GenerateGradNodeCreationContent(
size_t input_position = fwd_inputs_name_pos_map.at(input_name); size_t input_position = fwd_inputs_name_pos_map.at(input_name);
const char* SET_GRAD_OUT_META_TEMPLATE = const char* SET_GRAD_OUT_META_TEMPLATE =
" grad_node->SetGradOutMeta(%s, %d);\n"; " grad_node->SetGradOutMeta(&%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_position); SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_position);
...@@ -1138,6 +1138,11 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1138,6 +1138,11 @@ static std::string GenerateGradNodeCreationContent(
grad_node_creation_str += grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(&%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
} else { } else {
pass_stop_gradient_args += ", " + output_autograd_name; pass_stop_gradient_args += ", " + output_autograd_name;
const char* SET_OUT_RANK_TEMPLATE = const char* SET_OUT_RANK_TEMPLATE =
...@@ -1149,12 +1154,12 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1149,12 +1154,12 @@ static std::string GenerateGradNodeCreationContent(
" egr::EagerUtils::SetHistory(%s, grad_node);\n"; " egr::EagerUtils::SetHistory(%s, grad_node);\n";
grad_node_creation_str += grad_node_creation_str +=
paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name); paddle::string::Sprintf(SET_HISTORY_TEMPLATE, output_autograd_name);
}
const char* SET_GRAD_IN_META_TEMPLATE = const char* SET_GRAD_IN_META_TEMPLATE =
" grad_node->SetGradInMeta(%s, %d);\n"; " grad_node->SetGradInMeta(%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position); SET_GRAD_IN_META_TEMPLATE, output_autograd_name, output_position);
}
VLOG(6) << "Generated Call RetainGradForTensor"; VLOG(6) << "Generated Call RetainGradForTensor";
const char* RETAIN_GRAD_TEMPLATE = const char* RETAIN_GRAD_TEMPLATE =
......
...@@ -90,9 +90,9 @@ const std::vector<GradSlotMeta>& GradNodeBase::OutputMeta() const { ...@@ -90,9 +90,9 @@ const std::vector<GradSlotMeta>& GradNodeBase::OutputMeta() const {
return bwd_out_meta_; return bwd_out_meta_;
} }
void GradNodeBase::SetGradInMeta(const std::vector<AutogradMeta*>& fwd_out, void GradNodeBase::SetGradInMeta(std::vector<AutogradMeta*>* fwd_out,
size_t slot_rank) { size_t slot_rank) {
size_t slot_size = fwd_out.size(); size_t slot_size = fwd_out->size();
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1), slot_rank, (bwd_in_meta_.size() - 1),
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -108,15 +108,15 @@ void GradNodeBase::SetGradInMeta(const std::vector<AutogradMeta*>& fwd_out, ...@@ -108,15 +108,15 @@ void GradNodeBase::SetGradInMeta(const std::vector<AutogradMeta*>& fwd_out,
// Init stop gradient vector before use to avoid push back // Init stop gradient vector before use to avoid push back
meta.Init(slot_size); meta.Init(slot_size);
for (size_t i = 0; i < slot_size; i++) { for (size_t i = 0; i < slot_size; i++) {
PADDLE_ENFORCE_NOT_NULL(fwd_out[i], PADDLE_ENFORCE_NOT_NULL((*fwd_out)[i],
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Bwd_in_meta should only be called while " "Bwd_in_meta should only be called while "
"autograd_meta is not null. If you got this " "autograd_meta is not null. If you got this "
"error, it indicates bugs in framework.")); "error, it indicates bugs in framework."));
if (fwd_out[i]->StopGradient()) { if ((*fwd_out)[i]->StopGradient()) {
// Set Stop Gradient only when its true or non-initialized autograd_meta, // Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false. // since all default value is false.
meta.SetStopGradient(i, fwd_out[i]->StopGradient()); meta.SetStopGradient(i, (*fwd_out)[i]->StopGradient());
} }
} }
} }
...@@ -140,9 +140,9 @@ void GradNodeBase::SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank) { ...@@ -140,9 +140,9 @@ void GradNodeBase::SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank) {
meta.SetStopGradient(0, fwd_out->StopGradient()); meta.SetStopGradient(0, fwd_out->StopGradient());
} }
void GradNodeBase::SetGradOutMeta(const std::vector<AutogradMeta*>& fwd_in, void GradNodeBase::SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in,
size_t slot_rank) { size_t slot_rank) {
size_t slot_size = fwd_in.size(); size_t slot_size = fwd_in->size();
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
slot_rank, (bwd_out_meta_.size() - 1), slot_rank, (bwd_out_meta_.size() - 1),
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
...@@ -158,14 +158,14 @@ void GradNodeBase::SetGradOutMeta(const std::vector<AutogradMeta*>& fwd_in, ...@@ -158,14 +158,14 @@ void GradNodeBase::SetGradOutMeta(const std::vector<AutogradMeta*>& fwd_in,
// Init stop gradient vector before use to avoid push back // Init stop gradient vector before use to avoid push back
meta.Init(slot_size); meta.Init(slot_size);
for (size_t i = 0; i < slot_size; i++) { for (size_t i = 0; i < slot_size; i++) {
if (!fwd_in[i]) { if (!(*fwd_in)[i]) {
meta.SetStopGradient(i, true); meta.SetStopGradient(i, true);
continue; continue;
} }
if (fwd_in[i]->StopGradient()) { if ((*fwd_in)[i]->StopGradient()) {
// Set Stop Gradient only when its true or non-initialized autograd_meta, // Set Stop Gradient only when its true or non-initialized autograd_meta,
// since all default value is false. // since all default value is false.
meta.SetStopGradient(i, fwd_in[i]->StopGradient()); meta.SetStopGradient(i, (*fwd_in)[i]->StopGradient());
} }
} }
} }
......
...@@ -121,12 +121,10 @@ class GradNodeBase { ...@@ -121,12 +121,10 @@ class GradNodeBase {
* Set bwd ins and outs info with forward vars * Set bwd ins and outs info with forward vars
* **/ * **/
void SetGradInMeta(const std::vector<AutogradMeta*>& fwd_out, void SetGradInMeta(std::vector<AutogradMeta*>* fwd_out, size_t slot_rank);
size_t slot_rank);
void SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank); void SetGradInMeta(AutogradMeta* fwd_out, size_t slot_rank);
void SetGradOutMeta(const std::vector<AutogradMeta*>& fwd_in, void SetGradOutMeta(std::vector<AutogradMeta*>* fwd_in, size_t slot_rank);
size_t slot_rank);
void SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank); void SetGradOutMeta(AutogradMeta* fwd_in, size_t slot_rank);
/** /**
......
...@@ -75,9 +75,9 @@ TEST(GradNodeInfo, GradNodeBase) { ...@@ -75,9 +75,9 @@ TEST(GradNodeInfo, GradNodeBase) {
VLOG(6) << "Test Set Meta and Get Meta"; VLOG(6) << "Test Set Meta and Get Meta";
auto_grad1->SetStopGradient(true); auto_grad1->SetStopGradient(true);
grad_test_node0->SetGradInMeta(metas, 0); grad_test_node0->SetGradInMeta(&metas, 0);
grad_test_node0->SetGradInMeta(auto_grad1.get(), 1); grad_test_node0->SetGradInMeta(auto_grad1.get(), 1);
grad_test_node0->SetGradOutMeta(metas, 0); grad_test_node0->SetGradOutMeta(&metas, 0);
grad_test_node0->SetGradOutMeta(auto_grad1.get(), 1); grad_test_node0->SetGradOutMeta(auto_grad1.get(), 1);
CHECK_EQ(grad_test_node0->InputMeta()[0].Size(), 1); CHECK_EQ(grad_test_node0->InputMeta()[0].Size(), 1);
CHECK_EQ(grad_test_node0->InputMeta()[1].Size(), 1); CHECK_EQ(grad_test_node0->InputMeta()[1].Size(), 1);
...@@ -149,7 +149,7 @@ TEST(GradNodeInfo, Edge) { ...@@ -149,7 +149,7 @@ TEST(GradNodeInfo, Edge) {
auto auto_grad1 = std::make_shared<egr::AutogradMeta>(); auto auto_grad1 = std::make_shared<egr::AutogradMeta>();
std::vector<egr::AutogradMeta*> metas = {auto_grad1.get()}; std::vector<egr::AutogradMeta*> metas = {auto_grad1.get()};
// Uninitialized AutogradMeta indicates // Uninitialized AutogradMeta indicates
mt_grad_node->SetGradInMeta(metas, 0); mt_grad_node->SetGradInMeta(&metas, 0);
CHECK(grad_node->InputMeta()[0].IsStopGradient(0) == true); CHECK(grad_node->InputMeta()[0].IsStopGradient(0) == true);
VLOG(6) << "Test Get/Set Edge Rank Info"; VLOG(6) << "Test Get/Set Edge Rank Info";
CHECK_EQ(edge2.GetEdgeRankInfo().first, size_t(1)); CHECK_EQ(edge2.GetEdgeRankInfo().first, size_t(1));
......
...@@ -51,7 +51,6 @@ TEST(EagerUtils, AutoGradMeta) { ...@@ -51,7 +51,6 @@ TEST(EagerUtils, AutoGradMeta) {
// unsafe_autograd_meta() // unsafe_autograd_meta()
// autograd_meta() // autograd_meta()
// multi_autograd_meta()
AutogradMeta* autograd_meta0 = EagerUtils::autograd_meta(&et0); AutogradMeta* autograd_meta0 = EagerUtils::autograd_meta(&et0);
AutogradMeta* autograd_meta1 = EagerUtils::autograd_meta(&et1); AutogradMeta* autograd_meta1 = EagerUtils::autograd_meta(&et1);
...@@ -59,8 +58,7 @@ TEST(EagerUtils, AutoGradMeta) { ...@@ -59,8 +58,7 @@ TEST(EagerUtils, AutoGradMeta) {
EagerUtils::unsafe_autograd_meta(et0); EagerUtils::unsafe_autograd_meta(et0);
CHECK_NOTNULL(unsafe_autograd_meta_after); CHECK_NOTNULL(unsafe_autograd_meta_after);
std::vector<AutogradMeta*> autograd_metas = std::vector<AutogradMeta*> autograd_metas = EagerUtils::autograd_meta(&ets);
EagerUtils::multi_autograd_meta(&ets);
std::vector<AutogradMeta*> unsafe_autograd_metas = std::vector<AutogradMeta*> unsafe_autograd_metas =
EagerUtils::unsafe_autograd_meta(ets); EagerUtils::unsafe_autograd_meta(ets);
CHECK_NOTNULL(unsafe_autograd_metas[0]); CHECK_NOTNULL(unsafe_autograd_metas[0]);
......
...@@ -79,12 +79,12 @@ std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta( ...@@ -79,12 +79,12 @@ std::vector<AutogradMeta*> EagerUtils::nullable_autograd_meta(
return metas; return metas;
} }
std::vector<AutogradMeta*> EagerUtils::multi_autograd_meta( std::vector<AutogradMeta*> EagerUtils::autograd_meta(
std::vector<egr::EagerTensor>* targets) { std::vector<egr::EagerTensor>* targets) {
std::vector<AutogradMeta*> ret; std::vector<AutogradMeta*> ret;
ret.reserve(targets->size()); ret.reserve(targets->size());
// for multi_autograd_meta we can tolerent it has nullptr. // for autograd_meta we can tolerent it has nullptr.
for (auto& t : (*targets)) { for (auto& t : (*targets)) {
auto* p_autograd_meta = autograd_meta(&t); auto* p_autograd_meta = autograd_meta(&t);
ret.push_back(static_cast<AutogradMeta*>(p_autograd_meta)); ret.push_back(static_cast<AutogradMeta*>(p_autograd_meta));
......
...@@ -94,7 +94,7 @@ class EagerUtils { ...@@ -94,7 +94,7 @@ class EagerUtils {
* **/ * **/
static AutogradMeta* autograd_meta(egr::EagerTensor* target); static AutogradMeta* autograd_meta(egr::EagerTensor* target);
static std::vector<AutogradMeta*> multi_autograd_meta( static std::vector<AutogradMeta*> autograd_meta(
std::vector<egr::EagerTensor>* targets); std::vector<egr::EagerTensor>* targets);
static std::pair<size_t, size_t> OutRankInfo(const egr::EagerTensor& target); static std::pair<size_t, size_t> OutRankInfo(const egr::EagerTensor& target);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册