Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
832e58d6
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
832e58d6
编写于
5月 06, 2022
作者:
J
Jiabin Yang
提交者:
GitHub
5月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix stray error (#42509)
* fix @ stray error in dygraph * fix @ stray error in dygraph
上级
06927016
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
147 addition
and
91 deletion
+147
-91
paddle/fluid/eager/auto_code_generator/eager_generator.cc
paddle/fluid/eager/auto_code_generator/eager_generator.cc
+97
-62
paddle/fluid/pybind/eager_op_function_generator.cc
paddle/fluid/pybind/eager_op_function_generator.cc
+20
-12
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+30
-17
未找到文件。
paddle/fluid/eager/auto_code_generator/eager_generator.cc
浏览文件 @
832e58d6
...
@@ -56,6 +56,13 @@ static std::unordered_set<std::string> black_ops_list = {"run_program"};
...
@@ -56,6 +56,13 @@ static std::unordered_set<std::string> black_ops_list = {"run_program"};
static
std
::
string
LegalizeVariableName
(
const
std
::
string
&
var_name
)
{
static
std
::
string
LegalizeVariableName
(
const
std
::
string
&
var_name
)
{
std
::
string
ret
=
var_name
;
std
::
string
ret
=
var_name
;
std
::
replace
(
ret
.
begin
(),
ret
.
end
(),
'-'
,
'_'
);
// replace all '-' to '_'
std
::
replace
(
ret
.
begin
(),
ret
.
end
(),
'-'
,
'_'
);
// replace all '-' to '_'
std
::
replace
(
ret
.
begin
(),
ret
.
end
(),
'@'
,
'_'
);
// replace all '-' to '_'
return
ret
;
}
static
std
::
string
LegalizeVarName
(
const
std
::
string
&
var_name
)
{
std
::
string
ret
=
var_name
;
std
::
replace
(
ret
.
begin
(),
ret
.
end
(),
'@'
,
'_'
);
// replace all '-' to '_'
return
ret
;
return
ret
;
}
}
...
@@ -1024,7 +1031,8 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1024,7 +1031,8 @@ static std::string GenerateGradNodeCreationContent(
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
for
(
const
proto
::
OpProto
::
Var
&
output
:
out_vars
)
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
out_vars
)
{
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
output_name
;
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
LegalizeVarName
(
output_name
);
// output autograd_meta should be got after running TraceOP.
// output autograd_meta should be got after running TraceOP.
if
(
output
.
duplicable
())
{
if
(
output
.
duplicable
())
{
...
@@ -1032,12 +1040,13 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1032,12 +1040,13 @@ static std::string GenerateGradNodeCreationContent(
" std::vector<egr::AutogradMeta*> %s = "
" std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::autograd_meta(&%s);
\n
"
;
"egr::EagerUtils::autograd_meta(&%s);
\n
"
;
get_output_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
get_output_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
GET_MULTI_AUTOGRAD_META_TEMPLATE
,
output_autograd_name
,
output_name
);
GET_MULTI_AUTOGRAD_META_TEMPLATE
,
output_autograd_name
,
LegalizeVarName
(
output_name
));
}
else
{
}
else
{
// In inplace op, the case where output is duplicable is not considered.
// In inplace op, the case where output is duplicable is not considered.
// Replace output directly with input in inplace op.
// Replace output directly with input in inplace op.
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
output_name
))
{
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
output_name
))
{
auto
inplace_input_name
=
inplace_map
[
output_name
]
;
auto
inplace_input_name
=
LegalizeVarName
(
inplace_map
[
output_name
])
;
const
std
::
string
&
inplace_input_autograd_name
=
const
std
::
string
&
inplace_input_autograd_name
=
"p_autograd_"
+
inplace_input_name
;
"p_autograd_"
+
inplace_input_name
;
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
...
@@ -1049,9 +1058,9 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1049,9 +1058,9 @@ static std::string GenerateGradNodeCreationContent(
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
" egr::AutogradMeta* %s = "
" egr::AutogradMeta* %s = "
"egr::EagerUtils::autograd_meta(&%s);
\n
"
;
"egr::EagerUtils::autograd_meta(&%s);
\n
"
;
get_output_autograd_meta_str
+=
get_output_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
GET_SINGLE_AUTOGRAD_META_TEMPLATE
,
GET_SINGLE_AUTOGRAD_META_TEMPLATE
,
output_autograd_name
,
output_autograd_name
,
output_name
);
LegalizeVarName
(
output_name
)
);
}
}
}
}
}
}
...
@@ -1061,28 +1070,32 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1061,28 +1070,32 @@ static std::string GenerateGradNodeCreationContent(
// inplace).
// inplace).
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
input_name
;
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
LegalizeVarName
(
input_name
);
if
(
input
.
duplicable
())
{
if
(
input
.
duplicable
())
{
const
char
*
GET_MULTI_AUTOGRAD_META_TEMPLATE
=
const
char
*
GET_MULTI_AUTOGRAD_META_TEMPLATE
=
" std::vector<egr::AutogradMeta*> %s = "
" std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);
\n
"
;
"egr::EagerUtils::nullable_autograd_meta(%s);
\n
"
;
get_input_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
get_input_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
GET_MULTI_AUTOGRAD_META_TEMPLATE
,
input_autograd_name
,
input_name
);
GET_MULTI_AUTOGRAD_META_TEMPLATE
,
input_autograd_name
,
LegalizeVarName
(
input_name
));
}
else
if
(
input
.
dispensable
())
{
}
else
if
(
input
.
dispensable
())
{
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
" egr::AutogradMeta* %s = "
" egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);
\n
"
;
"egr::EagerUtils::nullable_autograd_meta(%s);
\n
"
;
get_input_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
get_input_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
GET_SINGLE_AUTOGRAD_META_TEMPLATE
,
input_autograd_name
,
input_name
);
GET_SINGLE_AUTOGRAD_META_TEMPLATE
,
input_autograd_name
,
LegalizeVarName
(
input_name
));
}
else
{
}
else
{
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
const
char
*
GET_SINGLE_AUTOGRAD_META_TEMPLATE
=
" egr::AutogradMeta* %s = "
" egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);
\n
"
;
"egr::EagerUtils::nullable_autograd_meta(%s);
\n
"
;
get_input_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
get_input_autograd_meta_str
+=
paddle
::
string
::
Sprintf
(
GET_SINGLE_AUTOGRAD_META_TEMPLATE
,
input_autograd_name
,
input_name
);
GET_SINGLE_AUTOGRAD_META_TEMPLATE
,
input_autograd_name
,
LegalizeVarName
(
input_name
));
}
}
}
}
VLOG
(
6
)
<<
"Generated inputs autograd_meta"
;
VLOG
(
6
)
<<
"Generated inputs autograd_meta"
;
...
@@ -1096,7 +1109,7 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1096,7 +1109,7 @@ static std::string GenerateGradNodeCreationContent(
" egr::EagerUtils::CheckInplace(%s, p_autograd_%s, "
" egr::EagerUtils::CheckInplace(%s, p_autograd_%s, "
"require_any_grad);
\n
"
;
"require_any_grad);
\n
"
;
for
(
auto
&
inplace_pair
:
inplace_map
)
{
for
(
auto
&
inplace_pair
:
inplace_map
)
{
std
::
string
inplace_name
=
inplace_pair
.
second
;
std
::
string
inplace_name
=
LegalizeVarName
(
inplace_pair
.
second
)
;
check_inplace_str
+=
paddle
::
string
::
Sprintf
(
CHECKING_INPLACE_TEMPLATE
,
check_inplace_str
+=
paddle
::
string
::
Sprintf
(
CHECKING_INPLACE_TEMPLATE
,
inplace_name
,
inplace_name
);
inplace_name
,
inplace_name
);
}
}
...
@@ -1159,12 +1172,12 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1159,12 +1172,12 @@ static std::string GenerateGradNodeCreationContent(
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
tensor_wrapper_name
))
{
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
tensor_wrapper_name
))
{
auto
inplace_input_name
=
inplace_map
[
tensor_wrapper_name
];
auto
inplace_input_name
=
inplace_map
[
tensor_wrapper_name
];
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_TEMPLATE
,
tensor_wrapper_name
,
SET_TENSOR_WRAPPER_TEMPLATE
,
LegalizeVarName
(
tensor_wrapper_name
)
,
inplace_input_name
,
full_reserved
);
LegalizeVarName
(
inplace_input_name
)
,
full_reserved
);
}
else
{
}
else
{
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_TEMPLATE
,
tensor_wrapper_name
,
SET_TENSOR_WRAPPER_TEMPLATE
,
LegalizeVarName
(
tensor_wrapper_name
)
,
tensor_wrapper_name
,
full_reserved
);
LegalizeVarName
(
tensor_wrapper_name
)
,
full_reserved
);
}
}
}
}
}
}
...
@@ -1176,7 +1189,8 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1176,7 +1189,8 @@ static std::string GenerateGradNodeCreationContent(
std
::
string
compute_require_grad_args
=
"trace_backward"
;
std
::
string
compute_require_grad_args
=
"trace_backward"
;
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
input_name
;
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
LegalizeVarName
(
input_name
);
if
(
!
input
.
duplicable
())
{
if
(
!
input
.
duplicable
())
{
compute_require_grad_args
+=
", "
+
input_autograd_name
;
compute_require_grad_args
+=
", "
+
input_autograd_name
;
...
@@ -1184,8 +1198,9 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1184,8 +1198,9 @@ static std::string GenerateGradNodeCreationContent(
const
char
*
SET_GRAD_OUT_META_TEMPLATE
=
const
char
*
SET_GRAD_OUT_META_TEMPLATE
=
" grad_node->SetGradOutMeta(%s, %d);
\n
"
;
" grad_node->SetGradOutMeta(%s, %d);
\n
"
;
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
SET_GRAD_OUT_META_TEMPLATE
,
input_name
,
input_position
);
paddle
::
string
::
Sprintf
(
SET_GRAD_OUT_META_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_position
);
}
else
{
}
else
{
compute_require_grad_args
+=
", &"
+
input_autograd_name
;
compute_require_grad_args
+=
", &"
+
input_autograd_name
;
...
@@ -1193,8 +1208,9 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1193,8 +1208,9 @@ static std::string GenerateGradNodeCreationContent(
const
char
*
SET_GRAD_OUT_META_TEMPLATE
=
const
char
*
SET_GRAD_OUT_META_TEMPLATE
=
" grad_node->SetGradOutMeta(%s, %d);
\n
"
;
" grad_node->SetGradOutMeta(%s, %d);
\n
"
;
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
SET_GRAD_OUT_META_TEMPLATE
,
input_name
,
input_position
);
paddle
::
string
::
Sprintf
(
SET_GRAD_OUT_META_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_position
);
}
}
}
}
...
@@ -1208,7 +1224,7 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1208,7 +1224,7 @@ static std::string GenerateGradNodeCreationContent(
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
output_name
))
{
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
output_name
))
{
auto
inplace_input_name
=
inplace_map
[
output_name
];
auto
inplace_input_name
=
inplace_map
[
output_name
];
const
std
::
string
&
inplace_input_autograd_name
=
const
std
::
string
&
inplace_input_autograd_name
=
"p_autograd_"
+
inplace_input_name
;
"p_autograd_"
+
LegalizeVarName
(
inplace_input_name
)
;
size_t
output_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
size_t
output_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
// Intermediate Tensor does not require SetHistory, nor RetainGrad
// Intermediate Tensor does not require SetHistory, nor RetainGrad
...
@@ -1228,18 +1244,20 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1228,18 +1244,20 @@ static std::string GenerateGradNodeCreationContent(
const
char
*
SET_GRAD_IN_META_TEMPLATE
=
const
char
*
SET_GRAD_IN_META_TEMPLATE
=
" grad_node->SetGradInMeta(%s, %d);
\n
"
;
" grad_node->SetGradInMeta(%s, %d);
\n
"
;
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
SET_GRAD_IN_META_TEMPLATE
,
inplace_input_name
,
output_position
);
SET_GRAD_IN_META_TEMPLATE
,
LegalizeVarName
(
inplace_input_name
),
output_position
);
// Intermediate Tensor does not require CheckAndRetainGrad
// Intermediate Tensor does not require CheckAndRetainGrad
if
(
!
output
.
intermediate
())
{
if
(
!
output
.
intermediate
())
{
VLOG
(
6
)
<<
"Generated Call RetainGradForTensor"
;
VLOG
(
6
)
<<
"Generated Call RetainGradForTensor"
;
const
char
*
RETAIN_GRAD_TEMPLATE
=
const
char
*
RETAIN_GRAD_TEMPLATE
=
" egr::EagerUtils::CheckAndRetainGrad(%s);
\n
"
;
" egr::EagerUtils::CheckAndRetainGrad(%s);
\n
"
;
grad_node_creation_str
+=
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
RETAIN_GRAD_TEMPLATE
,
inplace_input_name
);
RETAIN_GRAD_TEMPLATE
,
LegalizeVarName
(
inplace_input_name
)
);
}
}
}
else
{
}
else
{
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
output_name
;
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
LegalizeVarName
(
output_name
);
size_t
output_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
size_t
output_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
// Intermediate Tensor does not require SetHistory, nor RetainGrad
// Intermediate Tensor does not require SetHistory, nor RetainGrad
...
@@ -1261,7 +1279,8 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1261,7 +1279,8 @@ static std::string GenerateGradNodeCreationContent(
const
char
*
SET_GRAD_IN_META_TEMPLATE
=
const
char
*
SET_GRAD_IN_META_TEMPLATE
=
" grad_node->SetGradInMeta(%s, %d);
\n
"
;
" grad_node->SetGradInMeta(%s, %d);
\n
"
;
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
SET_GRAD_IN_META_TEMPLATE
,
output_name
,
output_position
);
SET_GRAD_IN_META_TEMPLATE
,
LegalizeVarName
(
output_name
),
output_position
);
}
else
{
}
else
{
pass_stop_gradient_args
+=
", "
+
output_autograd_name
;
pass_stop_gradient_args
+=
", "
+
output_autograd_name
;
...
@@ -1280,7 +1299,8 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1280,7 +1299,8 @@ static std::string GenerateGradNodeCreationContent(
const
char
*
SET_GRAD_IN_META_TEMPLATE
=
const
char
*
SET_GRAD_IN_META_TEMPLATE
=
" grad_node->SetGradInMeta(%s, %d);
\n
"
;
" grad_node->SetGradInMeta(%s, %d);
\n
"
;
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
SET_GRAD_IN_META_TEMPLATE
,
output_name
,
output_position
);
SET_GRAD_IN_META_TEMPLATE
,
LegalizeVarName
(
output_name
),
output_position
);
}
}
// Intermediate Tensor does not require CheckAndRetainGrad
// Intermediate Tensor does not require CheckAndRetainGrad
...
@@ -1288,8 +1308,8 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -1288,8 +1308,8 @@ static std::string GenerateGradNodeCreationContent(
VLOG
(
6
)
<<
"Generated Call RetainGradForTensor"
;
VLOG
(
6
)
<<
"Generated Call RetainGradForTensor"
;
const
char
*
RETAIN_GRAD_TEMPLATE
=
const
char
*
RETAIN_GRAD_TEMPLATE
=
" egr::EagerUtils::CheckAndRetainGrad(%s);
\n
"
;
" egr::EagerUtils::CheckAndRetainGrad(%s);
\n
"
;
grad_node_creation_str
+=
grad_node_creation_str
+=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
RETAIN_GRAD_TEMPLATE
,
output_name
);
RETAIN_GRAD_TEMPLATE
,
LegalizeVarName
(
output_name
)
);
}
}
}
}
}
}
...
@@ -1412,9 +1432,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1412,9 +1432,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if
(
input
.
duplicable
())
{
if
(
input
.
duplicable
())
{
const
char
*
FWD_INS_ARG_TEMPLATE
=
const
char
*
FWD_INS_ARG_TEMPLATE
=
"const std::vector<paddle::experimental::Tensor>& %s"
;
"const std::vector<paddle::experimental::Tensor>& %s"
;
input_args_str_list
[
input_position
]
=
input_args_str_list
[
input_position
]
=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
FWD_INS_ARG_TEMPLATE
,
input_name
);
FWD_INS_ARG_TEMPLATE
,
LegalizeVarName
(
input_name
));
amp_function_call_args_str_list
[
input_position
]
=
" NEW_"
+
input_name
;
amp_function_call_args_str_list
[
input_position
]
=
" NEW_"
+
LegalizeVarName
(
input_name
);
core_ops_args_type_info
[
op_type
][
input_position
]
=
"list"
;
core_ops_args_type_info
[
op_type
][
input_position
]
=
"list"
;
}
else
{
}
else
{
...
@@ -1433,9 +1454,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1433,9 +1454,10 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if
(
!
flag_find_input_name
)
{
if
(
!
flag_find_input_name
)
{
FWD_INS_ARG_TEMPLATE
=
"const paddle::experimental::Tensor& %s"
;
FWD_INS_ARG_TEMPLATE
=
"const paddle::experimental::Tensor& %s"
;
}
}
input_args_str_list
[
input_position
]
=
input_args_str_list
[
input_position
]
=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
FWD_INS_ARG_TEMPLATE
,
input_name
);
FWD_INS_ARG_TEMPLATE
,
LegalizeVarName
(
input_name
));
amp_function_call_args_str_list
[
input_position
]
=
" NEW_"
+
input_name
;
amp_function_call_args_str_list
[
input_position
]
=
" NEW_"
+
LegalizeVarName
(
input_name
);
core_ops_args_type_info
[
op_type
][
input_position
]
=
"tensor"
;
core_ops_args_type_info
[
op_type
][
input_position
]
=
"tensor"
;
}
}
...
@@ -1445,8 +1467,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1445,8 +1467,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const
char
*
FWD_INS_CONTENT_TEMPLATE
=
const
char
*
FWD_INS_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, egr::EagerUtils::TrySyncToVars(%s) },"
;
"{
\"
%s
\"
, egr::EagerUtils::TrySyncToVars(%s) },"
;
ins_contents_str
+=
paddle
::
string
::
Sprintf
(
FWD_INS_CONTENT_TEMPLATE
,
ins_contents_str
+=
paddle
::
string
::
Sprintf
(
input_name
,
input_name
);
FWD_INS_CONTENT_TEMPLATE
,
input_name
,
LegalizeVarName
(
input_name
)
);
if
(
input
.
duplicable
())
{
if
(
input
.
duplicable
())
{
const
char
*
AMP_TENSORS_VECTOR_TEMPLATE
=
"%s,"
;
const
char
*
AMP_TENSORS_VECTOR_TEMPLATE
=
"%s,"
;
amp_tensors_vector_str
+=
amp_tensors_vector_str
+=
...
@@ -1455,16 +1477,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1455,16 +1477,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" auto NEW_%s = egr::AmpAutoCasts(
\"
%s
\"
, %s, amp_dst_dtype, "
" auto NEW_%s = egr::AmpAutoCasts(
\"
%s
\"
, %s, amp_dst_dtype, "
"
\"
%s
\"
);
\n
"
;
"
\"
%s
\"
);
\n
"
;
amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
AMP_AUTO_CAST_TEMPLATE
,
input_name
,
input_name
,
input_name
,
op_type
);
AMP_AUTO_CAST_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_name
,
LegalizeVarName
(
input_name
),
op_type
);
}
else
{
}
else
{
const
char
*
AMP_TENSORS_VECTOR_TEMPLATE
=
"{%s},"
;
const
char
*
AMP_TENSORS_VECTOR_TEMPLATE
=
"{%s},"
;
amp_tensors_vector_str
+=
amp_tensors_vector_str
+=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
AMP_TENSORS_VECTOR_TEMPLATE
,
input_name
);
AMP_TENSORS_VECTOR_TEMPLATE
,
LegalizeVarName
(
input_name
)
);
const
char
*
AMP_AUTO_CAST_TEMPLATE
=
const
char
*
AMP_AUTO_CAST_TEMPLATE
=
" auto NEW_%s = egr::AmpAutoCast(
\"
%s
\"
, %s, amp_dst_dtype, "
" auto NEW_%s = egr::AmpAutoCast(
\"
%s
\"
, %s, amp_dst_dtype, "
"
\"
%s
\"
);
\n
"
;
"
\"
%s
\"
);
\n
"
;
amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
AMP_AUTO_CAST_TEMPLATE
,
input_name
,
input_name
,
input_name
,
op_type
);
AMP_AUTO_CAST_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_name
,
LegalizeVarName
(
input_name
),
op_type
);
}
}
}
}
if
(
ins_contents_str
.
size
()
>
0
)
if
(
ins_contents_str
.
size
()
>
0
)
...
@@ -1500,35 +1524,41 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1500,35 +1524,41 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" if(%s.size() > 0) "
" if(%s.size() > 0) "
"ins[
\"
%s
\"
] = egr::EagerUtils::TrySyncToVars(%s);
\n
"
;
"ins[
\"
%s
\"
] = egr::EagerUtils::TrySyncToVars(%s);
\n
"
;
dispensable_ins_contents_str
+=
paddle
::
string
::
Sprintf
(
dispensable_ins_contents_str
+=
paddle
::
string
::
Sprintf
(
FWD_INS_CONTENT_TEMPLATE
,
input_name
,
input_name
,
input_name
);
FWD_INS_CONTENT_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_name
,
LegalizeVarName
(
input_name
));
const
char
*
FWD_AMP_TENSORS_VECTOR_TEMPLATE
=
const
char
*
FWD_AMP_TENSORS_VECTOR_TEMPLATE
=
" if(%s.size() > 0) "
" if(%s.size() > 0) "
"amp_tensors_vector.push_back(%s);
\n
"
;
"amp_tensors_vector.push_back(%s);
\n
"
;
dispensable_amp_tensors_vector_str
+=
paddle
::
string
::
Sprintf
(
dispensable_amp_tensors_vector_str
+=
paddle
::
string
::
Sprintf
(
FWD_AMP_TENSORS_VECTOR_TEMPLATE
,
input_name
,
input_name
);
FWD_AMP_TENSORS_VECTOR_TEMPLATE
,
LegalizeVarName
(
input_name
),
LegalizeVarName
(
input_name
));
const
char
*
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
=
const
char
*
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
=
" auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(
\"
%s
\"
, "
" auto NEW_%s = ((%s.size() > 0) ? egr::AmpAutoCasts(
\"
%s
\"
, "
"%s, amp_dst_dtype,
\"
%s
\"
) : %s);
\n
"
;
"%s, amp_dst_dtype,
\"
%s
\"
) : %s);
\n
"
;
dispensable_amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
dispensable_amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
,
input_name
,
input_name
,
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_name
,
input_name
,
op_type
,
input_name
);
LegalizeVarName
(
input_name
),
input_name
,
LegalizeVarName
(
input_name
),
op_type
,
LegalizeVarName
(
input_name
));
}
else
{
}
else
{
const
char
*
FWD_INS_CONTENT_TEMPLATE
=
const
char
*
FWD_INS_CONTENT_TEMPLATE
=
" if(%s.initialized()) "
" if(%s.initialized()) "
"ins[
\"
%s
\"
] = egr::EagerUtils::TrySyncToVars(%s);
\n
"
;
"ins[
\"
%s
\"
] = egr::EagerUtils::TrySyncToVars(%s);
\n
"
;
dispensable_ins_contents_str
+=
paddle
::
string
::
Sprintf
(
dispensable_ins_contents_str
+=
paddle
::
string
::
Sprintf
(
FWD_INS_CONTENT_TEMPLATE
,
input_name
,
input_name
,
input_name
);
FWD_INS_CONTENT_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_name
,
LegalizeVarName
(
input_name
));
const
char
*
FWD_AMP_TENSORS_VECTOR_TEMPLATE
=
const
char
*
FWD_AMP_TENSORS_VECTOR_TEMPLATE
=
" if(%s.initialized()) "
" if(%s.initialized()) "
"amp_tensors_vector.push_back({ %s });
\n
"
;
"amp_tensors_vector.push_back({ %s });
\n
"
;
dispensable_amp_tensors_vector_str
+=
paddle
::
string
::
Sprintf
(
dispensable_amp_tensors_vector_str
+=
paddle
::
string
::
Sprintf
(
FWD_AMP_TENSORS_VECTOR_TEMPLATE
,
input_name
,
input_name
);
FWD_AMP_TENSORS_VECTOR_TEMPLATE
,
LegalizeVarName
(
input_name
),
LegalizeVarName
(
input_name
));
const
char
*
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
=
const
char
*
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
=
" auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(
\"
%s
\"
, "
" auto NEW_%s = ((%s.initialized()) ? egr::AmpAutoCast(
\"
%s
\"
, "
"%s, amp_dst_dtype,
\"
%s
\"
) : %s);
\n
"
;
"%s, amp_dst_dtype,
\"
%s
\"
) : %s);
\n
"
;
dispensable_amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
dispensable_amp_auto_cast_str
+=
paddle
::
string
::
Sprintf
(
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
,
input_name
,
input_name
,
DISPENSABLE_AMP_AUTO_CAST_TEMPLATE
,
LegalizeVarName
(
input_name
),
input_name
,
input_name
,
op_type
,
input_name
);
LegalizeVarName
(
input_name
),
input_name
,
LegalizeVarName
(
input_name
),
op_type
,
LegalizeVarName
(
input_name
));
}
}
}
}
}
}
...
@@ -1550,18 +1580,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1550,18 +1580,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if
(
output
.
duplicable
())
{
if
(
output
.
duplicable
())
{
const
char
*
FWD_NUM_ARG_TEMPLATE
=
const
char
*
FWD_NUM_ARG_TEMPLATE
=
", std::vector<paddle::experimental::Tensor*>& %s"
;
", std::vector<paddle::experimental::Tensor*>& %s"
;
std
::
string
arg_str
=
std
::
string
arg_str
=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
FWD_NUM_ARG_TEMPLATE
,
output_var_name
);
FWD_NUM_ARG_TEMPLATE
,
LegalizeVarName
(
output_var_name
)
);
dygraph_function_args_str
+=
arg_str
;
dygraph_function_args_str
+=
arg_str
;
amp_function_call_args_str
+=
(
", "
+
output_var_name
);
amp_function_call_args_str
+=
(
", "
+
LegalizeVarName
(
output_var_name
)
);
core_ops_args_type_info
[
op_type
].
push_back
(
"list"
);
core_ops_args_type_info
[
op_type
].
push_back
(
"list"
);
}
else
{
}
else
{
const
char
*
FWD_NUM_ARG_TEMPLATE
=
", paddle::experimental::Tensor* %s"
;
const
char
*
FWD_NUM_ARG_TEMPLATE
=
", paddle::experimental::Tensor* %s"
;
std
::
string
arg_str
=
std
::
string
arg_str
=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
FWD_NUM_ARG_TEMPLATE
,
output_var_name
);
FWD_NUM_ARG_TEMPLATE
,
LegalizeVarName
(
output_var_name
)
);
dygraph_function_args_str
+=
arg_str
;
dygraph_function_args_str
+=
arg_str
;
amp_function_call_args_str
+=
(
", "
+
output_var_name
);
amp_function_call_args_str
+=
(
", "
+
LegalizeVarName
(
output_var_name
)
);
core_ops_args_type_info
[
op_type
].
push_back
(
"tensor"
);
core_ops_args_type_info
[
op_type
].
push_back
(
"tensor"
);
}
}
...
@@ -1577,8 +1607,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1577,8 +1607,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
else
{
}
else
{
const
char
*
FWD_OUTS_CONTENT_TEMPLATE
=
const
char
*
FWD_OUTS_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, egr::EagerUtils::TrySyncToVars(%s) },"
;
"{
\"
%s
\"
, egr::EagerUtils::TrySyncToVars(%s) },"
;
outs_contents_str
+=
paddle
::
string
::
Sprintf
(
outs_contents_str
+=
FWD_OUTS_CONTENT_TEMPLATE
,
output_name
,
output_var_name
);
paddle
::
string
::
Sprintf
(
FWD_OUTS_CONTENT_TEMPLATE
,
output_name
,
LegalizeVarName
(
output_var_name
));
}
}
core_ops_args_info
[
op_type
].
push_back
(
output_name
);
core_ops_args_info
[
op_type
].
push_back
(
output_name
);
...
@@ -1773,7 +1804,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1773,7 +1804,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std
::
vector
<
std
::
string
>
return_types
(
output_size
);
std
::
vector
<
std
::
string
>
return_types
(
output_size
);
for
(
const
proto
::
OpProto
::
Var
&
output
:
out_vars
)
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
out_vars
)
{
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
output_var_args_name
=
output_name
+
"Var"
;
const
std
::
string
output_var_args_name
=
LegalizeVariableName
(
output_name
+
"Var"
);
std
::
string
out_tensor_str
;
std
::
string
out_tensor_str
;
size_t
return_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
size_t
return_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
std
::
string
output_varname
=
LegalizeVariableName
(
output_name
);
std
::
string
output_varname
=
LegalizeVariableName
(
output_name
);
...
@@ -1837,9 +1869,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1837,9 +1869,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" %s.bump_inplace_version();
\n
"
" %s.bump_inplace_version();
\n
"
" VLOG(3) <<
\"
Tensor(
\"
<< %s.name() <<
\"
) uses Inplace "
" VLOG(3) <<
\"
Tensor(
\"
<< %s.name() <<
\"
) uses Inplace "
"Strategy.
\"
;
\n
"
;
"Strategy.
\"
;
\n
"
;
out_tensor_str
=
paddle
::
string
::
Sprintf
(
out_tensor_str
=
FWD_OUT_TENSOR_TEMPLATE
,
output_name
,
inplace_input_name
,
paddle
::
string
::
Sprintf
(
FWD_OUT_TENSOR_TEMPLATE
,
output_name
,
inplace_input_name
,
inplace_input_name
);
LegalizeVarName
(
inplace_input_name
),
LegalizeVarName
(
inplace_input_name
),
LegalizeVarName
(
inplace_input_name
));
}
else
{
}
else
{
const
char
*
FWD_OUT_TENSOR_TEMPLATE
=
const
char
*
FWD_OUT_TENSOR_TEMPLATE
=
" paddle::experimental::Tensor %s;
\n
"
" paddle::experimental::Tensor %s;
\n
"
...
@@ -1854,7 +1888,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1854,7 +1888,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
output_name
))
{
if
(
!
inplace_map
.
empty
()
&&
inplace_map
.
count
(
output_name
))
{
// Replace output directly with input in inplace op.
// Replace output directly with input in inplace op.
return_contents
[
return_position
]
=
inplace_map
[
output_name
];
return_contents
[
return_position
]
=
LegalizeVarName
(
inplace_map
[
output_name
]);
}
else
{
}
else
{
return_contents
[
return_position
]
=
output_varname
;
return_contents
[
return_position
]
=
output_varname
;
}
}
...
...
paddle/fluid/pybind/eager_op_function_generator.cc
浏览文件 @
832e58d6
...
@@ -36,6 +36,11 @@
...
@@ -36,6 +36,11 @@
// phi
// phi
#include "paddle/phi/kernels/declarations.h"
#include "paddle/phi/kernels/declarations.h"
static
std
::
string
LegalizeVarName
(
const
std
::
string
&
var_name
)
{
std
::
string
ret
=
var_name
;
std
::
replace
(
ret
.
begin
(),
ret
.
end
(),
'@'
,
'_'
);
// replace all '-' to '_'
return
ret
;
}
// clang-format off
// clang-format off
const
char
*
OUT_INITIALIZER_TEMPLATE
=
const
char
*
OUT_INITIALIZER_TEMPLATE
=
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"
;
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"
;
...
@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody(
...
@@ -185,18 +190,19 @@ std::string GenerateOpFunctionsBody(
continue
;
continue
;
}
}
const
auto
in_type
=
input
.
duplicable
()
?
IN_VAR_LIST_TYPE
:
IN_VAR_TYPE
;
const
auto
in_type
=
input
.
duplicable
()
?
IN_VAR_LIST_TYPE
:
IN_VAR_TYPE
;
auto
input_arg
=
auto
input_arg
=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
ARG_TEMPLATE
,
in_type
,
TempName
(
in_name
));
ARG_TEMPLATE
,
in_type
,
TempName
(
LegalizeVarName
(
in_name
)
));
input_args
+=
input_arg
;
input_args
+=
input_arg
;
input_args
+=
","
;
input_args
+=
","
;
input_args_num
++
;
input_args_num
++
;
const
auto
in_cast_type
=
const
auto
in_cast_type
=
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
op_type
,
ins_cast_str
+=
in_name
,
arg_idx
++
,
dispensable
);
paddle
::
string
::
Sprintf
(
in_cast_type
,
LegalizeVarName
(
in_name
),
op_type
,
in_name
,
arg_idx
++
,
dispensable
);
call_api_str
+=
in_name
+
", "
;
call_api_str
+=
LegalizeVarName
(
in_name
)
+
", "
;
}
}
if
(
!
input_args
.
empty
()
&&
input_args
.
back
()
==
','
)
{
if
(
!
input_args
.
empty
()
&&
input_args
.
back
()
==
','
)
{
...
@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody(
...
@@ -224,7 +230,7 @@ std::string GenerateOpFunctionsBody(
input_args
+=
","
;
input_args
+=
","
;
}
}
input_args
+=
out_type
;
input_args
+=
out_type
;
input_args
+=
out_name
;
input_args
+=
LegalizeVarName
(
out_name
)
;
input_args_num
++
;
input_args_num
++
;
if
(
output
.
dispensable
())
{
if
(
output
.
dispensable
())
{
...
@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody(
...
@@ -237,18 +243,19 @@ std::string GenerateOpFunctionsBody(
const
auto
out_template
=
output
.
duplicable
()
const
auto
out_template
=
output
.
duplicable
()
?
INPUT_LIST_INITIALIZER_TEMPLATE
?
INPUT_LIST_INITIALIZER_TEMPLATE
:
INPUT_INITIALIZER_TEMPLATE
;
:
INPUT_INITIALIZER_TEMPLATE
;
outs_initializer
+=
outs_initializer
+=
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
);
LegalizeVarName
(
out_name
)
);
outs_initializer
+=
","
;
outs_initializer
+=
","
;
}
}
const
auto
in_cast_type
=
output
.
duplicable
()
?
CAST_VAR_PTR_LIST_TEMPLATE
const
auto
in_cast_type
=
output
.
duplicable
()
?
CAST_VAR_PTR_LIST_TEMPLATE
:
CAST_VAR_PTR_TEMPLATE
;
:
CAST_VAR_PTR_TEMPLATE
;
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
out_name
,
op_type
,
ins_cast_str
+=
out_name
,
arg_idx
++
,
dispensable
);
paddle
::
string
::
Sprintf
(
in_cast_type
,
LegalizeVarName
(
out_name
),
op_type
,
out_name
,
arg_idx
++
,
dispensable
);
call_api_str
+=
out_name
+
", "
;
call_api_str
+=
LegalizeVarName
(
out_name
)
+
", "
;
}
else
{
}
else
{
// There are few Operators that have duplicable output, like `Out` in
// There are few Operators that have duplicable output, like `Out` in
// split op. We need to specify the number of variables for the
// split op. We need to specify the number of variables for the
...
@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody(
...
@@ -257,7 +264,8 @@ std::string GenerateOpFunctionsBody(
if
(
input_args
!=
""
)
{
if
(
input_args
!=
""
)
{
input_args
+=
","
;
input_args
+=
","
;
}
}
auto
out_num_str
=
paddle
::
string
::
Sprintf
(
ARG_OUT_NUM
,
out_name
);
auto
out_num_str
=
paddle
::
string
::
Sprintf
(
ARG_OUT_NUM
,
LegalizeVarName
(
out_name
));
input_args
+=
ARG_OUT_NUM_TYPE
;
input_args
+=
ARG_OUT_NUM_TYPE
;
input_args
+=
out_num_str
;
input_args
+=
out_num_str
;
input_args_num
++
;
input_args_num
++
;
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
832e58d6
...
@@ -35,6 +35,12 @@
...
@@ -35,6 +35,12 @@
// phi
// phi
#include "paddle/phi/kernels/declarations.h"
#include "paddle/phi/kernels/declarations.h"
static
std
::
string
LegalizeVarName
(
const
std
::
string
&
var_name
)
{
std
::
string
ret
=
var_name
;
std
::
replace
(
ret
.
begin
(),
ret
.
end
(),
'@'
,
'_'
);
// replace all '-' to '_'
return
ret
;
}
// NOTE(pangyoki): Inplace OP with duplicable input.
// NOTE(pangyoki): Inplace OP with duplicable input.
// The set includes inplace ops that have duplicable input.
// The set includes inplace ops that have duplicable input.
// The first Varbase in input needs to be specified for the inplace strategy
// The first Varbase in input needs to be specified for the inplace strategy
...
@@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody(
...
@@ -201,28 +207,31 @@ std::string GenerateOpFunctionsBody(
continue
;
continue
;
}
}
const
auto
in_type
=
input
.
duplicable
()
?
IN_VAR_LIST_TYPE
:
IN_VAR_TYPE
;
const
auto
in_type
=
input
.
duplicable
()
?
IN_VAR_LIST_TYPE
:
IN_VAR_TYPE
;
auto
input_arg
=
auto
input_arg
=
paddle
::
string
::
Sprintf
(
paddle
::
string
::
Sprintf
(
ARG_TEMPLATE
,
in_type
,
TempName
(
in_name
));
ARG_TEMPLATE
,
in_type
,
LegalizeVarName
(
TempName
(
in_name
)
));
input_args
+=
input_arg
;
input_args
+=
input_arg
;
input_args
+=
","
;
input_args
+=
","
;
input_args_num
++
;
input_args_num
++
;
const
auto
in_cast_type
=
const
auto
in_cast_type
=
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
input
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
auto
dispensable
=
input
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
in_name
,
in_name
,
ins_cast_str
+=
arg_idx
++
,
dispensable
);
paddle
::
string
::
Sprintf
(
in_cast_type
,
LegalizeVarName
(
in_name
),
in_name
,
arg_idx
++
,
dispensable
);
if
(
input
.
dispensable
())
{
if
(
input
.
dispensable
())
{
const
auto
in_template
=
input
.
duplicable
()
const
auto
in_template
=
input
.
duplicable
()
?
INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
?
INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
:
INPUT_INITIALIZER_TEMPLATE_WITH_NULL
;
:
INPUT_INITIALIZER_TEMPLATE_WITH_NULL
;
ins_initializer_with_null
+=
ins_initializer_with_null
+=
paddle
::
string
::
Sprintf
(
in_template
,
in_name
,
in_name
,
in_name
);
paddle
::
string
::
Sprintf
(
in_template
,
LegalizeVarName
(
in_name
),
in_name
,
LegalizeVarName
(
in_name
));
}
else
{
}
else
{
const
auto
in_template
=
input
.
duplicable
()
const
auto
in_template
=
input
.
duplicable
()
?
INPUT_LIST_INITIALIZER_TEMPLATE
?
INPUT_LIST_INITIALIZER_TEMPLATE
:
INPUT_INITIALIZER_TEMPLATE
;
:
INPUT_INITIALIZER_TEMPLATE
;
ins_initializer
+=
paddle
::
string
::
Sprintf
(
in_template
,
in_name
,
in_name
);
ins_initializer
+=
paddle
::
string
::
Sprintf
(
in_template
,
in_name
,
LegalizeVarName
(
in_name
));
ins_initializer
+=
","
;
ins_initializer
+=
","
;
}
}
}
}
...
@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody(
...
@@ -259,7 +268,7 @@ std::string GenerateOpFunctionsBody(
input_args
+=
","
;
input_args
+=
","
;
}
}
input_args
+=
out_type
;
input_args
+=
out_type
;
input_args
+=
out_name
;
input_args
+=
LegalizeVarName
(
out_name
)
;
input_args_num
++
;
input_args_num
++
;
if
(
output
.
dispensable
())
{
if
(
output
.
dispensable
())
{
...
@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody(
...
@@ -272,16 +281,17 @@ std::string GenerateOpFunctionsBody(
const
auto
out_template
=
output
.
duplicable
()
const
auto
out_template
=
output
.
duplicable
()
?
INPUT_LIST_INITIALIZER_TEMPLATE
?
INPUT_LIST_INITIALIZER_TEMPLATE
:
INPUT_INITIALIZER_TEMPLATE
;
:
INPUT_INITIALIZER_TEMPLATE
;
outs_initializer
+=
outs_initializer
+=
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
);
LegalizeVarName
(
out_name
)
);
outs_initializer
+=
","
;
outs_initializer
+=
","
;
}
}
const
auto
in_cast_type
=
const
auto
in_cast_type
=
output
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
output
.
duplicable
()
?
CAST_VAR_LIST_TEMPLATE
:
CAST_VAR_TEMPLATE
;
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
auto
dispensable
=
output
.
dispensable
()
?
"true"
:
"false"
;
ins_cast_str
+=
paddle
::
string
::
Sprintf
(
in_cast_type
,
out_name
,
out_name
,
ins_cast_str
+=
arg_idx
++
,
dispensable
);
paddle
::
string
::
Sprintf
(
in_cast_type
,
LegalizeVarName
(
out_name
),
out_name
,
arg_idx
++
,
dispensable
);
}
else
if
(
use_inplace_strategy
&&
inplace_map
.
count
(
out_name
))
{
}
else
if
(
use_inplace_strategy
&&
inplace_map
.
count
(
out_name
))
{
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
inplace_map
[
out_name
],
""
,
inplace_map
[
out_name
],
""
,
...
@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody(
...
@@ -307,11 +317,13 @@ std::string GenerateOpFunctionsBody(
// Leaf Var that doesn't stop gradient can't use inplace strategy.
// Leaf Var that doesn't stop gradient can't use inplace strategy.
// Increase inplace_version.
// Increase inplace_version.
inplace_strategy_str
+=
paddle
::
string
::
Sprintf
(
inplace_strategy_str
+=
paddle
::
string
::
Sprintf
(
INPLACE_STRATEGY_TEMPLATE
,
inplace_input_name
,
inplace_input_name
,
INPLACE_STRATEGY_TEMPLATE
,
LegalizeVarName
(
inplace_input_name
),
INPLACE_LEAF_ERROR_MESSAGE
,
inplace_input_name
,
inplace_input_name
,
LegalizeVarName
(
inplace_input_name
),
INPLACE_LEAF_ERROR_MESSAGE
,
inplace_input_name
);
LegalizeVarName
(
inplace_input_name
),
outs_initializer
+=
LegalizeVarName
(
inplace_input_name
),
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
inplace_input_name
);
LegalizeVarName
(
inplace_input_name
));
outs_initializer
+=
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
LegalizeVarName
(
inplace_input_name
));
outs_initializer
+=
","
;
outs_initializer
+=
","
;
}
else
{
}
else
{
// There are few Operators that have duplicable output, like `Out` in
// There are few Operators that have duplicable output, like `Out` in
...
@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody(
...
@@ -321,7 +333,8 @@ std::string GenerateOpFunctionsBody(
if
(
input_args
!=
""
)
{
if
(
input_args
!=
""
)
{
input_args
+=
","
;
input_args
+=
","
;
}
}
auto
out_num_str
=
paddle
::
string
::
Sprintf
(
ARG_OUT_NUM
,
out_name
);
auto
out_num_str
=
paddle
::
string
::
Sprintf
(
ARG_OUT_NUM
,
LegalizeVarName
(
out_name
));
input_args
+=
ARG_OUT_NUM_TYPE
;
input_args
+=
ARG_OUT_NUM_TYPE
;
input_args
+=
out_num_str
;
input_args
+=
out_num_str
;
input_args_num
++
;
input_args_num
++
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录