Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d980d251
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
d980d251
编写于
5月 18, 2020
作者:
L
Leo Chen
提交者:
GitHub
5月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
specify outs, test=develop (#24537)
上级
16817c70
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
97 addition
and
31 deletion
+97
-31
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+97
-31
未找到文件。
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
d980d251
...
@@ -24,13 +24,48 @@
...
@@ -24,13 +24,48 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// determined by the OP`s proto automatically, i.e., all the inputs registered
// in OpMaker.
// However, some OPs have dispensable inputs, which means the input can
// be none for some conditions. It is discovered that most dispensable inputs
// is not used in imperative mode, so we drop those inputs when generating OP
// functions. While, for very few OPs, the dispensable inputs are used, we
// need to manually specify them in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_ins_map
=
{
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_ins_map
=
{
{
"layer_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"layer_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"gru_unit"
,
{
"Input"
,
"HiddenPrev"
,
"Weight"
,
"Bias"
}},
{
"gru_unit"
,
{
"Input"
,
"HiddenPrev"
,
"Weight"
,
"Bias"
}},
{
"label_smooth"
,
{
"X"
,
"PriorDist"
}},
{
"label_smooth"
,
{
"X"
,
"PriorDist"
}},
{
"assign"
,
{
"X"
}},
{
"assign"
,
{
"X"
}},
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"X"
,
"InScale"
,
"InAccum"
,
"InState"
}},
};
};
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_passing_out_map
=
{
// NOTE(zhiqiu): Like op_ins_map.
// Commonly, the outputs in auto-generated OP function are determined by the
// OP`s proto automatically, i.e., all the outputs registered in OpMaker.
// However, some OPs have dispensable outputs, which means the output can
// be none for some conditions. It is discovered that most dispensable outputs
// is not used in imperative mode, so we drop those outputs when generating OP
// functions. While, for very few OPs, the dispensable outputs are used, we
// need to manually specify them in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_outs_map
=
{
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
// generated in C++ automatically.
// However, some OPs need to pass the outputs from Python instead of generating
// them in C++. There are mainly 2 reasons for that,
// (1) Optimizer OPs need to update the input param in-place, like sgd.
// So they need to pass the output which is same as input param.
// (2) Very few python APIs has out in their arguments, like fill_constant.
// So they need to pass the python output to C++.
// Actually, this is not a good design, since it may break the SSA graph,
// especially in declarative mode.
// For those OPs, we need to manually specify the outs need to pass in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_passing_outs_map
=
{
{
"sgd"
,
{
"ParamOut"
}},
{
"sgd"
,
{
"ParamOut"
}},
{
"adam"
,
{
"adam"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
}},
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
}},
...
@@ -38,7 +73,10 @@ std::map<std::string, std::set<std::string>> op_passing_out_map = {
...
@@ -38,7 +73,10 @@ std::map<std::string, std::set<std::string>> op_passing_out_map = {
{
"batch_norm"
,
{
"MeanOut"
,
"VarianceOut"
}},
{
"batch_norm"
,
{
"MeanOut"
,
"VarianceOut"
}},
{
"accuracy"
,
{
"Correct"
,
"Total"
}},
{
"accuracy"
,
{
"Correct"
,
"Total"
}},
{
"fill_constant"
,
{
"Out"
}},
{
"fill_constant"
,
{
"Out"
}},
{
"matmul"
,
{
"Out"
}}};
{
"matmul"
,
{
"Out"
}},
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"OutScale"
,
"OutAccum"
,
"OutState"
}},
};
// clang-format off
// clang-format off
const
char
*
OUT_INITIALIZER_TEMPLATE
=
const
char
*
OUT_INITIALIZER_TEMPLATE
=
...
@@ -47,17 +85,30 @@ const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableO
...
@@ -47,17 +85,30 @@ const char* OUT_DUPLICABLE_INITIALIZER_TEMPLATE = R"({"%s", ConstructDuplicableO
const
char
*
INPUT_INITIALIZER_TEMPLATE
=
R"({"%s", {%s}})"
;
const
char
*
INPUT_INITIALIZER_TEMPLATE
=
R"({"%s", {%s}})"
;
const
char
*
INPUT_LIST_INITIALIZER_TEMPLATE
=
R"({"%s", %s})"
;
const
char
*
INPUT_LIST_INITIALIZER_TEMPLATE
=
R"({"%s", %s})"
;
const
char
*
INPUT_INITIALIZER_TEMPLATE_WITH_NULL
=
R"(
if (%s != nullptr) {
const
char
*
INPUT_INITIALIZER_TEMPLATE_WITH_NULL
=
R"(
ins["%s"] = {%s};
if (%s != nullptr) {
}
ins["%s"] = {%s};
}
)"
;
)"
;
const
char
*
INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
=
R"(
if (%s != nullptr) {
const
char
*
INPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
=
R"(
ins["%s"] = %s;
if (%s.size() != 0) {
}
ins["%s"] = %s;
}
)"
;
const
char
*
OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL
=
R"(
if (%s != nullptr) {
outs["%s"] = {%s};
}
)"
;
)"
;
const
char
*
OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
=
R"(
if (%s.size() != 0) {
outs["%s"] = %s;
}
)"
;
// if inputs is list, no need {}
// if inputs is list, no need {}
const
char
*
ARG_OUT_NUM
=
R"(%sNum)"
;
const
char
*
ARG_OUT_NUM
=
R"(%sNum)"
;
const
char
*
ARG_OUT_NUM_TYPE
=
R"(size_t )"
;
const
char
*
ARG_OUT_NUM_TYPE
=
R"(size_t )"
;
...
@@ -95,14 +146,19 @@ R"(
...
@@ -95,14 +146,19 @@ R"(
const
char
*
PYBIND_ITEM_TEMPLATE
=
R"( %s.def("%s", &%s);)"
;
const
char
*
PYBIND_ITEM_TEMPLATE
=
R"( %s.def("%s", &%s);)"
;
// clang-format on
// clang-format on
static
inline
bool
FindIn
putInSpecialization
(
const
std
::
string
&
op_type
,
static
inline
bool
FindIn
sMap
(
const
std
::
string
&
op_type
,
const
std
::
string
&
in_name
)
{
const
std
::
string
&
in_name
)
{
return
op_ins_map
[
op_type
].
count
(
in_name
);
return
op_ins_map
[
op_type
].
count
(
in_name
);
}
}
static
inline
bool
FindOutoutInSpecialization
(
const
std
::
string
&
op_type
,
static
inline
bool
FindOutsMap
(
const
std
::
string
&
op_type
,
const
std
::
string
&
out_name
)
{
const
std
::
string
&
out_name
)
{
return
op_passing_out_map
[
op_type
].
count
(
out_name
);
return
op_outs_map
[
op_type
].
count
(
out_name
);
}
static
inline
bool
FindPassingOutsMap
(
const
std
::
string
&
op_type
,
const
std
::
string
&
out_name
)
{
return
op_passing_outs_map
[
op_type
].
count
(
out_name
);
}
}
static
std
::
tuple
<
std
::
vector
<
std
::
string
>
,
std
::
vector
<
std
::
string
>>
static
std
::
tuple
<
std
::
vector
<
std
::
string
>
,
std
::
vector
<
std
::
string
>>
...
@@ -131,7 +187,7 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -131,7 +187,7 @@ GenerateOpFunctions(const std::string& module_name) {
for
(
auto
&
input
:
op_proto
->
inputs
())
{
for
(
auto
&
input
:
op_proto
->
inputs
())
{
auto
&
in_name
=
input
.
name
();
auto
&
in_name
=
input
.
name
();
// skip those dispensable inputs, like ResidualData in conv2d
// skip those dispensable inputs, like ResidualData in conv2d
if
(
input
.
dispensable
()
&&
!
FindIn
putInSpecialization
(
op_type
,
in_name
))
{
if
(
input
.
dispensable
()
&&
!
FindIn
sMap
(
op_type
,
in_name
))
{
continue
;
continue
;
}
}
const
auto
in_type
=
input
.
duplicable
()
?
VAR_LIST_TYPE
:
VAR_TYPE
;
const
auto
in_type
=
input
.
duplicable
()
?
VAR_LIST_TYPE
:
VAR_TYPE
;
...
@@ -165,30 +221,41 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -165,30 +221,41 @@ GenerateOpFunctions(const std::string& module_name) {
// Generate outs initializer
// Generate outs initializer
std
::
string
outs_initializer
=
"{"
;
std
::
string
outs_initializer
=
"{"
;
std
::
string
outs_initializer_with_null
=
""
;
std
::
string
return_type
=
""
;
std
::
string
return_type
=
""
;
std
::
string
return_str
=
""
;
std
::
string
return_str
=
""
;
int
outs_num
=
0
;
int
outs_num
=
0
;
for
(
auto
&
output
:
op_proto
->
outputs
())
{
for
(
auto
&
output
:
op_proto
->
outputs
())
{
if
(
output
.
dispensable
())
{
auto
&
out_name
=
output
.
name
();
// skip those dispensable oututs
if
(
output
.
dispensable
()
&&
!
FindOutsMap
(
op_type
,
out_name
))
{
continue
;
continue
;
}
}
const
auto
out_type
=
output
.
duplicable
()
?
VAR_LIST_TYPE
:
VAR_TYPE
;
const
auto
out_type
=
output
.
duplicable
()
?
VAR_LIST_TYPE
:
VAR_TYPE
;
const
auto
return_template
=
const
auto
return_template
=
output
.
duplicable
()
?
RETURN_LIST_TEMPLATE
:
RETURN_TEMPLATE
;
output
.
duplicable
()
?
RETURN_LIST_TEMPLATE
:
RETURN_TEMPLATE
;
auto
&
out_name
=
output
.
name
();
if
(
FindPassingOutsMap
(
op_type
,
out_name
))
{
std
::
string
out_initializer_str
;
if
(
FindOutoutInSpecialization
(
op_type
,
out_name
))
{
if
(
input_args
!=
""
)
{
if
(
input_args
!=
""
)
{
input_args
+=
","
;
input_args
+=
","
;
}
}
input_args
+=
out_type
;
input_args
+=
out_type
;
input_args
+=
out_name
;
input_args
+=
out_name
;
const
auto
out_template
=
output
.
duplicable
()
?
INPUT_LIST_INITIALIZER_TEMPLATE
if
(
output
.
dispensable
())
{
:
INPUT_INITIALIZER_TEMPLATE
;
const
auto
out_template
=
out_initializer_str
+=
output
.
duplicable
()
?
OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL_LIST
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
);
:
OUTPUT_INITIALIZER_TEMPLATE_WITH_NULL
;
outs_initializer_with_null
+=
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
,
out_name
);
}
else
{
const
auto
out_template
=
output
.
duplicable
()
?
INPUT_LIST_INITIALIZER_TEMPLATE
:
INPUT_INITIALIZER_TEMPLATE
;
outs_initializer
+=
paddle
::
string
::
Sprintf
(
out_template
,
out_name
,
out_name
);
outs_initializer
+=
","
;
}
}
else
{
}
else
{
// There are few Operators that have duplicable output, like `Out` in
// There are few Operators that have duplicable output, like `Out` in
// split op. We need to specify the number of variables for the
// split op. We need to specify the number of variables for the
...
@@ -200,12 +267,13 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -200,12 +267,13 @@ GenerateOpFunctions(const std::string& module_name) {
auto
out_num_str
=
paddle
::
string
::
Sprintf
(
ARG_OUT_NUM
,
out_name
);
auto
out_num_str
=
paddle
::
string
::
Sprintf
(
ARG_OUT_NUM
,
out_name
);
input_args
+=
ARG_OUT_NUM_TYPE
;
input_args
+=
ARG_OUT_NUM_TYPE
;
input_args
+=
out_num_str
;
input_args
+=
out_num_str
;
out
_initializer_str
=
paddle
::
string
::
Sprintf
(
out
s_initializer
+
=
paddle
::
string
::
Sprintf
(
OUT_DUPLICABLE_INITIALIZER_TEMPLATE
,
out_name
,
out_num_str
);
OUT_DUPLICABLE_INITIALIZER_TEMPLATE
,
out_name
,
out_num_str
);
}
else
{
}
else
{
out
_initializer_str
=
out
s_initializer
+
=
paddle
::
string
::
Sprintf
(
OUT_INITIALIZER_TEMPLATE
,
out_name
);
paddle
::
string
::
Sprintf
(
OUT_INITIALIZER_TEMPLATE
,
out_name
);
}
}
outs_initializer
+=
","
;
}
}
return_type
+=
out_type
;
return_type
+=
out_type
;
...
@@ -213,9 +281,6 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -213,9 +281,6 @@ GenerateOpFunctions(const std::string& module_name) {
return_str
+=
paddle
::
string
::
Sprintf
(
return_template
,
out_name
);
return_str
+=
paddle
::
string
::
Sprintf
(
return_template
,
out_name
);
return_str
+=
","
;
return_str
+=
","
;
outs_num
+=
1
;
outs_num
+=
1
;
outs_initializer
+=
out_initializer_str
;
outs_initializer
+=
","
;
}
}
if
(
outs_initializer
.
back
()
==
','
)
{
if
(
outs_initializer
.
back
()
==
','
)
{
outs_initializer
.
pop_back
();
outs_initializer
.
pop_back
();
...
@@ -241,7 +306,8 @@ GenerateOpFunctions(const std::string& module_name) {
...
@@ -241,7 +306,8 @@ GenerateOpFunctions(const std::string& module_name) {
// generate op funtcion body
// generate op funtcion body
auto
op_function_str
=
paddle
::
string
::
Sprintf
(
auto
op_function_str
=
paddle
::
string
::
Sprintf
(
OP_FUNCTION_TEMPLATE
,
return_type
,
func_name
,
function_args
,
OP_FUNCTION_TEMPLATE
,
return_type
,
func_name
,
function_args
,
outs_initializer
,
ins_initializer
,
ins_initializer_with_null
,
op_type
,
outs_initializer
,
ins_initializer
,
ins_initializer_with_null
+
outs_initializer_with_null
,
op_type
,
return_str
);
return_str
);
// generate pybind item
// generate pybind item
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录