未验证 提交 30417999 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Fixed issues with generated scale operator (#40482)

* Fixed issues with generated scale operator

* Fixed minor issues
上级 187fcfa3
......@@ -56,23 +56,29 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret;
}
static bool IgnoreGradAttribute(const std::string& op_type,
const std::string& attr_name) {
// Attributes in operators_with_attrs are created manually during code
// generation
// We should ignore these arbitrary attrs when setting up grad attribute map
if (operators_with_attrs.count(op_type)) {
if (operators_with_attrs[op_type].count(attr_name)) {
return true;
}
}
static std::string HandleDynamicGradAttributes(const std::string& fwd_op_type,
const std::string& attrs_name) {
std::string additional_grad_attrs_str = "";
if (fwd_op_type == "sum") {
const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n";
additional_grad_attrs_str = paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "scale", "float(1.0)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)");
// Only allow SumOp
if (op_type != "sum") {
return true;
} else if (fwd_op_type == "scale") {
const char* GRAD_ATTRS_TEMPLATE = " %s[\"%s\"] = %s;\n";
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias", "float(0.0f)");
additional_grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, attrs_name, "bias_after_scale", "bool(true)");
}
return false;
return additional_grad_attrs_str;
}
static void PrepareAttrMapForOps() {
......@@ -1866,18 +1872,9 @@ static std::string GenerateSingleOpBase(
const char* ATTRS_TEMPLATE = " auto& %s = this->attr_map_;\n";
std::string grad_attrs_str =
paddle::string::Sprintf(ATTRS_TEMPLATE, attrs_name);
for (const auto& iter : grad_attrs) {
if (IgnoreGradAttribute(fwd_op_type, iter.first)) continue;
std::pair<std::string, std::string> type_val =
GetAttrType(iter.second, false /*is_arg*/);
const char* GRAD_ATTRS_TEMPLATE =
" %s %s = %s;\n"
" %s[\"%s\"] = %s;\n";
std::string var_name = iter.first + std::to_string(*outs_size);
grad_attrs_str += paddle::string::Sprintf(
GRAD_ATTRS_TEMPLATE, type_val.first, var_name, type_val.second,
attrs_name, iter.first, var_name);
}
// Handle dynamic grad attributes
grad_attrs_str += HandleDynamicGradAttributes(fwd_op_type, attrs_name);
generated_grad_function_body += grad_attrs_str;
const char* TRACE_OP_TEMPLATE =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册