Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
de874cdd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
de874cdd
编写于
12月 07, 2021
作者:
Z
Zhanlue Yang
提交者:
GitHub
12月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enabled generation for special operators, the GradNode/Inputs/Outputs of which are empty (#37837)
上级
27d1f811
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
200 addition
and
144 deletion
+200
-144
paddle/fluid/eager/auto_code_generator/eager_generator.cc
paddle/fluid/eager/auto_code_generator/eager_generator.cc
+200
-143
paddle/fluid/eager/auto_code_generator/op_list.txt
paddle/fluid/eager/auto_code_generator/op_list.txt
+0
-1
未找到文件。
paddle/fluid/eager/auto_code_generator/eager_generator.cc
浏览文件 @
de874cdd
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <gflags/gflags.h>
#include <algorithm>
#include <algorithm>
#include <fstream>
#include <fstream>
#include <iostream>
#include <iostream>
...
@@ -27,69 +26,21 @@
...
@@ -27,69 +26,21 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/string_helper.h"
DEFINE_bool
(
generate_all
,
false
,
namespace
paddle
{
"Generate all operators currently registered in Paddle"
);
namespace
framework
{
static
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
AttributeMap
>
static
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
AttributeMap
>
operators_with_attrs
=
{};
operators_with_attrs
=
{};
static
std
::
unordered_set
<
std
::
string
>
operators_to_skip
=
{
static
std
::
unordered_set
<
std
::
string
>
operators_to_skip
=
{
"pull_sparse"
,
"pull_box_extended_sparse"
,
"pull_sparse_v2"
,
"chunk_eval"
,
// Stupid tensor name
"pull_box_sparse"
,
"fused_attention"
,
"diag_v2"
,
"minus"
,
"pull_sparse"
,
"pull_box_extended_sparse"
,
"c_split"
};
"pull_sparse_v2"
,
"pull_box_sparse"
,
"fused_attention"
,
"diag_v2"
,
"c_split"
};
static
std
::
unordered_set
<
std
::
string
>
operators_to_codegen
=
{};
static
std
::
unordered_set
<
std
::
string
>
operators_to_codegen
=
{};
static
std
::
unordered_set
<
std
::
string
>
skipped_operators
=
{};
static
std
::
unordered_set
<
std
::
string
>
skipped_operators
=
{};
static
void
PrepareAttrMapForOps
()
{
// Handle "fused_elemwise_add_activation"
std
::
vector
<
std
::
string
>
functor_list
=
{
"a"
,
"b"
};
operators_with_attrs
[
"fused_elemwise_add_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_add_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "fused_elemwise_activation"
operators_with_attrs
[
"fused_elemwise_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "reverse"
std
::
vector
<
int
>
axis
=
{
0
};
operators_with_attrs
[
"reverse"
]
=
{};
operators_with_attrs
[
"reverse"
][
"axis"
]
=
axis
;
// Handle "flip"
operators_with_attrs
[
"flip"
]
=
{};
operators_with_attrs
[
"flip"
][
"axis"
]
=
axis
;
// Handle "cast"
operators_with_attrs
[
"cast"
]
=
{};
operators_with_attrs
[
"cast"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"cast"
][
"in_dtype"
]
=
5
;
// Handle "transfer_dtype"
operators_with_attrs
[
"transfer_dtype"
]
=
{};
operators_with_attrs
[
"transfer_dtype"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"transfer_dtype"
][
"in_dtype"
]
=
5
;
}
static
void
CollectOperatorsToCodeGen
(
const
std
::
string
&
op_list_path
)
{
std
::
string
line
;
std
::
ifstream
op_list_file
(
op_list_path
);
if
(
op_list_file
.
is_open
())
{
while
(
getline
(
op_list_file
,
line
))
{
operators_to_codegen
.
insert
(
line
);
}
op_list_file
.
close
();
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Unable to open op_list.txt file"
));
}
}
namespace
paddle
{
namespace
framework
{
static
std
::
string
AttrTypeToString
(
const
proto
::
AttrType
&
type
)
{
static
std
::
string
AttrTypeToString
(
const
proto
::
AttrType
&
type
)
{
std
::
string
ret
;
std
::
string
ret
;
switch
(
type
)
{
switch
(
type
)
{
...
@@ -392,10 +343,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
...
@@ -392,10 +343,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now
// Only handle matmul_v2 for now
VLOG
(
1
)
<<
"------ Analyzing Op ------: "
<<
op_type
;
VLOG
(
1
)
<<
"------ Analyzing Op ------: "
<<
op_type
;
if
(
!
FLAGS_generate_all
)
{
if
(
!
operators_to_codegen
.
count
(
op_type
))
return
false
;
if
(
!
operators_to_codegen
.
count
(
op_type
))
return
false
;
}
if
(
operators_to_skip
.
count
(
op_type
))
return
false
;
if
(
operators_to_skip
.
count
(
op_type
))
return
false
;
return
true
;
return
true
;
...
@@ -404,21 +352,12 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
...
@@ -404,21 +352,12 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
/* --------------------------------------- */
/* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */
/* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */
/* --------------------------------------- */
static
void
PurifyOpProto
(
static
void
Purify
Forward
OpProto
(
const
proto
::
OpProto
&
op_proto
,
const
proto
::
OpProto
&
op_proto
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_inputs_name_pos_map
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_inputs_name_pos_map
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_outputs_name_pos_map
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_outputs_name_pos_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
)
{
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
)
{
// Op Name
// Op Name
const
std
::
string
op_name
=
op_proto
.
type
();
const
std
::
string
op_name
=
op_proto
.
type
();
...
@@ -440,6 +379,72 @@ static void PurifyOpProto(
...
@@ -440,6 +379,72 @@ static void PurifyOpProto(
}
}
}
}
in_vars
->
erase
(
iter
);
in_vars
->
erase
(
iter
);
}
}
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
std
::
string
output_name
=
output
.
name
();
// Delete dispensable tensor unless specified in op_outs_map
if
(
output
.
dispensable
())
{
if
(
!
op_outs_map
.
count
(
op_name
)
||
!
op_outs_map
[
op_name
].
count
(
output_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Output: "
<<
output_name
;
// out_vars
auto
iter
=
out_vars
->
begin
();
for
(
iter
=
out_vars
->
begin
();
iter
!=
out_vars
->
end
();
iter
++
)
{
if
(
iter
->
name
()
==
output_name
)
{
break
;
}
}
out_vars
->
erase
(
iter
);
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t
in_pos
=
0
;
for
(
const
auto
&
var
:
*
in_vars
)
{
VLOG
(
6
)
<<
"Mapping input tensor: "
<<
var
.
name
()
<<
" To position: "
<<
in_pos
;
(
*
fwd_inputs_name_pos_map
)[
var
.
name
()]
=
in_pos
;
in_pos
++
;
}
size_t
out_pos
=
0
;
for
(
const
auto
&
var
:
*
out_vars
)
{
VLOG
(
6
)
<<
"Mapping output tensor: "
<<
var
.
name
()
<<
" To position: "
<<
out_pos
;
(
*
fwd_outputs_name_pos_map
)[
var
.
name
()]
=
out_pos
;
out_pos
++
;
}
}
static
void
PurifyGradOpProto
(
const
proto
::
OpProto
&
op_proto
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
)
{
// Op Name
const
std
::
string
op_name
=
op_proto
.
type
();
// Handle dispensable inputs
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
std
::
string
input_name
=
input
.
name
();
// Delete dispensable tensor unless specified in op_ins_map
if
(
input
.
dispensable
())
{
if
(
!
op_ins_map
.
count
(
op_name
)
||
!
op_ins_map
[
op_name
].
count
(
input_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Input: "
<<
input_name
;
// grad_outs_slotname_map
// grad_outs_slotname_map
auto
grad_outs_slotname_map_purified
=
*
grad_outs_slotname_map
;
auto
grad_outs_slotname_map_purified
=
*
grad_outs_slotname_map
;
...
@@ -478,15 +483,6 @@ static void PurifyOpProto(
...
@@ -478,15 +483,6 @@ static void PurifyOpProto(
!
op_outs_map
[
op_name
].
count
(
output_name
))
{
!
op_outs_map
[
op_name
].
count
(
output_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Output: "
<<
output_name
;
VLOG
(
6
)
<<
"Removing Dispensable Output: "
<<
output_name
;
// out_vars
auto
iter
=
out_vars
->
begin
();
for
(
iter
=
out_vars
->
begin
();
iter
!=
out_vars
->
end
();
iter
++
)
{
if
(
iter
->
name
()
==
output_name
)
{
break
;
}
}
out_vars
->
erase
(
iter
);
// grad_ins_grad_slotname_map
// grad_ins_grad_slotname_map
auto
grad_ins_grad_slotname_map_purified
=
*
grad_ins_grad_slotname_map
;
auto
grad_ins_grad_slotname_map_purified
=
*
grad_ins_grad_slotname_map
;
for
(
const
auto
&
iter
:
*
grad_ins_grad_slotname_map
)
{
for
(
const
auto
&
iter
:
*
grad_ins_grad_slotname_map
)
{
...
@@ -514,52 +510,40 @@ static void PurifyOpProto(
...
@@ -514,52 +510,40 @@ static void PurifyOpProto(
}
}
}
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t
in_pos
=
0
;
for
(
const
auto
&
var
:
*
in_vars
)
{
VLOG
(
6
)
<<
"Mapping input tensor: "
<<
var
.
name
()
<<
" To position: "
<<
in_pos
;
(
*
fwd_inputs_name_pos_map
)[
var
.
name
()]
=
in_pos
;
in_pos
++
;
}
size_t
out_pos
=
0
;
for
(
const
auto
&
var
:
*
out_vars
)
{
VLOG
(
6
)
<<
"Mapping output tensor: "
<<
var
.
name
()
<<
" To position: "
<<
out_pos
;
(
*
fwd_outputs_name_pos_map
)[
var
.
name
()]
=
out_pos
;
out_pos
++
;
}
}
}
/* -------------------------------- */
/* -------------------------------- */
/* --------- Collect Info --------- */
/* --------- Collect Info --------- */
/* -------------------------------- */
/* -------------------------------- */
static
bool
Collect
InformationFromOpInfo
(
static
void
CollectForward
InformationFromOpInfo
(
const
paddle
::
framework
::
OpInfo
&
op_info
,
const
paddle
::
framework
::
OpInfo
&
op_info
,
std
::
vector
<
std
::
string
>*
grad_op_types
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
)
{
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
)
{
const
proto
::
OpProto
&
op_proto
=
*
op_info
.
proto_
;
const
proto
::
OpProto
&
op_proto
=
*
op_info
.
proto_
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
vector
<
int64_t
>
dims
=
{
1
,
1
,
1
,
1
};
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
in_vars
->
push_back
(
input
);
in_vars
->
push_back
(
input
);
}
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
out_vars
->
push_back
(
output
);
out_vars
->
push_back
(
output
);
}
}
}
static
bool
CollectGradInformationFromOpInfo
(
const
paddle
::
framework
::
OpInfo
&
op_info
,
bool
*
generate_forward_only
,
std
::
vector
<
std
::
string
>*
grad_op_types
,
// grad
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
// grad
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
// grad
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
// grad
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
// grad
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
// grad
)
{
const
proto
::
OpProto
&
op_proto
=
*
op_info
.
proto_
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
vector
<
int64_t
>
dims
=
{
1
,
1
,
1
,
1
};
/* ------ Prepare "ins" ------ */
/* ------ Prepare "ins" ------ */
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
...
@@ -621,8 +605,6 @@ static bool CollectInformationFromOpInfo(
...
@@ -621,8 +605,6 @@ static bool CollectInformationFromOpInfo(
if
(
operators_with_attrs
.
count
(
op_type
))
{
if
(
operators_with_attrs
.
count
(
op_type
))
{
VLOG
(
6
)
<<
"Found operator "
<<
op_type
<<
" using special AttributeMap"
;
VLOG
(
6
)
<<
"Found operator "
<<
op_type
<<
" using special AttributeMap"
;
attrs
=
operators_with_attrs
[
op_type
];
attrs
=
operators_with_attrs
[
op_type
];
// default_attrs.insert(operators_with_attrs[op_type].begin(),
// operators_with_attrs[op_type].end());
}
}
VLOG
(
6
)
<<
"Prepared Default Attributes Map, size = "
<<
default_attrs
.
size
();
VLOG
(
6
)
<<
"Prepared Default Attributes Map, size = "
<<
default_attrs
.
size
();
...
@@ -655,8 +637,8 @@ static bool CollectInformationFromOpInfo(
...
@@ -655,8 +637,8 @@ static bool CollectInformationFromOpInfo(
/* ------ Run GradOpMaker ------ */
/* ------ Run GradOpMaker ------ */
if
(
!
op_info
.
dygraph_grad_op_maker_
)
{
if
(
!
op_info
.
dygraph_grad_op_maker_
)
{
VLOG
(
6
)
<<
op_type
<<
" has no GradOpMaker
, skip it
"
;
VLOG
(
6
)
<<
op_type
<<
" has no GradOpMaker"
;
skipped_operators
.
insert
(
op_type
)
;
*
generate_forward_only
=
true
;
return
false
;
return
false
;
}
}
...
@@ -666,17 +648,19 @@ static bool CollectInformationFromOpInfo(
...
@@ -666,17 +648,19 @@ static bool CollectInformationFromOpInfo(
if
(
!
grad_node
)
{
if
(
!
grad_node
)
{
VLOG
(
6
)
<<
"Got nullptr GradOpNode for "
<<
op_type
VLOG
(
6
)
<<
"Got nullptr GradOpNode for "
<<
op_type
<<
" likely registered EmptyGradOpMaker
, skip it
"
;
<<
" likely registered EmptyGradOpMaker"
;
skipped_operators
.
insert
(
op_type
)
;
*
generate_forward_only
=
true
;
return
false
;
return
false
;
}
}
/*
if (grad_node->size() > 1) {
if (grad_node->size() > 1) {
// Backward attributes can be super complicated
// Backward attributes can be super complicated
VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type;
VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type;
skipped_operators.insert(op_type);
skipped_operators.insert(op_type);
return false;
return false;
}
}
*/
VLOG
(
6
)
<<
"Prepared GradOpNode"
;
VLOG
(
6
)
<<
"Prepared GradOpNode"
;
...
@@ -901,6 +885,7 @@ static std::string GenerateGradNodeCreationContent(
...
@@ -901,6 +885,7 @@ static std::string GenerateGradNodeCreationContent(
/* --------- CodeGen: Forward ----- */
/* --------- CodeGen: Forward ----- */
/* -------------------------------- */
/* -------------------------------- */
static
std
::
pair
<
std
::
string
,
std
::
string
>
GenerateForwardFunctionContents
(
static
std
::
pair
<
std
::
string
,
std
::
string
>
GenerateForwardFunctionContents
(
bool
generate_forward_only
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
...
@@ -1044,7 +1029,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1044,7 +1029,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Attrs
// [Generation] Get Attrs
dygraph_function_args_str
+=
dygraph_function_args_str
+=
", const paddle::framework::AttributeMap& attr_map"
;
", const paddle::framework::AttributeMap& attr_map"
;
generated_function_body
+=
"
\n
"
;
// [Generation] Get TraceOp
// [Generation] Get TraceOp
const
char
*
FWD_TRACE_OP_TEMPLATE
=
const
char
*
FWD_TRACE_OP_TEMPLATE
=
...
@@ -1092,16 +1076,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1092,16 +1076,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG
(
6
)
<<
"Converted Output VarBase to EagerTensor(s)"
;
VLOG
(
6
)
<<
"Converted Output VarBase to EagerTensor(s)"
;
// [Generation] ComputeRequireGrad -> GradNodeCreation
// [Generation] ComputeRequireGrad -> GradNodeCreation
if
(
!
generate_forward_only
)
{
std
::
string
grad_node_creation_body_str
=
GenerateGradNodeCreationContent
(
std
::
string
grad_node_creation_body_str
=
GenerateGradNodeCreationContent
(
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
op_type
,
in_vars
,
out_vars
);
grad_ins_fwd_slotname_map
,
op_type
,
in_vars
,
out_vars
);
generated_function_body
+=
grad_node_creation_body_str
;
generated_function_body
+=
grad_node_creation_body_str
;
generated_function_body
+=
"
\n
"
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated GradNode Creation codes"
;
VLOG
(
6
)
<<
"Generated GradNode Creation codes"
;
}
// [Generation] Handle return: Tuple/Vector/Tensor
// [Generation] Handle return: Tuple/Vector/Tensor
generated_function_body
+=
"
\n
"
;
generated_function_body
+=
"
\n
"
;
std
::
string
return_str
;
std
::
string
return_str
=
""
;
std
::
string
return_type_str
=
""
;
std
::
string
return_type_str
=
""
;
std
::
string
function_proto_return_type_str
=
""
;
std
::
string
function_proto_return_type_str
=
""
;
if
(
return_contents
.
size
()
>
1
)
{
if
(
return_contents
.
size
()
>
1
)
{
...
@@ -1124,14 +1110,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1124,14 +1110,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const
char
*
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
=
"std::tuple<%s>"
;
const
char
*
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
=
"std::tuple<%s>"
;
function_proto_return_type_str
=
paddle
::
string
::
Sprintf
(
function_proto_return_type_str
=
paddle
::
string
::
Sprintf
(
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
,
return_type_str
);
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
,
return_type_str
);
}
else
{
}
else
if
(
return_contents
.
size
()
==
1
)
{
// Return vector<Tensor> or Tensor
// Return vector<Tensor> or Tensor
return_type_str
=
return_types
[
0
];
return_type_str
=
return_types
[
0
];
const
char
*
FWD_TENSOR_RETURN_TEMPLATE
=
" return %s;"
;
const
char
*
FWD_TENSOR_RETURN_TEMPLATE
=
" return %s;"
;
return_str
=
return_str
=
paddle
::
string
::
Sprintf
(
FWD_TENSOR_RETURN_TEMPLATE
,
return_contents
[
0
]);
paddle
::
string
::
Sprintf
(
FWD_TENSOR_RETURN_TEMPLATE
,
return_contents
[
0
]);
function_proto_return_type_str
=
return_type_str
;
function_proto_return_type_str
=
return_type_str
;
}
else
{
return_str
=
"return nullptr;"
;
function_proto_return_type_str
=
"void*"
;
}
}
generated_function_body
+=
return_str
;
generated_function_body
+=
return_str
;
generated_function_body
+=
"
\n
"
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated return codes"
;
VLOG
(
6
)
<<
"Generated return codes"
;
...
@@ -1139,6 +1131,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
...
@@ -1139,6 +1131,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Full Function
// [Generation] Get Full Function
std
::
string
function_name
=
op_type
+
"_dygraph_function"
;
std
::
string
function_name
=
op_type
+
"_dygraph_function"
;
if
(
dygraph_function_args_str
.
size
()
>
0
)
{
auto
iter
=
dygraph_function_args_str
.
begin
();
if
((
*
iter
)
==
','
)
dygraph_function_args_str
.
erase
(
iter
);
}
const
char
*
FWD_FUNCTION_TEMPLATE
=
"%s %s(%s) {
\n\n
%s
\n
}
\n\n
"
;
const
char
*
FWD_FUNCTION_TEMPLATE
=
"%s %s(%s) {
\n\n
%s
\n
}
\n\n
"
;
std
::
string
fwd_function_str
=
paddle
::
string
::
Sprintf
(
std
::
string
fwd_function_str
=
paddle
::
string
::
Sprintf
(
FWD_FUNCTION_TEMPLATE
,
function_proto_return_type_str
,
function_name
,
FWD_FUNCTION_TEMPLATE
,
function_proto_return_type_str
,
function_name
,
...
@@ -1601,11 +1598,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1601,11 +1598,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ---- Collect Information ---- */
/* ---- Collect Information ---- */
/* ----------------------------- */
/* ----------------------------- */
std
::
vector
<
std
::
string
>
grad_op_types
;
std
::
vector
<
std
::
string
>
grad_op_types
;
std
::
vector
<
proto
::
OpProto
::
Var
>
in_vars
;
std
::
vector
<
proto
::
OpProto
::
Var
>
out_vars
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_outs_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_outs_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_fwd_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_fwd_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_grad_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_grad_slotname_map
;
std
::
vector
<
proto
::
OpProto
::
Var
>
in_vars
;
std
::
vector
<
proto
::
OpProto
::
Var
>
out_vars
;
std
::
map
<
std
::
string
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>
grad_ins
;
grad_ins
;
...
@@ -1614,20 +1611,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1614,20 +1611,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
grad_outs
;
grad_outs
;
VLOG
(
6
)
<<
"-------- CollectInformationFromOpInfo -------"
;
VLOG
(
6
)
<<
"-------- CollectInformationFromOpInfo -------"
;
bool
is_available
=
CollectInformationFromOpInfo
(
op_info
,
&
grad_op_types
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
in_vars
,
&
out_vars
,
&
grad_ins
,
&
grad_outs
);
if
(
!
is_available
)
continue
;
CollectForwardInformationFromOpInfo
(
op_info
,
&
in_vars
,
&
out_vars
);
bool
generate_forward_only
=
false
;
bool
is_available
=
CollectGradInformationFromOpInfo
(
op_info
,
&
generate_forward_only
,
&
grad_op_types
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
grad_ins
,
&
grad_outs
);
if
(
!
is_available
&&
!
generate_forward_only
)
{
VLOG
(
6
)
<<
"Skipped operator: "
<<
op_type
;
continue
;
}
VLOG
(
6
)
<<
"-------- PurifyOpProto -------"
;
VLOG
(
6
)
<<
"-------- PurifyOpProto -------"
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_inputs_name_pos_map
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_inputs_name_pos_map
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_outputs_name_pos_map
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_outputs_name_pos_map
;
PurifyOpProto
(
*
op_proto
,
&
fwd_inputs_name_pos_map
,
PurifyForwardOpProto
(
*
op_proto
,
&
fwd_inputs_name_pos_map
,
&
fwd_outputs_name_pos_map
,
&
grad_outs_slotname_map
,
&
fwd_outputs_name_pos_map
,
&
in_vars
,
&
out_vars
);
if
(
!
generate_forward_only
)
{
PurifyGradOpProto
(
*
op_proto
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
in_vars
,
&
out_vars
,
&
grad_ins
,
&
grad_outs
);
&
grad_ins
,
&
grad_outs
);
}
/* --------------------------- */
/* --------------------------- */
/* --------- CodeGen --------- */
/* --------- CodeGen --------- */
...
@@ -1636,16 +1644,19 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1636,16 +1644,19 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG
(
6
)
<<
"-------- GenerateForwardFunctionContents -------"
;
VLOG
(
6
)
<<
"-------- GenerateForwardFunctionContents -------"
;
std
::
pair
<
std
::
string
,
std
::
string
>
body_and_declaration
=
std
::
pair
<
std
::
string
,
std
::
string
>
body_and_declaration
=
GenerateForwardFunctionContents
(
GenerateForwardFunctionContents
(
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
generate_forward_only
,
fwd_inputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
grad_ins_grad_slotname_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
grad_outs_slotname_map
,
grad_ins
,
grad_outs
,
op_type
,
in_vars
,
grad_ins_grad_slotname_map
,
grad_outs_slotname_map
,
grad_ins
,
out_vars
);
grad_outs
,
op_type
,
in_vars
,
out_vars
);
fwd_function_str
+=
body_and_declaration
.
first
+
"
\n
"
;
fwd_function_str
+=
body_and_declaration
.
first
+
"
\n
"
;
/* ---- dygraph_forward_api.h ---- */
/* ---- dygraph_forward_api.h ---- */
std
::
string
fwd_function_declare_str
=
body_and_declaration
.
second
;
std
::
string
fwd_function_declare_str
=
body_and_declaration
.
second
;
dygraph_forward_api_str
+=
fwd_function_declare_str
;
dygraph_forward_api_str
+=
fwd_function_declare_str
;
if
(
generate_forward_only
)
continue
;
/* ---- nodes.h ---- */
/* ---- nodes.h ---- */
VLOG
(
6
)
<<
"-------- GenerateGradNodeHeaderContents -------"
;
VLOG
(
6
)
<<
"-------- GenerateGradNodeHeaderContents -------"
;
grad_node_h_str
+=
grad_node_h_str
+=
...
@@ -1681,6 +1692,52 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
...
@@ -1681,6 +1692,52 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateNodeCCFile
(
output_dir
,
grad_node_cc_str
);
GenerateNodeCCFile
(
output_dir
,
grad_node_cc_str
);
}
}
static
void
PrepareAttrMapForOps
()
{
// Handle "fused_elemwise_add_activation"
std
::
vector
<
std
::
string
>
functor_list
=
{
"a"
,
"b"
};
operators_with_attrs
[
"fused_elemwise_add_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_add_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "fused_elemwise_activation"
operators_with_attrs
[
"fused_elemwise_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "reverse"
std
::
vector
<
int
>
axis
=
{
0
};
operators_with_attrs
[
"reverse"
]
=
{};
operators_with_attrs
[
"reverse"
][
"axis"
]
=
axis
;
// Handle "flip"
operators_with_attrs
[
"flip"
]
=
{};
operators_with_attrs
[
"flip"
][
"axis"
]
=
axis
;
// Handle "cast"
operators_with_attrs
[
"cast"
]
=
{};
operators_with_attrs
[
"cast"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"cast"
][
"in_dtype"
]
=
5
;
// Handle "transfer_dtype"
operators_with_attrs
[
"transfer_dtype"
]
=
{};
operators_with_attrs
[
"transfer_dtype"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"transfer_dtype"
][
"in_dtype"
]
=
5
;
}
static
void
CollectOperatorsToCodeGen
(
const
std
::
string
&
op_list_path
)
{
std
::
string
line
;
std
::
ifstream
op_list_file
(
op_list_path
);
if
(
op_list_file
.
is_open
())
{
while
(
getline
(
op_list_file
,
line
))
{
operators_to_codegen
.
insert
(
line
);
}
op_list_file
.
close
();
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Unable to open op_list.txt file"
));
}
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
@@ -1693,8 +1750,8 @@ int main(int argc, char* argv[]) {
...
@@ -1693,8 +1750,8 @@ int main(int argc, char* argv[]) {
std
::
string
eager_root
=
argv
[
1
];
std
::
string
eager_root
=
argv
[
1
];
std
::
string
op_list_path
=
argv
[
2
];
std
::
string
op_list_path
=
argv
[
2
];
CollectOperatorsToCodeGen
(
op_list_path
);
paddle
::
framework
::
CollectOperatorsToCodeGen
(
op_list_path
);
PrepareAttrMapForOps
();
paddle
::
framework
::
PrepareAttrMapForOps
();
paddle
::
framework
::
DygraphCodeGeneration
(
eager_root
);
paddle
::
framework
::
DygraphCodeGeneration
(
eager_root
);
...
...
paddle/fluid/eager/auto_code_generator/op_list.txt
浏览文件 @
de874cdd
...
@@ -215,7 +215,6 @@ spp
...
@@ -215,7 +215,6 @@ spp
floor
floor
gelu
gelu
retinanet_detection_output
retinanet_detection_output
minus
push_dense
push_dense
silu
silu
sequence_erase
sequence_erase
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录