未验证 提交 d422a1ed 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Handled special sum_grad_op code gen in Eager Dygraph (#38573)

* Handled special sum_grad_op code gen in Eager Dygraph

* Fixed merge issues
上级 89c0877e
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#define NUM_CREATED_DUP_INPUTS 4
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -46,6 +48,62 @@ static std::string LegalizeVariableName(const std::string& var_name) { ...@@ -46,6 +48,62 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret; return ret;
} }
static bool IgnoreGradAttribute(const std::string& op_type,
const std::string& attr_name) {
// Attributes in operators_with_attrs are created manually during code
// generation
// We should ignore these arbitrary attrs when setting up grad attribute map
if (operators_with_attrs.count(op_type)) {
if (operators_with_attrs[op_type].count(attr_name)) {
return true;
}
}
return false;
}
static void PrepareAttrMapForOps() {
// Handle "run_program_op"
static framework::ProgramDesc fake_prog;
operators_with_attrs["run_program"] = {};
operators_with_attrs["run_program"]["global_block"] =
fake_prog.MutableBlock(0);
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
// Handle "c_split"
operators_with_attrs["c_split"] = {};
operators_with_attrs["c_split"]["nranks"] = 1;
}
/* --- Helper Objects --- */ /* --- Helper Objects --- */
class ForwardGenerationInfo { class ForwardGenerationInfo {
public: public:
...@@ -136,6 +194,13 @@ class GradNodeGenerationInfo { ...@@ -136,6 +194,13 @@ class GradNodeGenerationInfo {
return &grad_outs_; return &grad_outs_;
} }
const paddle::framework::AttributeMap& GetGradAttrs() const {
return grad_attrs_;
}
paddle::framework::AttributeMap* GetMutableGradAttrs() {
return &grad_attrs_;
}
private: private:
std::string op_base_type_; std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_; std::map<std::string, std::string> grad_outs_slotname_map_;
...@@ -147,6 +212,7 @@ class GradNodeGenerationInfo { ...@@ -147,6 +212,7 @@ class GradNodeGenerationInfo {
std::map<std::string, std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>> std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs_; grad_outs_;
paddle::framework::AttributeMap grad_attrs_;
}; };
public: public:
...@@ -677,6 +743,25 @@ static bool CollectGradInformationFromOpInfo( ...@@ -677,6 +743,25 @@ static bool CollectGradInformationFromOpInfo(
std::map<std::string, std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VarBase>>> std::vector<std::shared_ptr<paddle::imperative::VarBase>>>
ins; ins;
if (op_proto.inputs().size() == 1 && op_proto.outputs().size() == 1 &&
op_proto.inputs()[0].duplicable() &&
!op_proto.outputs()[0].duplicable()) {
VLOG(6) << "Handle op with special op_bases: " << op_type;
// @special case (sum_op): for ops with single duplicable input and single
// non-duplicable output
// feed in NUM_CREATED_DUP_INPUTS inputs to detect a
// special scenario.
const std::string& in_name = op_proto.inputs()[0].name();
ins[in_name] = {};
for (size_t i = 0; i < NUM_CREATED_DUP_INPUTS; i++) {
ins[in_name].emplace_back(std::shared_ptr<paddle::imperative::VarBase>(
new paddle::imperative::VarBase("auto_" + in_name + "_" +
std::to_string(i))));
ins[in_name][i]->SetOverridedStopGradient(false);
ins[in_name][i]->MutableVar()->GetMutable<framework::LoDTensor>();
}
} else {
for (const proto::OpProto::Var& input : op_proto.inputs()) { for (const proto::OpProto::Var& input : op_proto.inputs()) {
const std::string& in_name = input.name(); const std::string& in_name = input.name();
...@@ -694,11 +779,13 @@ static bool CollectGradInformationFromOpInfo( ...@@ -694,11 +779,13 @@ static bool CollectGradInformationFromOpInfo(
// We dont know the exact number of inputs expected, // We dont know the exact number of inputs expected,
// but we only need to identify the slot name order, // but we only need to identify the slot name order,
// therefore fill in 1 single input VarBase is enough in this scenario // therefore fill in 1 single input VarBase is enough in this scenario
ins[in_name] = {std::shared_ptr<paddle::imperative::VarBase>( ins[in_name] = {std::shared_ptr<paddle::imperative::VarBase>(
new paddle::imperative::VarBase("auto_" + in_name))}; new paddle::imperative::VarBase("auto_" + in_name))};
ins[in_name][0]->SetOverridedStopGradient(false); ins[in_name][0]->SetOverridedStopGradient(false);
ins[in_name][0]->MutableVar()->GetMutable<framework::LoDTensor>(); ins[in_name][0]->MutableVar()->GetMutable<framework::LoDTensor>();
} }
}
VLOG(6) << "Prepared Forward Ins Map, size = " << ins.size(); VLOG(6) << "Prepared Forward Ins Map, size = " << ins.size();
/* ------ Prepare "outs" ------ */ /* ------ Prepare "outs" ------ */
...@@ -725,7 +812,6 @@ static bool CollectGradInformationFromOpInfo( ...@@ -725,7 +812,6 @@ static bool CollectGradInformationFromOpInfo(
VLOG(6) << "Checking AttributeMap Settings"; VLOG(6) << "Checking AttributeMap Settings";
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true); attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
default_attrs = attr_checker->GetDefaultAttrMap(); default_attrs = attr_checker->GetDefaultAttrMap();
VLOG(6) << "AttributeMap Checking Passed";
} else { } else {
VLOG(6) << "Detected Null Attribute Checker, use empty default_attrs"; VLOG(6) << "Detected Null Attribute Checker, use empty default_attrs";
} }
...@@ -797,13 +883,13 @@ static bool CollectGradInformationFromOpInfo( ...@@ -797,13 +883,13 @@ static bool CollectGradInformationFromOpInfo(
(*op_base_infos)[index].SetOpBaseType(op_base.Type()); (*op_base_infos)[index].SetOpBaseType(op_base.Type());
} }
/* ------ Get Grad ins/outs ---- */ /* ------ Get Grad ins/outs/attrs ---- */
// In case of multiple OpBase, stitch all the respective ins/outs into one
VLOG(6) << "In function size: " << grad_node->size(); VLOG(6) << "In function size: " << grad_node->size();
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
int index = std::distance(grad_node->begin(), iter); int index = std::distance(grad_node->begin(), iter);
auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns(); auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns();
auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts(); auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts();
auto* op_base_grad_attrs = (*op_base_infos)[index].GetMutableGradAttrs();
const paddle::imperative::OpBase& op_base = *iter; const paddle::imperative::OpBase& op_base = *iter;
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>& const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
...@@ -811,6 +897,8 @@ static bool CollectGradInformationFromOpInfo( ...@@ -811,6 +897,8 @@ static bool CollectGradInformationFromOpInfo(
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>& const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
g_outs = op_base.GetOutsMap(); g_outs = op_base.GetOutsMap();
*op_base_grad_attrs = op_base.Attrs();
for (const auto& it : g_ins) { for (const auto& it : g_ins) {
if (!op_base_grad_ins->count(it.first)) if (!op_base_grad_ins->count(it.first))
(*op_base_grad_ins)[it.first] = {}; (*op_base_grad_ins)[it.first] = {};
...@@ -1395,84 +1483,29 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1395,84 +1483,29 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
return {fwd_function_str, dygraph_function_declaration_str}; return {fwd_function_str, dygraph_function_declaration_str};
} }
/* ---------------------------------------------- */ static std::string GenerateSingleOpBase(
/* --------- CodeGen: GradNode::operator() ------ */ const std::string& fwd_op_type, const std::string& op_base_type,
/* ---------------------------------------------- */ const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
static std::string GenerateGradNodeCCContents( const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const ForwardGenerationInfo& fwd_info, const std::vector<proto::OpProto::Var>& in_vars,
const GradNodeGenerationInfo& bwd_info) { const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
/* --- Process Forward Info --- */ const std::map<std::string, std::string>& grad_ins_grad_slotname_map,
const std::string& fwd_op_type = fwd_info.GetOpType(); const std::map<std::string, std::string>& grad_outs_slotname_map,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map = const std::map<
fwd_info.GetFwdInputsNamePosMap(); std::string,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map = std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
fwd_info.GetFwdOutputsNamePosMap(); grad_ins,
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars(); const std::map<
std::string,
VLOG(6) << "Generating Grad Node CC"; std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
/* [Outline] const paddle::framework::AttributeMap& grad_attrs,
bool is_op_base_per_duplicable_input, size_t* outs_size) {
vector<vector<Tensor>> GradNodeXXX::operator()(vector<vector<Tensor>>& grads)
{
const std::shared_ptr<Tracer>& tracer = imperative::GetCurrentTracer();
// Comes from "grad_ins"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> ins =
{
"X" : this->"X", "Y" : this->"Y",
"Out0@Grad":
SyncToVars(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out0@Grad"]]"]),
"Out1@Grad":
TensorsToVarBases(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out1@Grad"]]"])
};
// Comes from "grad_outs"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> outs =
{
"X@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["X@Grad"]]"].Size()),
"Y@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["Y@Grad"]]"].Size())
};
// Visit each OpBase
for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) {
// Simply pass entire attribute map to kernels
egr::legacy::RunOp("iter->Type()", ins, outs, this->attr_map_,
egr::Controller::Instance().ExpectedPlace(), false, {});
}
vector<vector<egr::EagerTensor>> outputs(outs.size());
for(auto& kv : outs) {
outputs["fwd_inputs_name_pos_map[grad_outs_slotname_map[kv.first]]"] =
GetOutputs(outs["kv.first"]);
}
return outputs;
}
*/
std::string generated_grad_function_body = ""; std::string generated_grad_function_body = "";
size_t outs_size = 0;
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
for (size_t i = 0; i < op_base_infos.size(); i++) {
const auto& op_base_info = op_base_infos[i];
const auto& grad_ins_fwd_slotname_map =
op_base_info.GetGradInsFwdSlotnameMap();
const auto& grad_ins_grad_slotname_map =
op_base_info.GetGradInsGradSlotnameMap();
const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap();
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const std::string& op_base_type = op_base_info.GetOpBaseType();
const std::string& ins_name = "ins" + std::to_string(i);
const std::string& outs_name = "outs" + std::to_string(i);
outs_size += grad_outs.size(); const std::string& ins_name = "ins" + std::to_string(*outs_size);
const std::string& outs_name = "outs" + std::to_string(*outs_size);
const std::string& attrs_name = "attrs_map" + std::to_string(*outs_size);
// [Generation] Get Ins Map // [Generation] Get Ins Map
std::string ins_contents_str = ""; std::string ins_contents_str = "";
...@@ -1499,9 +1532,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1499,9 +1532,8 @@ static std::string GenerateGradNodeCCContents(
grad_ins_grad_slotname_map.at(grad_input_name)); grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },"; "{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str += ins_contents_str += paddle::string::Sprintf(
paddle::string::Sprintf(GRAD_INS_GRAD_CONTENT_TEMPLATE, GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
grad_input_name, fwd_output_position);
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
...@@ -1517,8 +1549,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1517,8 +1549,8 @@ static std::string GenerateGradNodeCCContents(
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { " "std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n"; "%s };\n";
std::string ins_map_str = paddle::string::Sprintf( std::string ins_map_str =
BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str); paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
generated_grad_function_body += ins_map_str; generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map"; VLOG(6) << "Generated Ins Map";
...@@ -1535,8 +1567,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -1535,8 +1567,7 @@ static std::string GenerateGradNodeCCContents(
if (grad_outs_slotname_map.count(grad_output_name)) { if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor // Fwd Tensor
const std::string& fwd_name = const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name);
grad_outs_slotname_map.at(grad_output_name);
/* Handle Special Case: "PullSparseOp", etc /* Handle Special Case: "PullSparseOp", etc
...@@ -1569,8 +1600,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -1569,8 +1600,7 @@ static std::string GenerateGradNodeCCContents(
For returns, append "GradOut" to the very end of return list. For returns, append "GradOut" to the very end of return list.
*/ */
if (!fwd_inputs_name_pos_map.count(fwd_name)) { if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE( PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor " "fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map")); "fwd_outputs_name_pos_map"));
...@@ -1584,13 +1614,13 @@ static std::string GenerateGradNodeCCContents( ...@@ -1584,13 +1614,13 @@ static std::string GenerateGradNodeCCContents(
} else { } else {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (duplicable_input_name_set.count(fwd_name)) { if (duplicable_input_name_set.count(fwd_name) &&
!is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },"; "this->OutputMeta()[%d].Size() ) },";
outs_contents_str += outs_contents_str += paddle::string::Sprintf(
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
grad_output_name, fwd_input_position);
} else { } else {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", " "{ \"%s\", "
...@@ -1623,15 +1653,32 @@ static std::string GenerateGradNodeCCContents( ...@@ -1623,15 +1653,32 @@ static std::string GenerateGradNodeCCContents(
VLOG(6) << "Generated Outs Map"; VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs Map // [Generation] Get Attrs Map
const char* ATTRS_TEMPLATE = " auto %s = this->attr_map_;\n";
std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
for (const auto& iter : grad_attrs) {
if (IgnoreGradAttribute(fwd_op_type, iter.first)) continue;
std::pair<std::string, std::string> type_val =
GetAttrType(iter.second, false /*is_arg*/);
const char* GRAD_ATTRS_TEMPLATE =
" %s %s = %s;\n"
" %s[\"%s\"] = %s;\n";
std::string var_name = iter.first + std::to_string(*outs_size);
grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, type_val.first, var_name, type_val.second,
attrs_name, iter.first, var_name);
}
generated_grad_function_body += grad_attrs_str;
const char* TRACE_OP_TEMPLATE = const char* TRACE_OP_TEMPLATE =
" // Pass the entire attribute map to TraceOp\n" " // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need " " // The underlying kernel will pickup whatever attribute they need "
"at runtime\n" "at runtime\n"
" egr::legacy::RunOp(\"%s\", %s, %s, this->attr_map_,\n" " egr::legacy::RunOp(\"%s\", %s, %s, %s,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n" " egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n"; " &this->default_attr_map_, false, {});\n";
std::string trace_opbase_str = paddle::string::Sprintf( std::string trace_opbase_str = paddle::string::Sprintf(
TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name); TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name, attrs_name);
generated_grad_function_body += trace_opbase_str; generated_grad_function_body += trace_opbase_str;
...@@ -1646,10 +1693,19 @@ static std::string GenerateGradNodeCCContents( ...@@ -1646,10 +1693,19 @@ static std::string GenerateGradNodeCCContents(
if (fwd_inputs_name_pos_map.count(fwd_name)) { if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (!is_op_base_per_duplicable_input) {
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf( outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name); BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name);
} else {
const char* BWD_OUTPUT_TEMPLATE =
" "
"outputs[0].emplace_back(egr::EagerUtils::GetOutputs(%s[\"%s\"])[0]"
");\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outs_name,
grad_out_name);
}
num_appended_outputs++; num_appended_outputs++;
} else { } else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
...@@ -1668,15 +1724,127 @@ static std::string GenerateGradNodeCCContents( ...@@ -1668,15 +1724,127 @@ static std::string GenerateGradNodeCCContents(
if (fwd_outputs_name_pos_map.count(fwd_name)) { if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n"; " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += outputs_str += paddle::string::Sprintf(
paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, num_appended_outputs, BWD_OUTPUT_TEMPLATE, num_appended_outputs, outs_name, grad_out_name);
outs_name, grad_out_name);
num_appended_outputs++; num_appended_outputs++;
} }
} }
generated_grad_function_body += outputs_str; generated_grad_function_body += outputs_str;
generated_grad_function_body += "\n"; generated_grad_function_body += "\n";
*outs_size += grad_outs.size();
return generated_grad_function_body;
}
/* ---------------------------------------------- */
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
static std::string GenerateGradNodeCCContents(
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
/* --- Process Forward Info --- */
const std::string& fwd_op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
fwd_info.GetFwdInputsNamePosMap();
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
VLOG(6) << "Generating Grad Node CC";
/* [Outline]
vector<vector<Tensor>> GradNodeXXX::operator()(vector<vector<Tensor>>& grads)
{
const std::shared_ptr<Tracer>& tracer = imperative::GetCurrentTracer();
// Comes from "grad_ins"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> ins =
{
"X" : this->"X", "Y" : this->"Y",
"Out0@Grad":
SyncToVars(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out0@Grad"]]"]),
"Out1@Grad":
TensorsToVarBases(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out1@Grad"]]"])
};
// Comes from "grad_outs"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> outs =
{
"X@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["X@Grad"]]"].Size()),
"Y@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["Y@Grad"]]"].Size())
};
// Visit each OpBase
for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) {
// Simply pass entire attribute map to kernels
egr::legacy::RunOp("iter->Type()", ins, outs, this->attr_map_,
egr::Controller::Instance().ExpectedPlace(), false, {});
}
vector<vector<egr::EagerTensor>> outputs(outs.size());
for(auto& kv : outs) {
outputs["fwd_inputs_name_pos_map[grad_outs_slotname_map[kv.first]]"] =
GetOutputs(outs["kv.first"]);
}
return outputs;
}
*/
// This is a Copy
auto op_base_infos = bwd_info.GetOpBaseInfos();
/* Special Case: ops such as sum_grad_op is implemented abnormaly,
where it unpacked duplicable GradX and created one OpBase
corresponds to each member of GradX[i]
*/
bool is_op_base_per_duplicable_input = false;
if (in_vars.size() == 1 && out_vars.size() == 1 && in_vars[0].duplicable() &&
!out_vars[0].duplicable() &&
op_base_infos.size() == NUM_CREATED_DUP_INPUTS) {
is_op_base_per_duplicable_input = true;
// Only keep the first op_base
auto op_base_info = op_base_infos[0];
op_base_infos.clear();
op_base_infos.emplace_back(std::move(op_base_info));
}
std::string generated_grad_function_body = "";
size_t outs_size = 0;
for (size_t i = 0; i < op_base_infos.size(); i++) {
const auto& op_base_info = op_base_infos[i];
const auto& grad_ins_fwd_slotname_map =
op_base_info.GetGradInsFwdSlotnameMap();
const auto& grad_ins_grad_slotname_map =
op_base_info.GetGradInsGradSlotnameMap();
const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap();
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const auto& grad_attrs = op_base_info.GetGradAttrs();
const std::string& op_base_type = op_base_info.GetOpBaseType();
generated_grad_function_body += GenerateSingleOpBase(
fwd_op_type, op_base_type, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, in_vars, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs,
grad_attrs, is_op_base_per_duplicable_input, &outs_size);
}
if (is_op_base_per_duplicable_input) {
const char* OP_BASE_PER_DUP_INPUT_TEMPLATE =
" for(int i = 0; i < this->OutputMeta()[0].Size(); i++) {\n"
" %s\n"
" }\n";
generated_grad_function_body = paddle::string::Sprintf(
OP_BASE_PER_DUP_INPUT_TEMPLATE, generated_grad_function_body);
} }
const char* BWD_RETURN_TEMPLATE = const char* BWD_RETURN_TEMPLATE =
...@@ -2045,47 +2213,6 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -2045,47 +2213,6 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateNodeCCFile(node_cc_path, grad_node_cc_str); GenerateNodeCCFile(node_cc_path, grad_node_cc_str);
} }
static void PrepareAttrMapForOps() {
// Handle "run_program_op"
static framework::ProgramDesc fake_prog;
operators_with_attrs["run_program"] = {};
operators_with_attrs["run_program"]["global_block"] =
fake_prog.MutableBlock(0);
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
// Handle "c_split"
operators_with_attrs["c_split"] = {};
operators_with_attrs["c_split"]["nranks"] = 1;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -321,7 +321,7 @@ class TestImperative(unittest.TestCase): ...@@ -321,7 +321,7 @@ class TestImperative(unittest.TestCase):
with paddle.set_grad_enabled(True): with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled()) self.assertTrue(paddle.is_grad_enabled())
def test_sum_op(self): def func_sum_op(self):
x = np.ones([2, 2], np.float32) x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
inputs = [] inputs = []
...@@ -338,7 +338,7 @@ class TestImperative(unittest.TestCase): ...@@ -338,7 +338,7 @@ class TestImperative(unittest.TestCase):
tmp = paddle.to_tensor(x) tmp = paddle.to_tensor(x)
tmp.stop_gradient = False tmp.stop_gradient = False
inputs2.append(tmp) inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2) ret2 = paddle.add_n(inputs2)
loss2 = fluid.layers.reduce_sum(ret2) loss2 = fluid.layers.reduce_sum(ret2)
fluid.set_flags({'FLAGS_sort_sum_gradient': True}) fluid.set_flags({'FLAGS_sort_sum_gradient': True})
loss2.backward() loss2.backward()
...@@ -349,6 +349,11 @@ class TestImperative(unittest.TestCase): ...@@ -349,6 +349,11 @@ class TestImperative(unittest.TestCase):
a = inputs2[0].gradient() a = inputs2[0].gradient()
self.assertTrue(np.allclose(inputs2[0].gradient(), x)) self.assertTrue(np.allclose(inputs2[0].gradient(), x))
def test_sum_op(self):
with _test_eager_guard():
self.func_sum_op()
self.func_sum_op()
def func_empty_var(self): def func_empty_var(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
cur_program = fluid.Program() cur_program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册