未验证 提交 30417999 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Fixed issues with generated scale operator (#40482)

* Fixed issues with generated scale operator

* Fixed minor issues
上级 187fcfa3
...@@ -56,23 +56,29 @@ static std::string LegalizeVariableName(const std::string& var_name) { ...@@ -56,23 +56,29 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret; return ret;
} }
static bool IgnoreGradAttribute(const std::string& op_type, static std::string HandleDynamicGradAttributes(const std::string& fwd_op_type,
const std::string& attr_name) { const std::string& attrs_name) {
// Attributes in operators_with_attrs are created manually during code std::string additional_grad_attrs_str = "";
// generation
// We should ignore these arbitrary attrs when setting up grad attribute map
if (operators_with_attrs.count(op_type)) {
if (operators_with_attrs[op_type].count(attr_name)) {
return true;
}
}
// Only allow SumOp if (fwd_op_type == "sum") {
if (op_type != "sum") { const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n";
return true; additional_grad_attrs_str = paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "scale", "float(1.0)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)");
} else if (fwd_op_type == "scale") {
const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n";
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)");
} }
return false; return additional_grad_attrs_str;
} }
static void PrepareAttrMapForOps() { static void PrepareAttrMapForOps() {
...@@ -1866,18 +1872,9 @@ static std::string GenerateSingleOpBase( ...@@ -1866,18 +1872,9 @@ static std::string GenerateSingleOpBase(
const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n"; const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n";
std::string grad_attrs_str = std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name); paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
for (const auto& iter : grad_attrs) {
if (IgnoreGradAttribute(fwd_op_type, iter.first)) continue; // Handle dynamic grad attributes
std::pair<std::string, std::string> type_val = grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name);
GetAttrType(iter.second, false /*is_arg*/);
const char* GRAD_ATTRS_TEMPLATE =
" %s %s = %s;\n"
" %s[\"%s\"] = %s;\n";
std::string var_name = iter.first + std::to_string(*outs_size);
grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, type_val.first, var_name, type_val.second,
attrs_name, iter.first, var_name);
}
generated_grad_function_body += grad_attrs_str; generated_grad_function_body += grad_attrs_str;
const char* TRACE_OP_TEMPLATE = const char* TRACE_OP_TEMPLATE =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册