diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 6a2e5e7ac6cd75068bba4e9b675ab67588c38366..bf838b27615028167e35a8e85e7636dd4c834016 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -56,23 +56,29 @@ static std::string LegalizeVariableName(const std::string& var_name) { return ret; } -static bool IgnoreGradAttribute(const std::string& op_type, - const std::string& attr_name) { - // Attributes in operators_with_attrs are created manually during code - // generation - // We should ignore these arbitrary attrs when setting up grad attribute map - if (operators_with_attrs.count(op_type)) { - if (operators_with_attrs[op_type].count(attr_name)) { - return true; - } - } +static std::string HandleDynamicGradAttributes(const std::string& fwd_op_type, + const std::string& attrs_name) { + std::string additional_grad_attrs_str = ""; + + if (fwd_op_type == "sum") { + const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n"; + additional_grad_attrs_str = paddle::string::Sprintf( + GRAD_ATTRS_TEMPLATE, attrs_name, "scale", "float(1.0)"); + additional_grad_attrs_str += paddle::string::Sprintf( + GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)"); + additional_grad_attrs_str += paddle::string::Sprintf( + GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)"); - // Only allow SumOp - if (op_type != "sum") { - return true; + } else if (fwd_op_type == "scale") { + const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n"; + + additional_grad_attrs_str += paddle::string::Sprintf( + GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)"); + additional_grad_attrs_str += paddle::string::Sprintf( + GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)"); } - return false; + return additional_grad_attrs_str; } static void PrepareAttrMapForOps() { @@ -1866,18 +1872,9 @@ static std::string GenerateSingleOpBase( const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n"; std::string grad_attrs_str = paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name); - for (const auto& iter : grad_attrs) { - if (IgnoreGradAttribute(fwd_op_type, iter.first)) continue; - std::pair type_val = - GetAttrType(iter.second, false /*is_arg*/); - const char* GRAD_ATTRS_TEMPLATE = - " %s %s = %s;\n" - " %s[\"%s\"] = %s;\n"; - std::string var_name = iter.first + std::to_string(*outs_size); - grad_attrs_str += paddle::string::Sprintf( - GRAD_ATTRS_TEMPLATE, type_val.first, var_name, type_val.second, - attrs_name, iter.first, var_name); - } + + // Handle dynamic grad attributes + grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name); generated_grad_function_body += grad_attrs_str; const char* TRACE_OP_TEMPLATE =