未验证 提交 832e58d6 编写于 作者: J Jiabin Yang 提交者: GitHub

Fix stray error (#42509)

* fix @ stray error in dygraph

* fix @ stray error in dygraph
上级 06927016
......@@ -36,6 +36,11 @@
// phi
#include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// clang-format off
const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
......@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody(
continue;
}
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg =
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
auto input_arg = paddle::string::Sprintf(
ARG_TEMPLATE, in_type, TempName(LegalizeVarName(in_name)));
input_args += input_arg;
input_args += ",";
input_args_num++;
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, op_type,
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), op_type,
in_name, arg_idx++, dispensable);
call_api_str += in_name + ", ";
call_api_str += LegalizeVarName(in_name) + ", ";
}
if (!input_args.empty() && input_args.back() == ',') {
......@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody(
input_args += ",";
}
input_args += out_type;
input_args += out_name;
input_args += LegalizeVarName(out_name);
input_args_num++;
if (output.dispensable()) {
......@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE;
outs_initializer +=
paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += paddle::string::Sprintf(out_template, out_name,
LegalizeVarName(out_name));
outs_initializer += ",";
}
const auto in_cast_type = output.duplicable() ? CAST_VAR_PTR_LIST_TEMPLATE
: CAST_VAR_PTR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, op_type,
out_name, arg_idx++, dispensable);
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
op_type, out_name, arg_idx++, dispensable);
call_api_str += out_name + ", ";
call_api_str += LegalizeVarName(out_name) + ", ";
} else {
// There are few Operators that have duplicable output, like `Out` in
// split op. We need to specify the number of variables for the
......@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") {
input_args += ",";
}
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str;
input_args_num++;
......
......@@ -35,6 +35,12 @@
// phi
#include "paddle/phi/kernels/declarations.h"
static std::string LegalizeVarName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '@', '_'); // replace all '-' to '_'
return ret;
}
// NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy
......@@ -201,15 +207,16 @@ std::string GenerateOpFunctionsBody(
continue;
}
const auto in_type = input.duplicable() ? IN_VAR_LIST_TYPE : IN_VAR_TYPE;
auto input_arg =
paddle::string::Sprintf(ARG_TEMPLATE, in_type, TempName(in_name));
auto input_arg = paddle::string::Sprintf(
ARG_TEMPLATE, in_type, LegalizeVarName(TempName(in_name)));
input_args += input_arg;
input_args += ",";
input_args_num++;
const auto in_cast_type =
input.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = input.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, in_name, in_name,
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(in_name), in_name,
arg_idx++, dispensable);
if (input.dispensable()) {
......@@ -217,12 +224,14 @@ std::string GenerateOpFunctionsBody(
? INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
: INPUT_INITIALIZER_TEMPLATE_WITH_NULL;
ins_initializer_with_null +=
paddle::string::Sprintf(in_template, in_name, in_name, in_name);
paddle::string::Sprintf(in_template, LegalizeVarName(in_name),
in_name, LegalizeVarName(in_name));
} else {
const auto in_template = input.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE;
ins_initializer += paddle::string::Sprintf(in_template, in_name, in_name);
ins_initializer += paddle::string::Sprintf(in_template, in_name,
LegalizeVarName(in_name));
ins_initializer += ",";
}
}
......@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody(
input_args += ",";
}
input_args += out_type;
input_args += out_name;
input_args += LegalizeVarName(out_name);
input_args_num++;
if (output.dispensable()) {
......@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody(
const auto out_template = output.duplicable()
? INPUT_LIST_INITIALIZER_TEMPLATE
: INPUT_INITIALIZER_TEMPLATE;
outs_initializer +=
paddle::string::Sprintf(out_template, out_name, out_name);
outs_initializer += paddle::string::Sprintf(out_template, out_name,
LegalizeVarName(out_name));
outs_initializer += ",";
}
const auto in_cast_type =
output.duplicable() ? CAST_VAR_LIST_TEMPLATE : CAST_VAR_TEMPLATE;
auto dispensable = output.dispensable() ? "true" : "false";
ins_cast_str += paddle::string::Sprintf(in_cast_type, out_name, out_name,
arg_idx++, dispensable);
ins_cast_str +=
paddle::string::Sprintf(in_cast_type, LegalizeVarName(out_name),
out_name, arg_idx++, dispensable);
} else if (use_inplace_strategy && inplace_map.count(out_name)) {
PADDLE_ENFORCE_NE(
inplace_map[out_name], "",
......@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody(
// Leaf Var that doesn't stop gradient can't use inplace strategy.
// Increase inplace_version.
inplace_strategy_str += paddle::string::Sprintf(
INPLACE_STRATEGY_TEMPLATE, inplace_input_name, inplace_input_name,
INPLACE_LEAF_ERROR_MESSAGE, inplace_input_name, inplace_input_name,
inplace_input_name);
outs_initializer +=
paddle::string::Sprintf(out_template, out_name, inplace_input_name);
INPLACE_STRATEGY_TEMPLATE, LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name), INPLACE_LEAF_ERROR_MESSAGE,
LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name),
LegalizeVarName(inplace_input_name));
outs_initializer += paddle::string::Sprintf(
out_template, out_name, LegalizeVarName(inplace_input_name));
outs_initializer += ",";
} else {
// There are few Operators that have duplicable output, like `Out` in
......@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody(
if (input_args != "") {
input_args += ",";
}
auto out_num_str = paddle::string::Sprintf(ARG_OUT_NUM, out_name);
auto out_num_str =
paddle::string::Sprintf(ARG_OUT_NUM, LegalizeVarName(out_name));
input_args += ARG_OUT_NUM_TYPE;
input_args += out_num_str;
input_args_num++;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册