未验证 提交 d422a1ed 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Handled special sum_grad_op code gen in Eager Dygraph (#38573)

* Handled special sum_grad_op code gen in Eager Dygraph

* Fixed merge issues
上级 89c0877e
......@@ -27,6 +27,8 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
#define NUM_CREATED_DUP_INPUTS 4
namespace paddle {
namespace framework {
......@@ -46,6 +48,62 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret;
}
static bool IgnoreGradAttribute(const std::string& op_type,
const std::string& attr_name) {
// Attributes in operators_with_attrs are created manually during code
// generation
// We should ignore these arbitrary attrs when setting up grad attribute map
if (operators_with_attrs.count(op_type)) {
if (operators_with_attrs[op_type].count(attr_name)) {
return true;
}
}
return false;
}
static void PrepareAttrMapForOps() {
// Handle "run_program_op"
static framework::ProgramDesc fake_prog;
operators_with_attrs["run_program"] = {};
operators_with_attrs["run_program"]["global_block"] =
fake_prog.MutableBlock(0);
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
// Handle "c_split"
operators_with_attrs["c_split"] = {};
operators_with_attrs["c_split"]["nranks"] = 1;
}
/* --- Helper Objects --- */
class ForwardGenerationInfo {
public:
......@@ -136,6 +194,13 @@ class GradNodeGenerationInfo {
return &grad_outs_;
}
const paddle::framework::AttributeMap& GetGradAttrs() const {
return grad_attrs_;
}
paddle::framework::AttributeMap* GetMutableGradAttrs() {
return &grad_attrs_;
}
private:
std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_;
......@@ -147,6 +212,7 @@ class GradNodeGenerationInfo {
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs_;
paddle::framework::AttributeMap grad_attrs_;
};
public:
......@@ -677,27 +743,48 @@ static bool CollectGradInformationFromOpInfo(
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VarBase>>>
ins;
for (const proto::OpProto::Var& input : op_proto.inputs()) {
const std::string& in_name = input.name();
// Handle dispensable input:
// 1. At python level, dispensable input will be detected at Python-C
// interface and filled with an empty vector
// 2. At C++ level, customers should always pass an empty vector for any
// dispensable input
// 3. During further lowering, there will always be a placeholder VarBase
// in ins/outs no matter whether it's dispensable or not
// As a result, we always create input VarBase regardless of its
// dispensability.
// Handle duplicable input: list(VarBase) or VarBase
// We dont know the exact number of inputs expected,
// but we only need to identify the slot name order,
// therefore fill in 1 single input VarBase is enough in this scenario
ins[in_name] = {std::shared_ptr<paddle::imperative::VarBase>(
new paddle::imperative::VarBase("auto_" + in_name))};
ins[in_name][0]->SetOverridedStopGradient(false);
ins[in_name][0]->MutableVar()->GetMutable<framework::LoDTensor>();
if (op_proto.inputs().size() == 1 && op_proto.outputs().size() == 1 &&
op_proto.inputs()[0].duplicable() &&
!op_proto.outputs()[0].duplicable()) {
VLOG(6) << "Handle op with special op_bases: " << op_type;
// @special case (sum_op): for ops with single duplicable input and single
// non-duplicable output
// feed in NUM_CREATED_DUP_INPUTS inputs to detect a
// special scenario.
const std::string& in_name = op_proto.inputs()[0].name();
ins[in_name] = {};
for (size_t i = 0; i < NUM_CREATED_DUP_INPUTS; i++) {
ins[in_name].emplace_back(std::shared_ptr<paddle::imperative::VarBase>(
new paddle::imperative::VarBase("auto_" + in_name + "_" +
std::to_string(i))));
ins[in_name][i]->SetOverridedStopGradient(false);
ins[in_name][i]->MutableVar()->GetMutable<framework::LoDTensor>();
}
} else {
for (const proto::OpProto::Var& input : op_proto.inputs()) {
const std::string& in_name = input.name();
// Handle dispensable input:
// 1. At python level, dispensable input will be detected at Python-C
// interface and filled with an empty vector
// 2. At C++ level, customers should always pass an empty vector for any
// dispensable input
// 3. During further lowering, there will always be a placeholder VarBase
// in ins/outs no matter whether it's dispensable or not
// As a result, we always create input VarBase regardless of its
// dispensability.
// Handle duplicable input: list(VarBase) or VarBase
// We dont know the exact number of inputs expected,
// but we only need to identify the slot name order,
// therefore fill in 1 single input VarBase is enough in this scenario
ins[in_name] = {std::shared_ptr<paddle::imperative::VarBase>(
new paddle::imperative::VarBase("auto_" + in_name))};
ins[in_name][0]->SetOverridedStopGradient(false);
ins[in_name][0]->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
VLOG(6) << "Prepared Forward Ins Map, size = " << ins.size();
......@@ -725,7 +812,6 @@ static bool CollectGradInformationFromOpInfo(
VLOG(6) << "Checking AttributeMap Settings";
attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
default_attrs = attr_checker->GetDefaultAttrMap();
VLOG(6) << "AttributeMap Checking Passed";
} else {
VLOG(6) << "Detected Null Attribute Checker, use empty default_attrs";
}
......@@ -797,13 +883,13 @@ static bool CollectGradInformationFromOpInfo(
(*op_base_infos)[index].SetOpBaseType(op_base.Type());
}
/* ------ Get Grad ins/outs ---- */
// In case of multiple OpBase, stitch all the respective ins/outs into one
/* ------ Get Grad ins/outs/attrs ---- */
VLOG(6) << "In function size: " << grad_node->size();
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
int index = std::distance(grad_node->begin(), iter);
auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns();
auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts();
auto* op_base_grad_attrs = (*op_base_infos)[index].GetMutableGradAttrs();
const paddle::imperative::OpBase& op_base = *iter;
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
......@@ -811,6 +897,8 @@ static bool CollectGradInformationFromOpInfo(
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
g_outs = op_base.GetOutsMap();
*op_base_grad_attrs = op_base.Attrs();
for (const auto& it : g_ins) {
if (!op_base_grad_ins->count(it.first))
(*op_base_grad_ins)[it.first] = {};
......@@ -1395,6 +1483,261 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
return {fwd_function_str, dygraph_function_declaration_str};
}
static std::string GenerateSingleOpBase(
const std::string& fwd_op_type, const std::string& op_base_type,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::vector<proto::OpProto::Var>& in_vars,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::map<std::string, std::string>& grad_ins_grad_slotname_map,
const std::map<std::string, std::string>& grad_outs_slotname_map,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_ins,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const paddle::framework::AttributeMap& grad_attrs,
bool is_op_base_per_duplicable_input, size_t* outs_size) {
std::string generated_grad_function_body = "";
const std::string& ins_name = "ins" + std::to_string(*outs_size);
const std::string& outs_name = "outs" + std::to_string(*outs_size);
const std::string& attrs_name = "attrs_map" + std::to_string(*outs_size);
// [Generation] Get Ins Map
std::string ins_contents_str = "";
for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first;
if (grad_ins_fwd_slotname_map.count(grad_input_name)) {
// Fwd Tensor
std::string struct_fwd_input_name =
grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", "
"egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper("
"&"
"this->%s, "
"nullptr)) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name);
} else if (grad_ins_grad_slotname_map.count(grad_input_name)) {
// Fwd Tensor's Grad
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_input_name));
}
}
if (ins_contents_str.size() > 0)
ins_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n";
std::string ins_map_str =
paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::unordered_set<std::string> duplicable_input_name_set;
for (const auto& in : in_vars) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name());
}
std::string outs_contents_str = "";
for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first;
if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor
const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name);
/* Handle Special Case: "PullSparseOp", etc
Forward:
Ids W
| |
PullSparseOp
|
Out
Backward:
Ids GradOut W
| | |
PullSparseGradOp
|
GradOut
Its grad output "GradOut" corresponds to forward output "Out",
where there is a hiden inplace involved. So we find "GradOut"'s
index
in
grads, and perform the inplace operation by constructing outs =
{{"Out", grads[i]}}
GradOut -> Out -> fwd_output_pos -> grads position -> grads[i]
outs = {{"Out", grads[i]}}
For returns, append "GradOut" to the very end of return list.
*/
if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grads_position);
} else {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (duplicable_input_name_set.count(fwd_name) &&
!is_op_base_per_duplicable_input) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance("
")."
"GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name);
}
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_output_name));
}
}
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n";
std::string outs_map_str = paddle::string::Sprintf(
BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
generated_grad_function_body += outs_map_str;
generated_grad_function_body += "\n";
VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs Map
const char* ATTRS_TEMPLATE = " auto %s = this->attr_map_;\n";
std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
for (const auto& iter : grad_attrs) {
if (IgnoreGradAttribute(fwd_op_type, iter.first)) continue;
std::pair<std::string, std::string> type_val =
GetAttrType(iter.second, false /*is_arg*/);
const char* GRAD_ATTRS_TEMPLATE =
" %s %s = %s;\n"
" %s[\"%s\"] = %s;\n";
std::string var_name = iter.first + std::to_string(*outs_size);
grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, type_val.first, var_name, type_val.second,
attrs_name, iter.first, var_name);
}
generated_grad_function_body += grad_attrs_str;
const char* TRACE_OP_TEMPLATE =
" // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need "
"at runtime\n"
" egr::legacy::RunOp(\"%s\", %s, %s, %s,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n";
std::string trace_opbase_str = paddle::string::Sprintf(
TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name, attrs_name);
generated_grad_function_body += trace_opbase_str;
VLOG(6) << "Generated Attrs Map";
// [Generation] Get Return
std::string outputs_str = "";
size_t num_appended_outputs = 0;
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (!is_op_base_per_duplicable_input) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name);
} else {
const char* BWD_OUTPUT_TEMPLATE =
" "
"outputs[0].emplace_back(egr::EagerUtils::GetOutputs(%s[\"%s\"])[0]"
");\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outs_name,
grad_out_name);
}
num_appended_outputs++;
} else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
}
}
/* Handle Special Case: "PullSparseOp", etc
For returns, append "GradOut" to the very end of return list. */
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, num_appended_outputs, outs_name, grad_out_name);
num_appended_outputs++;
}
}
generated_grad_function_body += outputs_str;
generated_grad_function_body += "\n";
*outs_size += grad_outs.size();
return generated_grad_function_body;
}
/* ---------------------------------------------- */
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
......@@ -1408,6 +1751,7 @@ static std::string GenerateGradNodeCCContents(
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
VLOG(6) << "Generating Grad Node CC";
......@@ -1454,9 +1798,26 @@ static std::string GenerateGradNodeCCContents(
}
*/
// This is a Copy
auto op_base_infos = bwd_info.GetOpBaseInfos();
/* Special Case: ops such as sum_grad_op is implemented abnormaly,
where it unpacked duplicable GradX and created one OpBase
corresponds to each member of GradX[i]
*/
bool is_op_base_per_duplicable_input = false;
if (in_vars.size() == 1 && out_vars.size() == 1 && in_vars[0].duplicable() &&
!out_vars[0].duplicable() &&
op_base_infos.size() == NUM_CREATED_DUP_INPUTS) {
is_op_base_per_duplicable_input = true;
// Only keep the first op_base
auto op_base_info = op_base_infos[0];
op_base_infos.clear();
op_base_infos.emplace_back(std::move(op_base_info));
}
std::string generated_grad_function_body = "";
size_t outs_size = 0;
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
for (size_t i = 0; i < op_base_infos.size(); i++) {
const auto& op_base_info = op_base_infos[i];
......@@ -1467,216 +1828,23 @@ static std::string GenerateGradNodeCCContents(
const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap();
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const auto& grad_attrs = op_base_info.GetGradAttrs();
const std::string& op_base_type = op_base_info.GetOpBaseType();
const std::string& ins_name = "ins" + std::to_string(i);
const std::string& outs_name = "outs" + std::to_string(i);
outs_size += grad_outs.size();
// [Generation] Get Ins Map
std::string ins_contents_str = "";
for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first;
if (grad_ins_fwd_slotname_map.count(grad_input_name)) {
// Fwd Tensor
std::string struct_fwd_input_name =
grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", "
"egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper("
"&"
"this->%s, "
"nullptr)) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name);
} else if (grad_ins_grad_slotname_map.count(grad_input_name)) {
// Fwd Tensor's Grad
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_GRAD_CONTENT_TEMPLATE,
grad_input_name, fwd_output_position);
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_input_name));
}
}
if (ins_contents_str.size() > 0)
ins_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n";
std::string ins_map_str = paddle::string::Sprintf(
BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::unordered_set<std::string> duplicable_input_name_set;
for (const auto& in : in_vars) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name());
}
std::string outs_contents_str = "";
for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first;
if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor
const std::string& fwd_name =
grad_outs_slotname_map.at(grad_output_name);
/* Handle Special Case: "PullSparseOp", etc
Forward:
Ids W
| |
PullSparseOp
|
Out
Backward:
Ids GradOut W
| | |
PullSparseGradOp
|
GradOut
Its grad output "GradOut" corresponds to forward output "Out",
where there is a hiden inplace involved. So we find "GradOut"'s
index
in
grads, and perform the inplace operation by constructing outs =
{{"Out", grads[i]}}
GradOut -> Out -> fwd_output_pos -> grads position -> grads[i]
outs = {{"Out", grads[i]}}
For returns, append "GradOut" to the very end of return list.
*/
if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE(
fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grads_position);
} else {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (duplicable_input_name_set.count(fwd_name)) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grad_output_name, fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance("
")."
"GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name);
}
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_output_name));
}
}
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n";
std::string outs_map_str = paddle::string::Sprintf(
BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
generated_grad_function_body += outs_map_str;
generated_grad_function_body += "\n";
VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs Map
const char* TRACE_OP_TEMPLATE =
" // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need "
"at runtime\n"
" egr::legacy::RunOp(\"%s\", %s, %s, this->attr_map_,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n";
std::string trace_opbase_str = paddle::string::Sprintf(
TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name);
generated_grad_function_body += trace_opbase_str;
VLOG(6) << "Generated Attrs Map";
// [Generation] Get Return
std::string outputs_str = "";
size_t num_appended_outputs = 0;
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name);
num_appended_outputs++;
} else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
}
}
/* Handle Special Case: "PullSparseOp", etc
For returns, append "GradOut" to the very end of return list. */
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str +=
paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, num_appended_outputs,
outs_name, grad_out_name);
num_appended_outputs++;
}
}
generated_grad_function_body += GenerateSingleOpBase(
fwd_op_type, op_base_type, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, in_vars, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs,
grad_attrs, is_op_base_per_duplicable_input, &outs_size);
}
generated_grad_function_body += outputs_str;
generated_grad_function_body += "\n";
if (is_op_base_per_duplicable_input) {
const char* OP_BASE_PER_DUP_INPUT_TEMPLATE =
" for(int i = 0; i < this->OutputMeta()[0].Size(); i++) {\n"
" %s\n"
" }\n";
generated_grad_function_body = paddle::string::Sprintf(
OP_BASE_PER_DUP_INPUT_TEMPLATE, generated_grad_function_body);
}
const char* BWD_RETURN_TEMPLATE =
......@@ -2045,47 +2213,6 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateNodeCCFile(node_cc_path, grad_node_cc_str);
}
static void PrepareAttrMapForOps() {
// Handle "run_program_op"
static framework::ProgramDesc fake_prog;
operators_with_attrs["run_program"] = {};
operators_with_attrs["run_program"]["global_block"] =
fake_prog.MutableBlock(0);
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
operators_with_attrs["fused_elemwise_add_activation"]["functor_list"] =
functor_list;
// Handle "fused_elemwise_activation"
operators_with_attrs["fused_elemwise_activation"] = {};
operators_with_attrs["fused_elemwise_activation"]["functor_list"] =
functor_list;
// Handle "reverse"
std::vector<int> axis = {0};
operators_with_attrs["reverse"] = {};
operators_with_attrs["reverse"]["axis"] = axis;
// Handle "flip"
operators_with_attrs["flip"] = {};
operators_with_attrs["flip"]["axis"] = axis;
// Handle "cast"
operators_with_attrs["cast"] = {};
operators_with_attrs["cast"]["out_dtype"] = 5;
operators_with_attrs["cast"]["in_dtype"] = 5;
// Handle "transfer_dtype"
operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
// Handle "c_split"
operators_with_attrs["c_split"] = {};
operators_with_attrs["c_split"]["nranks"] = 1;
}
} // namespace framework
} // namespace paddle
......
......@@ -321,7 +321,7 @@ class TestImperative(unittest.TestCase):
with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled())
def test_sum_op(self):
def func_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs = []
......@@ -338,7 +338,7 @@ class TestImperative(unittest.TestCase):
tmp = paddle.to_tensor(x)
tmp.stop_gradient = False
inputs2.append(tmp)
ret2 = fluid.layers.sums(inputs2)
ret2 = paddle.add_n(inputs2)
loss2 = fluid.layers.reduce_sum(ret2)
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
loss2.backward()
......@@ -349,6 +349,11 @@ class TestImperative(unittest.TestCase):
a = inputs2[0].gradient()
self.assertTrue(np.allclose(inputs2[0].gradient(), x))
def test_sum_op(self):
with _test_eager_guard():
self.func_sum_op()
self.func_sum_op()
def func_empty_var(self):
with fluid.dygraph.guard():
cur_program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册