未验证 提交 832e58d6 编写于 作者: J Jiabin Yang 提交者: GitHub

Fix stray error (#42509)

* fix @ stray error in dygraph

* fix @ stray error in dygraph
上级 06927016
...@@ -36,6 +36,11 @@ ...@@ -36,6 +36,11 @@
// phi // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// clang-format off // clang-format off
const char* OUT_INITIALIZER_TEMPLATE = const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"; R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
...@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody( ...@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody(
continue; continue;
} }
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg = auto input_arg = paddle::string::Sprintf(
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); ARG_TEMPLATE, in_type, TempName(LegalizeVarName(in_name)));
input_args += input_arg; input_args += input_arg;
input_args += ","; input_args += ",";
input_args_num++; input_args_num++;
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false"; auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type, ins_cast_str +=
in_name, arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), op_type,
in_name, arg_idx++, dispensable);
call_api_str += in_name + ", "; call_api_str += LegalizeVarName(in_name) + ", ";
} }
if (!input_args.empty() && input_args.back() == ',') { if (!input_args.empty() && input_args.back() == ',') {
...@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody( ...@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody(
input_args += ","; input_args += ",";
} }
input_args += out_type; input_args += out_type;
input_args += out_name; input_args += LegalizeVarName(out_name);
input_args_num++; input_args_num++;
if (output.dispensable()) { if (output.dispensable()) {
...@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody( ...@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable() const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE ? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE; : INPUT_INITIALIZER_TEMPLATE;
outs_initializer += outs_initializer += paddle::string::Sprintf(out_template, out_name,
paddle::string::Sprintf(out_template, out_name, out_name); LegalizeVarName(out_name));
outs_initializer += ","; outs_initializer += ",";
} }
const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE
: CAST_VAR_PTR_TEMPLATE; : CAST_VAR_PTR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false"; auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type, ins_cast_str +=
out_name, arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
op_type, out_name, arg_idx++, dispensable);
call_api_str += out_name + ", "; call_api_str += LegalizeVarName(out_name) + ", ";
} else { } else {
// There are few Operators that have duplicable output, like `Out` in // There are few Operators that have duplicable output, like `Out` in
// split op. We need to specify the number of variables for the // split op. We need to specify the number of variables for the
...@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody( ...@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") { if (input_args != "") {
input_args += ","; input_args += ",";
} }
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE; input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str; input_args += out_num_str;
input_args_num++; input_args_num++;
......
...@@ -35,6 +35,12 @@ ...@@ -35,6 +35,12 @@
// phi // phi
#include "paddle/phi/kernels/declarations.h" #include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// NOTE(pangyoki): Inplace OP with duplicable input. // NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input. // The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy // The first Varbase in input needs to be specified for the inplace strategy
...@@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody( ...@@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody(
continue; continue;
} }
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE; const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg = auto input_arg = paddle::string::Sprintf(
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name)); ARG_TEMPLATE, in_type, LegalizeVarName(TempName(in_name)));
input_args += input_arg; input_args += input_arg;
input_args += ","; input_args += ",";
input_args_num++; input_args_num++;
const auto in_cast_type = const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false"; auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name, ins_cast_str +=
arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), in_name,
arg_idx++, dispensable);
if (input.dispensable()) { if (input.dispensable()) {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST ? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
: INPUT_INITIALIZER_TEMPLATE_WITH_NULL; : INPUT_INITIALIZER_TEMPLATE_WITH_NULL;
ins_initializer_with_null += ins_initializer_with_null +=
paddle::string::Sprintf(in_template, in_name, in_name, in_name); paddle::string::Sprintf(in_template, LegalizeVarName(in_name),
in_name, LegalizeVarName(in_name));
} else { } else {
const auto in_template = input.duplicable() const auto in_template = input.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE ? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE; : INPUT_INITIALIZER_TEMPLATE;
ins_initializer += paddle::string::Sprintf(in_template, in_name, in_name); ins_initializer += paddle::string::Sprintf(in_template, in_name,
LegalizeVarName(in_name));
ins_initializer += ","; ins_initializer += ",";
} }
} }
...@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody( ...@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody(
input_args += ","; input_args += ",";
} }
input_args += out_type; input_args += out_type;
input_args += out_name; input_args += LegalizeVarName(out_name);
input_args_num++; input_args_num++;
if (output.dispensable()) { if (output.dispensable()) {
...@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody( ...@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable() const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE ? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE; : INPUT_INITIALIZER_TEMPLATE;
outs_initializer += outs_initializer += paddle::string::Sprintf(out_template, out_name,
paddle::string::Sprintf(out_template, out_name, out_name); LegalizeVarName(out_name));
outs_initializer += ","; outs_initializer += ",";
} }
const auto in_cast_type = const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE; output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false"; auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name, ins_cast_str +=
arg_idx++, dispensable); paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) { } else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
inplace_map[out_name], "", inplace_map[out_name], "",
...@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody( ...@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody(
// Leaf Var that doesn't stop gradient can't use inplace strategy. // Leaf Var that doesn't stop gradient can't use inplace strategy.
// Increase inplace_version. // Increase inplace_version.
inplace_strategy_str += paddle::string::Sprintf( inplace_strategy_str += paddle::string::Sprintf(
INPLACE_STRATEGY_TEMPLATE, inplace_input_name, inplace_input_name, INPLACE_STRATEGY_TEMPLATE, LegalizeVarName(inplace_input_name),
INPLACE_LEAF_ERROR_MESSAGE, inplace_input_name, inplace_input_name, LegalizeVarName(inplace_input_name), INPLACE_LEAF_ERROR_MESSAGE,
inplace_input_name); LegalizeVarName(inplace_input_name),
outs_initializer += LegalizeVarName(inplace_input_name),
paddle::string::Sprintf(out_template, out_name, inplace_input_name); LegalizeVarName(inplace_input_name));
outs_initializer += paddle::string::Sprintf(
out_template, out_name, LegalizeVarName(inplace_input_name));
outs_initializer += ","; outs_initializer += ",";
} else { } else {
// There are few Operators that have duplicable output, like `Out` in // There are few Operators that have duplicable output, like `Out` in
...@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody( ...@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") { if (input_args != "") {
input_args += ","; input_args += ",";
} }
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name); auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE; input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str; input_args += out_num_str;
input_args_num++; input_args_num++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册