Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
06c3cce9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
06c3cce9
编写于
12月 01, 2021
作者:
Z
Zhanlue Yang
提交者:
GitHub
12月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Handled dispensable tensors in AutoCodeGen for Eager Dygraph (#37723)
上级
f91e2331
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
314 addition
and
175 deletion
+314
-175
paddle/fluid/eager/auto_code_generator/eager_generator.cc
paddle/fluid/eager/auto_code_generator/eager_generator.cc
+191
-73
paddle/fluid/pybind/op_function_generator.cc
paddle/fluid/pybind/op_function_generator.cc
+2
-102
paddle/fluid/pybind/op_function_generator.h
paddle/fluid/pybind/op_function_generator.h
+121
-0
未找到文件。
paddle/fluid/eager/auto_code_generator/eager_generator.cc
浏览文件 @
06c3cce9
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/op_function_generator.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
...
@@ -358,18 +359,149 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
...
@@ -358,18 +359,149 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
return
true
;
return
true
;
}
}
/* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */
static
void
PurifyOpProto
(
const
proto
::
OpProto
&
op_proto
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_inputs_name_pos_map
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_outputs_name_pos_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
)
{
// Op Name
const
std
::
string
op_name
=
op_proto
.
type
();
// Handle dispensable inputs
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
std
::
string
input_name
=
input
.
name
();
// Delete dispensable tensor unless specified in op_ins_map
if
(
input
.
dispensable
())
{
if
(
!
op_ins_map
.
count
(
op_name
)
||
!
op_ins_map
[
op_name
].
count
(
input_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Input: "
<<
input_name
;
// in_vars
auto
iter
=
in_vars
->
begin
();
for
(
iter
=
in_vars
->
begin
();
iter
!=
in_vars
->
end
();
iter
++
)
{
if
(
iter
->
name
()
==
input_name
)
{
break
;
}
}
in_vars
->
erase
(
iter
);
// grad_outs_slotname_map
auto
grad_outs_slotname_map_purified
=
*
grad_outs_slotname_map
;
for
(
const
auto
&
iter
:
*
grad_outs_slotname_map
)
{
const
std
::
string
&
grad_output_name
=
iter
.
first
;
const
std
::
string
&
matched_input_name
=
iter
.
second
;
if
(
matched_input_name
==
input_name
)
{
grad_outs_slotname_map_purified
.
erase
(
grad_output_name
);
PADDLE_ENFORCE
(
grad_outs
->
count
(
grad_output_name
)
>
0
,
paddle
::
platform
::
errors
::
Fatal
(
"Unable to find gradient output name in grad_outs."
));
// grad_outs
grad_outs
->
erase
(
grad_output_name
);
}
}
*
grad_outs_slotname_map
=
grad_outs_slotname_map_purified
;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if
(
grad_ins_fwd_slotname_map
->
count
(
input_name
))
grad_ins_fwd_slotname_map
->
erase
(
input_name
);
// grad_ins: output as tensorwrapper
if
(
grad_ins
->
count
(
input_name
))
grad_ins
->
erase
(
input_name
);
}
}
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
std
::
string
output_name
=
output
.
name
();
// Delete dispensable tensor unless specified in op_outs_map
if
(
output
.
dispensable
())
{
if
(
!
op_outs_map
.
count
(
op_name
)
||
!
op_outs_map
[
op_name
].
count
(
output_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Output: "
<<
output_name
;
// out_vars
auto
iter
=
out_vars
->
begin
();
for
(
iter
=
out_vars
->
begin
();
iter
!=
out_vars
->
end
();
iter
++
)
{
if
(
iter
->
name
()
==
output_name
)
{
break
;
}
}
out_vars
->
erase
(
iter
);
// grad_ins_grad_slotname_map
auto
grad_ins_grad_slotname_map_purified
=
*
grad_ins_grad_slotname_map
;
for
(
const
auto
&
iter
:
*
grad_ins_grad_slotname_map
)
{
const
std
::
string
&
grad_input_name
=
iter
.
first
;
const
std
::
string
&
matched_output_name
=
iter
.
second
;
if
(
matched_output_name
==
output_name
)
{
grad_ins_grad_slotname_map_purified
.
erase
(
grad_input_name
);
PADDLE_ENFORCE
(
grad_ins
->
count
(
grad_input_name
)
>
0
,
paddle
::
platform
::
errors
::
Fatal
(
"Unable to find gradient input name in grad_ins."
));
// grad_ins
grad_ins
->
erase
(
grad_input_name
);
}
}
*
grad_ins_grad_slotname_map
=
grad_ins_grad_slotname_map_purified
;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if
(
grad_ins_fwd_slotname_map
->
count
(
output_name
))
grad_ins_fwd_slotname_map
->
erase
(
output_name
);
// grad_ins: output as tensorwrapper
if
(
grad_ins
->
count
(
output_name
))
grad_ins
->
erase
(
output_name
);
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t
in_pos
=
0
;
for
(
const
auto
&
var
:
*
in_vars
)
{
VLOG
(
6
)
<<
"Mapping input tensor: "
<<
var
.
name
()
<<
" To position: "
<<
in_pos
;
(
*
fwd_inputs_name_pos_map
)[
var
.
name
()]
=
in_pos
;
in_pos
++
;
}
size_t
out_pos
=
0
;
for
(
const
auto
&
var
:
*
out_vars
)
{
VLOG
(
6
)
<<
"Mapping output tensor: "
<<
var
.
name
()
<<
" To position: "
<<
out_pos
;
(
*
fwd_outputs_name_pos_map
)[
var
.
name
()]
=
out_pos
;
out_pos
++
;
}
}
/* -------------------------------- */
/* -------------------------------- */
/* --------- Collect Info --------- */
/* --------- Collect Info --------- */
/* -------------------------------- */
/* -------------------------------- */
static
bool
CollectInformationFromOpInfo
(
static
bool
CollectInformationFromOpInfo
(
const
paddle
::
framework
::
OpInfo
&
op_info
,
const
paddle
::
framework
::
OpInfo
&
op_info
,
std
::
vector
<
paddle
::
framework
::
AttributeMap
>*
grad_node_default_attr_maps
,
std
::
vector
<
std
::
string
>*
grad_op_types
,
std
::
vector
<
std
::
string
>*
grad_op_types
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_inputs_name_pos_map
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_outputs_name_pos_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
,
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
grad_ins
,
...
@@ -380,6 +512,13 @@ static bool CollectInformationFromOpInfo(
...
@@ -380,6 +512,13 @@ static bool CollectInformationFromOpInfo(
const
std
::
string
&
op_type
=
op_proto
.
type
();
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
vector
<
int64_t
>
dims
=
{
1
,
1
,
1
,
1
};
std
::
vector
<
int64_t
>
dims
=
{
1
,
1
,
1
,
1
};
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
in_vars
->
push_back
(
input
);
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
out_vars
->
push_back
(
output
);
}
/* ------ Prepare "ins" ------ */
/* ------ Prepare "ins" ------ */
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>>>
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>>>
...
@@ -494,7 +633,6 @@ static bool CollectInformationFromOpInfo(
...
@@ -494,7 +633,6 @@ static bool CollectInformationFromOpInfo(
for
(
auto
iter
=
grad_node
->
begin
();
iter
<
grad_node
->
end
();
iter
++
)
{
for
(
auto
iter
=
grad_node
->
begin
();
iter
<
grad_node
->
end
();
iter
++
)
{
// Each OpBase
// Each OpBase
paddle
::
imperative
::
OpBase
&
op_base
=
*
iter
;
paddle
::
imperative
::
OpBase
&
op_base
=
*
iter
;
grad_node_default_attr_maps
->
push_back
(
op_base
.
DefaultAttrsMap
());
grad_op_types
->
push_back
(
op_base
.
Type
());
grad_op_types
->
push_back
(
op_base
.
Type
());
}
}
...
@@ -538,22 +676,6 @@ static bool CollectInformationFromOpInfo(
...
@@ -538,22 +676,6 @@ static bool CollectInformationFromOpInfo(
grad_outs_slotname_map
);
grad_outs_slotname_map
);
VLOG
(
6
)
<<
"Finished Slotname Matching for Grad_Outs"
;
VLOG
(
6
)
<<
"Finished Slotname Matching for Grad_Outs"
;
/* ------ Maping forward slot name to fwd position ------ */
size_t
in_pos
=
0
;
for
(
const
auto
&
iter
:
ins
)
{
VLOG
(
6
)
<<
"Mapping input tensor: "
<<
iter
.
first
<<
" To position: "
<<
in_pos
;
(
*
fwd_inputs_name_pos_map
)[
iter
.
first
]
=
in_pos
;
in_pos
++
;
}
size_t
out_pos
=
0
;
for
(
const
auto
&
iter
:
outs
)
{
VLOG
(
6
)
<<
"Mapping output tensor: "
<<
iter
.
first
<<
" To position: "
<<
out_pos
;
(
*
fwd_outputs_name_pos_map
)[
iter
.
first
]
=
out_pos
;
out_pos
++
;
}
return
true
;
return
true
;
}
}
...
@@ -561,16 +683,13 @@ static bool CollectInformationFromOpInfo(
...
@@ -561,16 +683,13 @@ static bool CollectInformationFromOpInfo(
/* --------- CodeGen: Forward GradNode Creation ------ */
/* --------- CodeGen: Forward GradNode Creation ------ */
/* --------------------------------------------------- */
/* --------------------------------------------------- */
static
std
::
string
GenerateGradNodeCreationContent
(
static
std
::
string
GenerateGradNodeCreationContent
(
const
std
::
vector
<
paddle
::
framework
::
AttributeMap
>&
grad_node_default_attr_maps
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
proto
::
OpProto
&
op_proto
)
{
const
std
::
string
&
op_type
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
in_vars
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
out_vars
)
{
VLOG
(
6
)
<<
"Generating GradNode Creation codes"
;
VLOG
(
6
)
<<
"Generating GradNode Creation codes"
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
// [Generation] Construct GradOpNode
// [Generation] Construct GradOpNode
// Run ComputeRequiredGrad
// Run ComputeRequiredGrad
...
@@ -578,7 +697,7 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -578,7 +697,7 @@ static std::string GenerateGradNodeCreationContent(
// then generate: "egr::AutogradMeta* p_autograd_out =
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
std
::
string
get_autograd_meta_str
=
" // Prepare Autograd Meta
\n
"
;
std
::
string
get_autograd_meta_str
=
" // Prepare Autograd Meta
\n
"
;
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
input_name
;
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
input_name
;
...
@@ -602,7 +721,7 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -602,7 +721,7 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable,
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
p_proto
.
outputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
ut_vars
)
{
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
output_name
;
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
output_name
;
...
@@ -636,8 +755,8 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -636,8 +755,8 @@ static std::string GenerateGradNodeCreationContent(
// [GradOpNode] Generation
// [GradOpNode] Generation
std
::
string
grad_node_creation_str
=
""
;
std
::
string
grad_node_creation_str
=
""
;
size_t
bwd_in_slot_num
=
o
p_proto
.
outputs
()
.
size
();
size_t
bwd_in_slot_num
=
o
ut_vars
.
size
();
size_t
bwd_out_slot_num
=
op_proto
.
inputs
()
.
size
();
size_t
bwd_out_slot_num
=
in_vars
.
size
();
const
char
*
GRAD_OP_NODE_TEMPLATE
=
const
char
*
GRAD_OP_NODE_TEMPLATE
=
" auto grad_node = std::make_shared<GradNode%s>(%d, %d);
\n
"
;
" auto grad_node = std::make_shared<GradNode%s>(%d, %d);
\n
"
;
grad_node_creation_str
+=
" // Create GradOpNode
\n
"
;
grad_node_creation_str
+=
" // Create GradOpNode
\n
"
;
...
@@ -669,7 +788,7 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -669,7 +788,7 @@ static std::string GenerateGradNodeCreationContent(
// [GradOpNode] SetGradOutMeta
// [GradOpNode] SetGradOutMeta
// [GradOpNode] Add Edges
// [GradOpNode] Add Edges
std
::
string
compute_require_grad_args
=
"trace_backward"
;
std
::
string
compute_require_grad_args
=
"trace_backward"
;
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
input_name
;
const
std
::
string
&
input_autograd_name
=
"p_autograd_"
+
input_name
;
compute_require_grad_args
+=
", &"
+
input_autograd_name
;
compute_require_grad_args
+=
", &"
+
input_autograd_name
;
...
@@ -689,7 +808,7 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -689,7 +808,7 @@ static std::string GenerateGradNodeCreationContent(
// [AutogradMeta] SetOutRank
// [AutogradMeta] SetOutRank
// [AutogradMeta] SetHistory
// [AutogradMeta] SetHistory
std
::
string
pass_stop_gradient_args
=
"false"
;
std
::
string
pass_stop_gradient_args
=
"false"
;
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
p_proto
.
outputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
ut_vars
)
{
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
output_name
;
const
std
::
string
&
output_autograd_name
=
"p_autograd_"
+
output_name
;
pass_stop_gradient_args
+=
", &"
+
output_autograd_name
;
pass_stop_gradient_args
+=
", &"
+
output_autograd_name
;
...
@@ -743,8 +862,6 @@ static std::string AppendUseOp(const std::string& op_type) {
...
@@ -743,8 +862,6 @@ static std::string AppendUseOp(const std::string& op_type) {
/* --------- CodeGen: Forward ----- */
/* --------- CodeGen: Forward ----- */
/* -------------------------------- */
/* -------------------------------- */
static
std
::
pair
<
std
::
string
,
std
::
string
>
GenerateForwardFunctionContents
(
static
std
::
pair
<
std
::
string
,
std
::
string
>
GenerateForwardFunctionContents
(
const
std
::
vector
<
paddle
::
framework
::
AttributeMap
>&
grad_node_default_attr_maps
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
...
@@ -758,7 +875,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -758,7 +875,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
std
::
string
,
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
grad_outs
,
grad_outs
,
const
proto
::
OpProto
&
op_proto
)
{
const
std
::
string
&
op_type
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
in_vars
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
out_vars
)
{
/*
/*
// Forward Function Example:
// Forward Function Example:
std::tuple<vector<Tensor>, Tensor, vector<Tensor>>
std::tuple<vector<Tensor>, Tensor, vector<Tensor>>
...
@@ -779,6 +897,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -779,6 +897,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
,ConstructDuplicableOutput(Out1Num)} };
,ConstructDuplicableOutput(Out1Num)} };
// According to op_proto->attrs()
// According to op_proto->attrs()
egr::legacy::RunOp("op_type", ins, outs, attr_map,
egr::legacy::RunOp("op_type", ins, outs, attr_map,
Controller.Instance().GetExpectedPlace(), {});
Controller.Instance().GetExpectedPlace(), {});
...
@@ -795,8 +914,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -795,8 +914,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
*/
*/
VLOG
(
6
)
<<
"Generating Dygraph Forward Function"
;
VLOG
(
6
)
<<
"Generating Dygraph Forward Function"
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
string
generated_function_body
=
""
;
std
::
string
generated_function_body
=
""
;
std
::
string
dygraph_function_args_str
=
""
;
std
::
string
dygraph_function_args_str
=
""
;
...
@@ -806,8 +923,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -806,8 +923,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Ins Map
// [Generation] Get Ins Map
std
::
string
ins_contents_str
=
""
;
std
::
string
ins_contents_str
=
""
;
std
::
vector
<
std
::
string
>
input_args_str_list
(
op_proto
.
inputs
()
.
size
());
std
::
vector
<
std
::
string
>
input_args_str_list
(
in_vars
.
size
());
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
const
std
::
string
&
input_name
=
input
.
name
();
const
std
::
string
&
input_name
=
input
.
name
();
size_t
input_position
=
fwd_inputs_name_pos_map
.
at
(
input_name
);
size_t
input_position
=
fwd_inputs_name_pos_map
.
at
(
input_name
);
if
(
input
.
duplicable
())
{
if
(
input
.
duplicable
())
{
...
@@ -848,7 +965,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -848,7 +965,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Outs Map
// [Generation] Get Outs Map
std
::
string
outs_contents_str
=
""
;
std
::
string
outs_contents_str
=
""
;
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
p_proto
.
outputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
ut_vars
)
{
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_name
=
output
.
name
();
std
::
string
outnum
=
"1"
;
std
::
string
outnum
=
"1"
;
if
(
output
.
duplicable
())
{
if
(
output
.
duplicable
())
{
...
@@ -898,17 +1015,17 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -898,17 +1015,17 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
" egr::Controller::Instance().GetExpectedPlace(),
\n
"
" egr::Controller::Instance().GetExpectedPlace(),
\n
"
" &default_attrs, true, {});
\n
"
;
" &default_attrs, true, {});
\n
"
;
std
::
string
trace_op_str
=
std
::
string
trace_op_str
=
paddle
::
string
::
Sprintf
(
FWD_TRACE_OP_TEMPLATE
,
op_
proto
.
type
()
);
paddle
::
string
::
Sprintf
(
FWD_TRACE_OP_TEMPLATE
,
op_
type
);
generated_function_body
+=
trace_op_str
;
generated_function_body
+=
trace_op_str
;
generated_function_body
+=
"
\n
"
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated AttrMap & TraceOp"
;
VLOG
(
6
)
<<
"Generated AttrMap & TraceOp"
;
// [Generation] Convert output VarBase to Vector/Tensor
// [Generation] Convert output VarBase to Vector/Tensor
size_t
output_size
=
o
p_proto
.
outputs
()
.
size
();
size_t
output_size
=
o
ut_vars
.
size
();
std
::
vector
<
std
::
string
>
return_contents
(
output_size
);
std
::
vector
<
std
::
string
>
return_contents
(
output_size
);
std
::
vector
<
std
::
string
>
return_types
(
output_size
);
std
::
vector
<
std
::
string
>
return_types
(
output_size
);
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
p_proto
.
outputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
ut_vars
)
{
const
std
::
string
&
output_name
=
output
.
name
();
const
std
::
string
&
output_name
=
output
.
name
();
std
::
string
out_tensor_str
;
std
::
string
out_tensor_str
;
size_t
return_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
size_t
return_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
...
@@ -937,8 +1054,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -937,8 +1054,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] ComputeRequireGrad -> GradNodeCreation
// [Generation] ComputeRequireGrad -> GradNodeCreation
std
::
string
grad_node_creation_body_str
=
GenerateGradNodeCreationContent
(
std
::
string
grad_node_creation_body_str
=
GenerateGradNodeCreationContent
(
grad_node_default_attr_maps
,
fwd_in
puts_name_pos_map
,
fwd_inputs_name_pos_map
,
fwd_out
puts_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
op_proto
);
grad_ins_fwd_slotname_map
,
op_type
,
in_vars
,
out_vars
);
generated_function_body
+=
grad_node_creation_body_str
;
generated_function_body
+=
grad_node_creation_body_str
;
generated_function_body
+=
"
\n
"
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated GradNode Creation codes"
;
VLOG
(
6
)
<<
"Generated GradNode Creation codes"
;
...
@@ -1004,8 +1121,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1004,8 +1121,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
/* --------- CodeGen: GradNode::operator() ------ */
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
/* ---------------------------------------------- */
static
std
::
string
GenerateGradNodeCCContents
(
static
std
::
string
GenerateGradNodeCCContents
(
const
std
::
vector
<
paddle
::
framework
::
AttributeMap
>&
grad_node_default_attr_maps
,
const
std
::
vector
<
std
::
string
>&
grad_op_types
,
const
std
::
vector
<
std
::
string
>&
grad_op_types
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
...
@@ -1020,7 +1135,8 @@ static std::string GenerateGradNodeCCContents(
...
@@ -1020,7 +1135,8 @@ static std::string GenerateGradNodeCCContents(
std
::
string
,
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
grad_outs
,
grad_outs
,
const
proto
::
OpProto
&
op_proto
)
{
const
std
::
string
&
op_type
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
in_vars
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
out_vars
)
{
VLOG
(
6
)
<<
"Generating Grad Node CC"
;
VLOG
(
6
)
<<
"Generating Grad Node CC"
;
/* [Outline]
/* [Outline]
...
@@ -1066,7 +1182,6 @@ static std::string GenerateGradNodeCCContents(
...
@@ -1066,7 +1182,6 @@ static std::string GenerateGradNodeCCContents(
}
}
*/
*/
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
string
generated_grad_function_body
=
""
;
std
::
string
generated_grad_function_body
=
""
;
// [Generation] Get Tracer
// [Generation] Get Tracer
...
@@ -1122,7 +1237,7 @@ static std::string GenerateGradNodeCCContents(
...
@@ -1122,7 +1237,7 @@ static std::string GenerateGradNodeCCContents(
// [Generation] Get Outs Map
// [Generation] Get Outs Map
std
::
unordered_set
<
std
::
string
>
duplicable_input_name_set
;
std
::
unordered_set
<
std
::
string
>
duplicable_input_name_set
;
for
(
const
auto
&
in
:
op_proto
.
inputs
()
)
{
for
(
const
auto
&
in
:
in_vars
)
{
if
(
in
.
duplicable
())
duplicable_input_name_set
.
insert
(
in
.
name
());
if
(
in
.
duplicable
())
duplicable_input_name_set
.
insert
(
in
.
name
());
}
}
...
@@ -1173,7 +1288,7 @@ static std::string GenerateGradNodeCCContents(
...
@@ -1173,7 +1288,7 @@ static std::string GenerateGradNodeCCContents(
// [Generation] Get Attrs Map
// [Generation] Get Attrs Map
std
::
string
trace_opbase_str
=
""
;
std
::
string
trace_opbase_str
=
""
;
for
(
size_t
i
=
0
;
i
<
grad_
node_default_attr_map
s
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
grad_
op_type
s
.
size
();
i
++
)
{
const
std
::
string
&
op_base_type
=
grad_op_types
[
i
];
const
std
::
string
&
op_base_type
=
grad_op_types
[
i
];
const
char
*
TRACE_OP_TEMPLATE
=
const
char
*
TRACE_OP_TEMPLATE
=
...
@@ -1230,10 +1345,9 @@ static std::string GenerateGradNodeCCContents(
...
@@ -1230,10 +1345,9 @@ static std::string GenerateGradNodeCCContents(
/* --------- CodeGen: GradNode Header ------ */
/* --------- CodeGen: GradNode Header ------ */
/* ----------------------------------------- */
/* ----------------------------------------- */
static
std
::
string
GenerateGradNodeHeaderContents
(
static
std
::
string
GenerateGradNodeHeaderContents
(
const
std
::
vector
<
paddle
::
framework
::
AttributeMap
>&
grad_node_default_attr_maps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
proto
::
OpProto
&
op_proto
)
{
const
std
::
string
&
op_type
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
in_vars
,
const
std
::
vector
<
proto
::
OpProto
::
Var
>&
out_vars
)
{
VLOG
(
6
)
<<
"Generating Grad Node Header"
;
VLOG
(
6
)
<<
"Generating Grad Node Header"
;
const
char
*
GRAD_NODE_TEMPLATE
=
const
char
*
GRAD_NODE_TEMPLATE
=
...
@@ -1261,8 +1375,6 @@ static std::string GenerateGradNodeHeaderContents(
...
@@ -1261,8 +1375,6 @@ static std::string GenerateGradNodeHeaderContents(
"%s
\n
"
"%s
\n
"
"};"
;
"};"
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
// [Generation] Handle Attributes
// [Generation] Handle Attributes
std
::
string
set_attr_map_str
=
std
::
string
set_attr_map_str
=
" void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {
\n
"
" void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {
\n
"
...
@@ -1279,12 +1391,12 @@ static std::string GenerateGradNodeHeaderContents(
...
@@ -1279,12 +1391,12 @@ static std::string GenerateGradNodeHeaderContents(
// [Generation] Handle TensorWrappers
// [Generation] Handle TensorWrappers
std
::
unordered_set
<
std
::
string
>
duplicable_tensors
;
std
::
unordered_set
<
std
::
string
>
duplicable_tensors
;
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
input
:
in_vars
)
{
if
(
input
.
duplicable
())
{
if
(
input
.
duplicable
())
{
duplicable_tensors
.
insert
(
input
.
name
());
duplicable_tensors
.
insert
(
input
.
name
());
}
}
}
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
p_proto
.
outputs
()
)
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
o
ut_vars
)
{
if
(
output
.
duplicable
())
{
if
(
output
.
duplicable
())
{
duplicable_tensors
.
insert
(
output
.
name
());
duplicable_tensors
.
insert
(
output
.
name
());
}
}
...
@@ -1454,13 +1566,12 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1454,13 +1566,12 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ----------------------------- */
/* ----------------------------- */
/* ---- Collect Information ---- */
/* ---- Collect Information ---- */
/* ----------------------------- */
/* ----------------------------- */
std
::
vector
<
paddle
::
framework
::
AttributeMap
>
grad_node_default_attr_maps
;
std
::
vector
<
std
::
string
>
grad_op_types
;
std
::
vector
<
std
::
string
>
grad_op_types
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_inputs_name_pos_map
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_outputs_name_pos_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_outs_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_outs_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_fwd_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_fwd_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_grad_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_grad_slotname_map
;
std
::
vector
<
proto
::
OpProto
::
Var
>
in_vars
;
std
::
vector
<
proto
::
OpProto
::
Var
>
out_vars
;
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>
grad_ins
;
grad_ins
;
...
@@ -1470,13 +1581,20 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1470,13 +1581,20 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG
(
6
)
<<
"-------- CollectInformationFromOpInfo -------"
;
VLOG
(
6
)
<<
"-------- CollectInformationFromOpInfo -------"
;
bool
is_available
=
CollectInformationFromOpInfo
(
bool
is_available
=
CollectInformationFromOpInfo
(
op_info
,
&
grad_node_default_attr_maps
,
&
grad_op_types
,
op_info
,
&
grad_op_types
,
&
grad_outs_slotname_map
,
&
fwd_inputs_name_pos_map
,
&
fwd_outputs_name_pos_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
in_vars
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
out_vars
,
&
grad_ins
,
&
grad_outs
);
&
grad_ins_grad_slotname_map
,
&
grad_ins
,
&
grad_outs
);
if
(
!
is_available
)
continue
;
if
(
!
is_available
)
continue
;
VLOG
(
6
)
<<
"-------- PurifyOpProto -------"
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_inputs_name_pos_map
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_outputs_name_pos_map
;
PurifyOpProto
(
*
op_proto
,
&
fwd_inputs_name_pos_map
,
&
fwd_outputs_name_pos_map
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
in_vars
,
&
out_vars
,
&
grad_ins
,
&
grad_outs
);
/* --------------------------- */
/* --------------------------- */
/* --------- CodeGen --------- */
/* --------- CodeGen --------- */
/* --------------------------- */
/* --------------------------- */
...
@@ -1484,10 +1602,10 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1484,10 +1602,10 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG
(
6
)
<<
"-------- GenerateForwardFunctionContents -------"
;
VLOG
(
6
)
<<
"-------- GenerateForwardFunctionContents -------"
;
std
::
pair
<
std
::
string
,
std
::
string
>
body_and_declaration
=
std
::
pair
<
std
::
string
,
std
::
string
>
body_and_declaration
=
GenerateForwardFunctionContents
(
GenerateForwardFunctionContents
(
grad_node_default_attr_maps
,
fwd_in
puts_name_pos_map
,
fwd_inputs_name_pos_map
,
fwd_out
puts_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fw
d_slotname_map
,
grad_ins_fwd_slotname_map
,
grad_ins_gra
d_slotname_map
,
grad_
ins_grad_slotname_map
,
grad_outs_slotname_map
,
grad_in
s
,
grad_
outs_slotname_map
,
grad_ins
,
grad_outs
,
op_type
,
in_var
s
,
grad_outs
,
*
op_proto
);
out_vars
);
std
::
string
fwd_function_str
=
body_and_declaration
.
first
;
std
::
string
fwd_function_str
=
body_and_declaration
.
first
;
GenerateForwardDygraphFile
(
op_type
,
output_dir
,
fwd_function_str
);
GenerateForwardDygraphFile
(
op_type
,
output_dir
,
fwd_function_str
);
...
@@ -1498,16 +1616,16 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1498,16 +1616,16 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ---- xxx_node.h ---- */
/* ---- xxx_node.h ---- */
VLOG
(
6
)
<<
"-------- GenerateGradNodeHeaderContents -------"
;
VLOG
(
6
)
<<
"-------- GenerateGradNodeHeaderContents -------"
;
std
::
string
grad_node_h_str
=
GenerateGradNodeHeaderContents
(
std
::
string
grad_node_h_str
=
GenerateGradNodeHeaderContents
(
grad_
node_default_attr_maps
,
grad_ins_fwd_slotname_map
,
*
op_proto
);
grad_
ins_fwd_slotname_map
,
op_type
,
in_vars
,
out_vars
);
GenerateNodeHFile
(
op_type
,
output_dir
,
grad_node_h_str
);
GenerateNodeHFile
(
op_type
,
output_dir
,
grad_node_h_str
);
/* ---- xxx_node.cc ---- */
/* ---- xxx_node.cc ---- */
VLOG
(
6
)
<<
"-------- GenerateGradNodeCCContents -------"
;
VLOG
(
6
)
<<
"-------- GenerateGradNodeCCContents -------"
;
std
::
string
grad_node_cc_str
=
GenerateGradNodeCCContents
(
std
::
string
grad_node_cc_str
=
GenerateGradNodeCCContents
(
grad_
node_default_attr_maps
,
grad_op_types
,
fwd_in
puts_name_pos_map
,
grad_
op_types
,
fwd_inputs_name_pos_map
,
fwd_out
puts_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fw
d_slotname_map
,
grad_ins_fwd_slotname_map
,
grad_ins_gra
d_slotname_map
,
grad_
ins_grad_slotname_map
,
grad_outs_slotname_map
,
grad_ins
,
grad_out
s
,
grad_
outs_slotname_map
,
grad_ins
,
grad_outs
,
op_type
,
in_var
s
,
*
op_proto
);
out_vars
);
GenerateNodeCCFile
(
op_type
,
output_dir
,
grad_node_cc_str
);
GenerateNodeCCFile
(
op_type
,
output_dir
,
grad_node_cc_str
);
VLOG
(
6
)
<<
op_type
<<
": Finished Generation"
;
VLOG
(
6
)
<<
op_type
<<
": Finished Generation"
;
...
...
paddle/fluid/pybind/op_function_generator.cc
浏览文件 @
06c3cce9
...
@@ -12,6 +12,8 @@
...
@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/fluid/pybind/op_function_generator.h"
#include <algorithm>
#include <algorithm>
#include <fstream>
#include <fstream>
#include <iostream>
#include <iostream>
...
@@ -30,108 +32,6 @@
...
@@ -30,108 +32,6 @@
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#include "paddle/fluid/framework/fleet/ascend_wrapper.h"
#endif
#endif
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// determined by the OP`s proto automatically, i.e., all the inputs registered
// in OpMaker.
// However, some OPs have dispensable inputs, which means the input can
// be none for some conditions. It is discovered that most dispensable inputs
// is not used in imperative mode, so we drop those inputs when generating OP
// functions. While, for very few OPs, the dispensable inputs are used, we
// need to manually specify them in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_ins_map
=
{
{
"layer_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"bincount"
,
{
"X"
,
"Weights"
}},
{
"fused_attention"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"SrcMask"
,
"OutLinearW"
,
"OutLinearBias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"instance_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"gru_unit"
,
{
"Input"
,
"HiddenPrev"
,
"Weight"
,
"Bias"
}},
{
"label_smooth"
,
{
"X"
,
"PriorDist"
}},
{
"assign"
,
{
"X"
}},
{
"reshape2"
,
{
"X"
,
"Shape"
}},
{
"expand"
,
{
"X"
,
"ExpandTimes"
}},
{
"slice"
,
{
"Input"
,
"StartsTensor"
,
"EndsTensor"
}},
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"X"
,
"InScale"
,
"InAccum"
,
"InState"
}},
{
"nll_loss"
,
{
"X"
,
"Label"
,
"Weight"
}},
{
"bilinear_tensor_product"
,
{
"X"
,
"Y"
,
"Weight"
,
"Bias"
}},
{
"gather"
,
{
"X"
,
"Index"
,
"Axis"
}},
{
"roi_pool"
,
{
"X"
,
"ROIs"
,
"RoisNum"
}},
{
"roi_align"
,
{
"X"
,
"ROIs"
,
"RoisNum"
}},
{
"psroi_pool"
,
{
"X"
,
"ROIs"
,
"RoisNum"
}},
{
"collect_fpn_proposals"
,
{
"MultiLevelRois"
,
"MultiLevelScores"
,
"MultiLevelRoIsNum"
}},
{
"distribute_fpn_proposals"
,
{
"FpnRois"
,
"RoisNum"
}},
{
"warpctc"
,
{
"Logits"
,
"Label"
,
"LogitsLength"
,
"LabelLength"
}},
{
"hierarchical_sigmoid"
,
{
"X"
,
"W"
,
"Label"
,
"PathTable"
,
"PathCode"
,
"Bias"
}},
{
"moving_average_abs_max_scale"
,
{
"X"
,
"InAccum"
,
"InState"
}},
{
"multiclass_nms3"
,
{
"BBoxes"
,
"Scores"
,
"RoisNum"
}},
{
"box_coder"
,
{
"PriorBox"
,
"PriorBoxVar"
,
"TargetBox"
}},
{
"momentum"
,
{
"Param"
,
"Grad"
,
"Velocity"
,
"LearningRate"
,
"MasterParam"
}},
{
"sparse_momentum"
,
{
"Param"
,
"Grad"
,
"Velocity"
,
"Index"
,
"LearningRate"
}},
{
"rnn"
,
{
"Input"
,
"PreState"
,
"WeightList"
,
"SequenceLength"
}},
{
"run_program"
,
{
"X"
,
"Params"
}},
{
"fused_feedforward"
,
{
"Dropout1Seed"
,
"Dropout2Seed"
,
"Linear1Bias"
,
"Linear2Bias"
,
"Ln1Scale"
,
"Ln1Bias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"faster_tokenizer"
,
{
"Text"
,
"Vocab"
,
"TextPair"
}},
{
"matrix_rank"
,
{
"X"
,
"TolTensor"
}},
{
"adam"
,
{
"Param"
,
"Grad"
,
"LearningRate"
,
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
,
"MasterParam"
}},
{
"adamw"
,
{
"Param"
,
"Grad"
,
"LearningRate"
,
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
,
"MasterParam"
}},
};
// NOTE(zhiqiu): Like op_ins_map.
// Commonly, the outputs in auto-generated OP function are determined by the
// OP`s proto automatically, i.e., all the outputs registered in OpMaker.
// However, some OPs have dispensable outputs, which means the output can
// be none for some conditions. It is discovered that most dispensable outputs
// is not used in imperative mode, so we drop those outputs when generating OP
// functions. While, for very few OPs, the dispensable outputs are used, we
// need to manually specify them in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_outs_map
=
{
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
{
"fused_attention"
,
{
"LnMean"
,
"LnVariance"
,
"LnOut"
,
"QKVOut"
,
"QKVBiasOut"
,
"TransposeOut2"
,
"QKOut"
,
"QKTVOut"
,
"SoftmaxOut"
,
"AttnDropoutMaskOut"
,
"AttnDropoutOut"
,
"SrcMaskOut"
,
"FMHAOut"
,
"OutLinearOut"
,
"DropoutMaskOut"
,
"Ln2Mean"
,
"Ln2Variance"
,
"BiasDropoutResidualOut"
,
"Y"
}},
{
"sync_batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
{
"unique"
,
{
"Out"
,
"Index"
,
"Indices"
,
"Counts"
}},
{
"unique_consecutive"
,
{
"Out"
,
"Index"
,
"Counts"
}},
{
"generate_proposals"
,
{
"RpnRois"
,
"RpnRoiProbs"
,
"RpnRoisNum"
}},
{
"collect_fpn_proposals"
,
{
"FpnRois"
,
"RoisNum"
}},
{
"matrix_nms"
,
{
"Out"
,
"Index"
,
"RoisNum"
}},
{
"distribute_fpn_proposals"
,
{
"MultiFpnRois"
,
"RestoreIndex"
,
"MultiLevelRoIsNum"
}},
{
"moving_average_abs_max_scale"
,
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"multiclass_nms3"
,
{
"Out"
,
"NmsRoisNum"
}},
{
"generate_proposals_v2"
,
{
"RpnRois"
,
"RpnRoiProbs"
,
"RpnRoisNum"
}},
{
"momentum"
,
{
"ParamOut"
,
"VelocityOut"
,
"MasterParamOut"
}},
{
"sparse_momentum"
,
{
"ParamOut"
,
"VelocityOut"
}},
{
"rnn"
,
{
"DropoutState"
,
"Reserve"
,
"Out"
,
"State"
}},
{
"lamb"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
}},
{
"run_program"
,
{
"DOut"
}},
{
"adam"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
{
"adamw"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
// generated in C++ automatically.
// generated in C++ automatically.
// However, some OPs need to pass the outputs from Python instead of generating
// However, some OPs need to pass the outputs from Python instead of generating
...
...
paddle/fluid/pybind/op_function_generator.h
0 → 100644
浏览文件 @
06c3cce9
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <set>
#include <string>
// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are
// determined by the OP`s proto automatically, i.e., all the inputs registered
// in OpMaker.
// However, some OPs have dispensable inputs, which means the input can
// be none for some conditions. It is discovered that most dispensable inputs
// is not used in imperative mode, so we drop those inputs when generating OP
// functions. While, for very few OPs, the dispensable inputs are used, we
// need to manually specify them in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_ins_map
=
{
{
"layer_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"bincount"
,
{
"X"
,
"Weights"
}},
{
"fused_attention"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"SrcMask"
,
"OutLinearW"
,
"OutLinearBias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"instance_norm"
,
{
"X"
,
"Scale"
,
"Bias"
}},
{
"gru_unit"
,
{
"Input"
,
"HiddenPrev"
,
"Weight"
,
"Bias"
}},
{
"label_smooth"
,
{
"X"
,
"PriorDist"
}},
{
"assign"
,
{
"X"
}},
{
"reshape2"
,
{
"X"
,
"Shape"
}},
{
"expand"
,
{
"X"
,
"ExpandTimes"
}},
{
"slice"
,
{
"Input"
,
"StartsTensor"
,
"EndsTensor"
}},
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"X"
,
"InScale"
,
"InAccum"
,
"InState"
}},
{
"nll_loss"
,
{
"X"
,
"Label"
,
"Weight"
}},
{
"bilinear_tensor_product"
,
{
"X"
,
"Y"
,
"Weight"
,
"Bias"
}},
{
"gather"
,
{
"X"
,
"Index"
,
"Axis"
}},
{
"roi_pool"
,
{
"X"
,
"ROIs"
,
"RoisNum"
}},
{
"roi_align"
,
{
"X"
,
"ROIs"
,
"RoisNum"
}},
{
"psroi_pool"
,
{
"X"
,
"ROIs"
,
"RoisNum"
}},
{
"collect_fpn_proposals"
,
{
"MultiLevelRois"
,
"MultiLevelScores"
,
"MultiLevelRoIsNum"
}},
{
"distribute_fpn_proposals"
,
{
"FpnRois"
,
"RoisNum"
}},
{
"warpctc"
,
{
"Logits"
,
"Label"
,
"LogitsLength"
,
"LabelLength"
}},
{
"hierarchical_sigmoid"
,
{
"X"
,
"W"
,
"Label"
,
"PathTable"
,
"PathCode"
,
"Bias"
}},
{
"moving_average_abs_max_scale"
,
{
"X"
,
"InAccum"
,
"InState"
}},
{
"multiclass_nms3"
,
{
"BBoxes"
,
"Scores"
,
"RoisNum"
}},
{
"box_coder"
,
{
"PriorBox"
,
"PriorBoxVar"
,
"TargetBox"
}},
{
"momentum"
,
{
"Param"
,
"Grad"
,
"Velocity"
,
"LearningRate"
,
"MasterParam"
}},
{
"sparse_momentum"
,
{
"Param"
,
"Grad"
,
"Velocity"
,
"Index"
,
"LearningRate"
}},
{
"rnn"
,
{
"Input"
,
"PreState"
,
"WeightList"
,
"SequenceLength"
}},
{
"run_program"
,
{
"X"
,
"Params"
}},
{
"fused_feedforward"
,
{
"Dropout1Seed"
,
"Dropout2Seed"
,
"Linear1Bias"
,
"Linear2Bias"
,
"Ln1Scale"
,
"Ln1Bias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"faster_tokenizer"
,
{
"Text"
,
"Vocab"
,
"TextPair"
}},
{
"matrix_rank"
,
{
"X"
,
"TolTensor"
}},
{
"adam"
,
{
"Param"
,
"Grad"
,
"LearningRate"
,
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
,
"MasterParam"
}},
{
"adamw"
,
{
"Param"
,
"Grad"
,
"LearningRate"
,
"Moment1"
,
"Moment2"
,
"Beta1Pow"
,
"Beta2Pow"
,
"MasterParam"
}},
};
// NOTE(zhiqiu): Like op_ins_map.
// Commonly, the outputs in auto-generated OP function are determined by the
// OP`s proto automatically, i.e., all the outputs registered in OpMaker.
// However, some OPs have dispensable outputs, which means the output can
// be none for some conditions. It is discovered that most dispensable outputs
// is not used in imperative mode, so we drop those outputs when generating OP
// functions. While, for very few OPs, the dispensable outputs are used, we
// need to manually specify them in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_outs_map
=
{
{
"fake_quantize_dequantize_moving_average_abs_max"
,
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
{
"fused_attention"
,
{
"LnMean"
,
"LnVariance"
,
"LnOut"
,
"QKVOut"
,
"QKVBiasOut"
,
"TransposeOut2"
,
"QKOut"
,
"QKTVOut"
,
"SoftmaxOut"
,
"AttnDropoutMaskOut"
,
"AttnDropoutOut"
,
"SrcMaskOut"
,
"FMHAOut"
,
"OutLinearOut"
,
"DropoutMaskOut"
,
"Ln2Mean"
,
"Ln2Variance"
,
"BiasDropoutResidualOut"
,
"Y"
}},
{
"sync_batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
{
"unique"
,
{
"Out"
,
"Index"
,
"Indices"
,
"Counts"
}},
{
"unique_consecutive"
,
{
"Out"
,
"Index"
,
"Counts"
}},
{
"generate_proposals"
,
{
"RpnRois"
,
"RpnRoiProbs"
,
"RpnRoisNum"
}},
{
"collect_fpn_proposals"
,
{
"FpnRois"
,
"RoisNum"
}},
{
"matrix_nms"
,
{
"Out"
,
"Index"
,
"RoisNum"
}},
{
"distribute_fpn_proposals"
,
{
"MultiFpnRois"
,
"RestoreIndex"
,
"MultiLevelRoIsNum"
}},
{
"moving_average_abs_max_scale"
,
{
"Out"
,
"OutScale"
,
"OutAccum"
,
"OutState"
}},
{
"multiclass_nms3"
,
{
"Out"
,
"NmsRoisNum"
}},
{
"generate_proposals_v2"
,
{
"RpnRois"
,
"RpnRoiProbs"
,
"RpnRoisNum"
}},
{
"momentum"
,
{
"ParamOut"
,
"VelocityOut"
,
"MasterParamOut"
}},
{
"sparse_momentum"
,
{
"ParamOut"
,
"VelocityOut"
}},
{
"rnn"
,
{
"DropoutState"
,
"Reserve"
,
"Out"
,
"State"
}},
{
"lamb"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
}},
{
"run_program"
,
{
"DOut"
}},
{
"adam"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
{
"adamw"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
};
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录