未验证 提交 f2043bd1 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Handled Dispensable Inputs/Outputs in Eager AutoCodeGen (#37959)

* Rearranged Eager AutoCodeGen directory structure

* Removed USE_OP in Eager AutoCodeGen

* Enabled generation for Operators without Grad/Inputs/Outputs

* Resolved operators without input

* Fixed merge conflicts

* Enabled Eager AutoCodeGen for 10+ more operators

* Refactored Eager AutoCodeGen with more organized helper objects

* Enabled Eager AutoCodeGen for operators with multiple OpBases

* Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument

* Handled Dispensable Inputs/Outputs in Eager AutoCodeGen
上级 b95c9cf2
...@@ -646,96 +646,6 @@ static void PurifyGradNodeGenerationInfo(const proto::OpProto& op_proto, ...@@ -646,96 +646,6 @@ static void PurifyGradNodeGenerationInfo(const proto::OpProto& op_proto,
} }
} }
static void PurifyGradOpProto(
const proto::OpProto& op_proto,
std::map<std::string, std::string>* grad_outs_slotname_map,
std::map<std::string, std::string>* grad_ins_fwd_slotname_map,
std::map<std::string, std::string>* grad_ins_grad_slotname_map,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins,
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs) {
// Op Name
const std::string op_name = op_proto.type();
// Handle dispensable inputs
for (const proto::OpProto::Var& input : op_proto.inputs()) {
std::string input_name = input.name();
// Delete dispensable tensor unless specified in op_ins_map
if (input.dispensable()) {
if (!op_ins_map.count(op_name) ||
!op_ins_map[op_name].count(input_name)) {
VLOG(6) << "Removing Dispensable Input: " << input_name;
// grad_outs_slotname_map
auto grad_outs_slotname_map_purified = *grad_outs_slotname_map;
for (const auto& iter : *grad_outs_slotname_map) {
const std::string& grad_output_name = iter.first;
const std::string& matched_input_name = iter.second;
if (matched_input_name == input_name) {
grad_outs_slotname_map_purified.erase(grad_output_name);
PADDLE_ENFORCE(
grad_outs->count(grad_output_name) > 0,
paddle::platform::errors::Fatal(
"Unable to find gradient output name in grad_outs."));
// grad_outs
grad_outs->erase(grad_output_name);
}
}
*grad_outs_slotname_map = grad_outs_slotname_map_purified;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if (grad_ins_fwd_slotname_map->count(input_name))
grad_ins_fwd_slotname_map->erase(input_name);
// grad_ins: output as tensorwrapper
if (grad_ins->count(input_name)) grad_ins->erase(input_name);
}
}
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
std::string output_name = output.name();
// Delete dispensable tensor unless specified in op_outs_map
if (output.dispensable()) {
if (!op_outs_map.count(op_name) ||
!op_outs_map[op_name].count(output_name)) {
VLOG(6) << "Removing Dispensable Output: " << output_name;
// grad_ins_grad_slotname_map
auto grad_ins_grad_slotname_map_purified = *grad_ins_grad_slotname_map;
for (const auto& iter : *grad_ins_grad_slotname_map) {
const std::string& grad_input_name = iter.first;
const std::string& matched_output_name = iter.second;
if (matched_output_name == output_name) {
grad_ins_grad_slotname_map_purified.erase(grad_input_name);
PADDLE_ENFORCE(
grad_ins->count(grad_input_name) > 0,
paddle::platform::errors::Fatal(
"Unable to find gradient input name in grad_ins."));
// grad_ins
grad_ins->erase(grad_input_name);
}
}
*grad_ins_grad_slotname_map = grad_ins_grad_slotname_map_purified;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if (grad_ins_fwd_slotname_map->count(output_name))
grad_ins_fwd_slotname_map->erase(output_name);
// grad_ins: output as tensorwrapper
if (grad_ins->count(output_name)) grad_ins->erase(output_name);
}
}
}
}
/* -------------------------------- */ /* -------------------------------- */
/* --------- Collect Info --------- */ /* --------- Collect Info --------- */
/* -------------------------------- */ /* -------------------------------- */
...@@ -980,6 +890,13 @@ static std::string GenerateGradNodeCreationContent( ...@@ -980,6 +890,13 @@ static std::string GenerateGradNodeCreationContent(
get_autograd_meta_str += paddle::string::Sprintf( get_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name); GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);
} else if (input.dispensable()) {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);
} else { } else {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE = const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta& %s = " " egr::AutogradMeta& %s = "
...@@ -1068,6 +985,24 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1068,6 +985,24 @@ static std::string GenerateGradNodeCreationContent(
for (const proto::OpProto::Var& input : in_vars) { for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name(); const std::string& input_name = input.name();
const std::string& input_autograd_name = "p_autograd_" + input_name; const std::string& input_autograd_name = "p_autograd_" + input_name;
if (input.dispensable() && !input.duplicable()) {
compute_require_grad_args += ", " + input_autograd_name;
size_t input_position = fwd_inputs_name_pos_map.at(input_name);
const char* SET_GRAD_OUT_META_TEMPLATE =
" if(%s) grad_node->SetGradOutMeta(*%s, %d);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_GRAD_OUT_META_TEMPLATE, input_autograd_name, input_autograd_name,
input_position);
const char* ADD_EDGES_TEMPLATE =
" if(%s) grad_node->AddEdges(*%s, %d);\n";
grad_node_creation_str +=
paddle::string::Sprintf(ADD_EDGES_TEMPLATE, input_autograd_name,
input_autograd_name, input_position);
} else {
compute_require_grad_args += ", &" + input_autograd_name; compute_require_grad_args += ", &" + input_autograd_name;
size_t input_position = fwd_inputs_name_pos_map.at(input_name); size_t input_position = fwd_inputs_name_pos_map.at(input_name);
...@@ -1080,6 +1015,7 @@ static std::string GenerateGradNodeCreationContent( ...@@ -1080,6 +1015,7 @@ static std::string GenerateGradNodeCreationContent(
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str += paddle::string::Sprintf(
ADD_EDGES_TEMPLATE, input_autograd_name, input_position); ADD_EDGES_TEMPLATE, input_autograd_name, input_position);
} }
}
// [GradOpNode] SetGradInMeta // [GradOpNode] SetGradInMeta
// [AutogradMeta] SetOutRank // [AutogradMeta] SetOutRank
...@@ -1188,6 +1124,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1188,6 +1124,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
for (const proto::OpProto::Var& input : in_vars) { for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name(); const std::string& input_name = input.name();
size_t input_position = fwd_inputs_name_pos_map.at(input_name); size_t input_position = fwd_inputs_name_pos_map.at(input_name);
if (input.duplicable()) { if (input.duplicable()) {
const char* FWD_INS_ARG_TEMPLATE = const char* FWD_INS_ARG_TEMPLATE =
"const std::vector<egr::EagerTensor>& %s"; "const std::vector<egr::EagerTensor>& %s";
...@@ -1198,6 +1135,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1198,6 +1135,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
input_args_str_list[input_position] = input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
} }
if (input.dispensable()) continue;
const char* FWD_INS_CONTENT_TEMPLATE = const char* FWD_INS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(%s) },"; "{ \"%s\", egr::EagerUtils::SyncToVars(%s) },";
ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE, ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE,
...@@ -1222,6 +1162,26 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1222,6 +1162,26 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
generated_function_body += ins_map_str; generated_function_body += ins_map_str;
generated_function_body += "\n"; generated_function_body += "\n";
// Handle Dispensable Inputs
for (const proto::OpProto::Var& input : in_vars) {
const std::string& input_name = input.name();
if (input.dispensable()) {
if (input.duplicable()) {
const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.size() > 0) "
"ins[\"%s\"] = egr::EagerUtils::SyncToVars(%s)\n;";
generated_function_body += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
} else {
const char* FWD_INS_CONTENT_TEMPLATE =
" if(%s.initialized()) "
"ins[\"%s\"] = egr::EagerUtils::SyncToVars(%s)\n;";
generated_function_body += paddle::string::Sprintf(
FWD_INS_CONTENT_TEMPLATE, input_name, input_name, input_name);
}
}
}
VLOG(6) << "Generated Ins Map"; VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map // [Generation] Get Outs Map
......
...@@ -53,6 +53,12 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap( ...@@ -53,6 +53,12 @@ std::unordered_map<GradNodeBase*, int> getInDegreeMap(
for (const auto& edge_list : edges) { for (const auto& edge_list : edges) {
for (const Edge& edge : edge_list) { for (const Edge& edge : edge_list) {
GradNodeBase* next_node = edge.GetMutableGradNode().get(); GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
// Update in_degree // Update in_degree
if (!node_in_degree_map.count(next_node)) if (!node_in_degree_map.count(next_node))
node_in_degree_map[next_node] = 0; node_in_degree_map[next_node] = 0;
...@@ -91,11 +97,6 @@ void RunBackward(const std::vector<egr::EagerTensor>& tensors, ...@@ -91,11 +97,6 @@ void RunBackward(const std::vector<egr::EagerTensor>& tensors,
// Get target GradNodeBase from target tensors // Get target GradNodeBase from target tensors
GradNodeBase* grad_node = auto_grad_meta->GetMutableGradNode().get(); GradNodeBase* grad_node = auto_grad_meta->GetMutableGradNode().get();
PADDLE_ENFORCE(grad_node,
paddle::platform::errors::Fatal(
"Detected null grad_node."
"Grad Node is nullptr for grad input tensor %d",
i));
// Prepare GradTensorHolder // Prepare GradTensorHolder
if (!node_input_buffers_dict.count(grad_node)) { if (!node_input_buffers_dict.count(grad_node)) {
VLOG(6) << "Create Value for grad input tensor " << i; VLOG(6) << "Create Value for grad input tensor " << i;
...@@ -186,6 +187,11 @@ void RunBackward(const std::vector<egr::EagerTensor>& tensors, ...@@ -186,6 +187,11 @@ void RunBackward(const std::vector<egr::EagerTensor>& tensors,
} }
GradNodeBase* next_node = edge.GetMutableGradNode().get(); GradNodeBase* next_node = edge.GetMutableGradNode().get();
// Next node could be nullptr if it is leaf tensor with no
// AccumulationNode attached
// Or it could also originated from dispensable inputs
if (!next_node) continue;
if (!node_input_buffers_dict.count(next_node)) { if (!node_input_buffers_dict.count(next_node)) {
node_input_buffers_dict[next_node] = node_input_buffers_dict[next_node] =
std::make_unique<GradTensorHolder>(next_node->InputMeta()); std::make_unique<GradTensorHolder>(next_node->InputMeta());
......
...@@ -56,6 +56,14 @@ std::vector<AutogradMeta*> EagerUtils::unsafe_autograd_meta( ...@@ -56,6 +56,14 @@ std::vector<AutogradMeta*> EagerUtils::unsafe_autograd_meta(
return metas; return metas;
} }
AutogradMeta* EagerUtils::nullable_autograd_meta(
const egr::EagerTensor& target) {
auto* p_autograd_meta = target.get_autograd_meta();
if (!p_autograd_meta) return nullptr;
return static_cast<AutogradMeta*>(p_autograd_meta);
}
std::vector<AutogradMeta*> EagerUtils::multi_autograd_meta( std::vector<AutogradMeta*> EagerUtils::multi_autograd_meta(
std::vector<egr::EagerTensor>* targets) { std::vector<egr::EagerTensor>* targets) {
std::vector<AutogradMeta*> ret; std::vector<AutogradMeta*> ret;
......
...@@ -56,6 +56,9 @@ class ComputeRequireGradIter : public IterHelper<AutogradMeta*> { ...@@ -56,6 +56,9 @@ class ComputeRequireGradIter : public IterHelper<AutogradMeta*> {
private: private:
void visit(AutogradMeta* element) override { void visit(AutogradMeta* element) override {
// Dispensable Tensors feeds in nullptr autograd_meta
if (!element) return;
bool stop_gradient = element->StopGradient(); bool stop_gradient = element->StopGradient();
if (!stop_gradient) require_grad_ = true; if (!stop_gradient) require_grad_ = true;
} }
...@@ -112,6 +115,7 @@ class EagerUtils { ...@@ -112,6 +115,7 @@ class EagerUtils {
static void SetOutRankWithSlot(AutogradMeta* target, size_t slot_id); static void SetOutRankWithSlot(AutogradMeta* target, size_t slot_id);
// This method will return an AutogradMeta pointer unsafely. // This method will return an AutogradMeta pointer unsafely.
static AutogradMeta* nullable_autograd_meta(const egr::EagerTensor& target);
static AutogradMeta* unsafe_autograd_meta(const egr::EagerTensor& target); static AutogradMeta* unsafe_autograd_meta(const egr::EagerTensor& target);
static std::vector<AutogradMeta*> unsafe_autograd_meta( static std::vector<AutogradMeta*> unsafe_autograd_meta(
const std::vector<egr::EagerTensor>& targets); const std::vector<egr::EagerTensor>& targets);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册