Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
de874cdd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
de874cdd
编写于
12月 07, 2021
作者:
Z
Zhanlue Yang
提交者:
GitHub
12月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enabled generation for special operators, the GradNode/Inputs/Outputs of which are empty (#37837)
上级
27d1f811
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
200 addition
and
144 deletion
+200
-144
paddle/fluid/eager/auto_code_generator/eager_generator.cc
paddle/fluid/eager/auto_code_generator/eager_generator.cc
+200
-143
paddle/fluid/eager/auto_code_generator/op_list.txt
paddle/fluid/eager/auto_code_generator/op_list.txt
+0
-1
未找到文件。
paddle/fluid/eager/auto_code_generator/eager_generator.cc
浏览文件 @
de874cdd
...
...
@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <algorithm>
#include <fstream>
#include <iostream>
...
...
@@ -27,69 +26,21 @@
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h"
DEFINE_bool
(
generate_all
,
false
,
"Generate all operators currently registered in Paddle"
);
namespace
paddle
{
namespace
framework
{
static
std
::
unordered_map
<
std
::
string
,
paddle
::
framework
::
AttributeMap
>
operators_with_attrs
=
{};
static
std
::
unordered_set
<
std
::
string
>
operators_to_skip
=
{
"pull_sparse"
,
"pull_box_extended_sparse"
,
"pull_sparse_v2"
,
"pull_box_sparse"
,
"fused_attention"
,
"diag_v2"
,
"c_split"
};
"chunk_eval"
,
// Stupid tensor name
"minus"
,
"pull_sparse"
,
"pull_box_extended_sparse"
,
"pull_sparse_v2"
,
"pull_box_sparse"
,
"fused_attention"
,
"diag_v2"
,
"c_split"
};
static
std
::
unordered_set
<
std
::
string
>
operators_to_codegen
=
{};
static
std
::
unordered_set
<
std
::
string
>
skipped_operators
=
{};
static
void
PrepareAttrMapForOps
()
{
// Handle "fused_elemwise_add_activation"
std
::
vector
<
std
::
string
>
functor_list
=
{
"a"
,
"b"
};
operators_with_attrs
[
"fused_elemwise_add_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_add_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "fused_elemwise_activation"
operators_with_attrs
[
"fused_elemwise_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "reverse"
std
::
vector
<
int
>
axis
=
{
0
};
operators_with_attrs
[
"reverse"
]
=
{};
operators_with_attrs
[
"reverse"
][
"axis"
]
=
axis
;
// Handle "flip"
operators_with_attrs
[
"flip"
]
=
{};
operators_with_attrs
[
"flip"
][
"axis"
]
=
axis
;
// Handle "cast"
operators_with_attrs
[
"cast"
]
=
{};
operators_with_attrs
[
"cast"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"cast"
][
"in_dtype"
]
=
5
;
// Handle "transfer_dtype"
operators_with_attrs
[
"transfer_dtype"
]
=
{};
operators_with_attrs
[
"transfer_dtype"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"transfer_dtype"
][
"in_dtype"
]
=
5
;
}
static
void
CollectOperatorsToCodeGen
(
const
std
::
string
&
op_list_path
)
{
std
::
string
line
;
std
::
ifstream
op_list_file
(
op_list_path
);
if
(
op_list_file
.
is_open
())
{
while
(
getline
(
op_list_file
,
line
))
{
operators_to_codegen
.
insert
(
line
);
}
op_list_file
.
close
();
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Unable to open op_list.txt file"
));
}
}
namespace
paddle
{
namespace
framework
{
static
std
::
string
AttrTypeToString
(
const
proto
::
AttrType
&
type
)
{
std
::
string
ret
;
switch
(
type
)
{
...
...
@@ -392,10 +343,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now
VLOG
(
1
)
<<
"------ Analyzing Op ------: "
<<
op_type
;
if
(
!
FLAGS_generate_all
)
{
if
(
!
operators_to_codegen
.
count
(
op_type
))
return
false
;
}
if
(
!
operators_to_codegen
.
count
(
op_type
))
return
false
;
if
(
operators_to_skip
.
count
(
op_type
))
return
false
;
return
true
;
...
...
@@ -404,21 +352,12 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
/* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */
static
void
PurifyOpProto
(
static
void
Purify
Forward
OpProto
(
const
proto
::
OpProto
&
op_proto
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_inputs_name_pos_map
,
std
::
unordered_map
<
std
::
string
,
size_t
>*
fwd_outputs_name_pos_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
)
{
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
)
{
// Op Name
const
std
::
string
op_name
=
op_proto
.
type
();
...
...
@@ -440,6 +379,72 @@ static void PurifyOpProto(
}
}
in_vars
->
erase
(
iter
);
}
}
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
std
::
string
output_name
=
output
.
name
();
// Delete dispensable tensor unless specified in op_outs_map
if
(
output
.
dispensable
())
{
if
(
!
op_outs_map
.
count
(
op_name
)
||
!
op_outs_map
[
op_name
].
count
(
output_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Output: "
<<
output_name
;
// out_vars
auto
iter
=
out_vars
->
begin
();
for
(
iter
=
out_vars
->
begin
();
iter
!=
out_vars
->
end
();
iter
++
)
{
if
(
iter
->
name
()
==
output_name
)
{
break
;
}
}
out_vars
->
erase
(
iter
);
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t
in_pos
=
0
;
for
(
const
auto
&
var
:
*
in_vars
)
{
VLOG
(
6
)
<<
"Mapping input tensor: "
<<
var
.
name
()
<<
" To position: "
<<
in_pos
;
(
*
fwd_inputs_name_pos_map
)[
var
.
name
()]
=
in_pos
;
in_pos
++
;
}
size_t
out_pos
=
0
;
for
(
const
auto
&
var
:
*
out_vars
)
{
VLOG
(
6
)
<<
"Mapping output tensor: "
<<
var
.
name
()
<<
" To position: "
<<
out_pos
;
(
*
fwd_outputs_name_pos_map
)[
var
.
name
()]
=
out_pos
;
out_pos
++
;
}
}
static
void
PurifyGradOpProto
(
const
proto
::
OpProto
&
op_proto
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
)
{
// Op Name
const
std
::
string
op_name
=
op_proto
.
type
();
// Handle dispensable inputs
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
std
::
string
input_name
=
input
.
name
();
// Delete dispensable tensor unless specified in op_ins_map
if
(
input
.
dispensable
())
{
if
(
!
op_ins_map
.
count
(
op_name
)
||
!
op_ins_map
[
op_name
].
count
(
input_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Input: "
<<
input_name
;
// grad_outs_slotname_map
auto
grad_outs_slotname_map_purified
=
*
grad_outs_slotname_map
;
...
...
@@ -478,15 +483,6 @@ static void PurifyOpProto(
!
op_outs_map
[
op_name
].
count
(
output_name
))
{
VLOG
(
6
)
<<
"Removing Dispensable Output: "
<<
output_name
;
// out_vars
auto
iter
=
out_vars
->
begin
();
for
(
iter
=
out_vars
->
begin
();
iter
!=
out_vars
->
end
();
iter
++
)
{
if
(
iter
->
name
()
==
output_name
)
{
break
;
}
}
out_vars
->
erase
(
iter
);
// grad_ins_grad_slotname_map
auto
grad_ins_grad_slotname_map_purified
=
*
grad_ins_grad_slotname_map
;
for
(
const
auto
&
iter
:
*
grad_ins_grad_slotname_map
)
{
...
...
@@ -514,52 +510,40 @@ static void PurifyOpProto(
}
}
}
/* ------ Maping forward slot name to fwd position ------ */
size_t
in_pos
=
0
;
for
(
const
auto
&
var
:
*
in_vars
)
{
VLOG
(
6
)
<<
"Mapping input tensor: "
<<
var
.
name
()
<<
" To position: "
<<
in_pos
;
(
*
fwd_inputs_name_pos_map
)[
var
.
name
()]
=
in_pos
;
in_pos
++
;
}
size_t
out_pos
=
0
;
for
(
const
auto
&
var
:
*
out_vars
)
{
VLOG
(
6
)
<<
"Mapping output tensor: "
<<
var
.
name
()
<<
" To position: "
<<
out_pos
;
(
*
fwd_outputs_name_pos_map
)[
var
.
name
()]
=
out_pos
;
out_pos
++
;
}
}
/* -------------------------------- */
/* --------- Collect Info --------- */
/* -------------------------------- */
static
bool
Collect
InformationFromOpInfo
(
static
void
CollectForward
InformationFromOpInfo
(
const
paddle
::
framework
::
OpInfo
&
op_info
,
std
::
vector
<
std
::
string
>*
grad_op_types
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
in_vars
,
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
)
{
std
::
vector
<
proto
::
OpProto
::
Var
>*
out_vars
)
{
const
proto
::
OpProto
&
op_proto
=
*
op_info
.
proto_
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
vector
<
int64_t
>
dims
=
{
1
,
1
,
1
,
1
};
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
in_vars
->
push_back
(
input
);
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
out_vars
->
push_back
(
output
);
}
}
static
bool
CollectGradInformationFromOpInfo
(
const
paddle
::
framework
::
OpInfo
&
op_info
,
bool
*
generate_forward_only
,
std
::
vector
<
std
::
string
>*
grad_op_types
,
// grad
std
::
map
<
std
::
string
,
std
::
string
>*
grad_outs_slotname_map
,
// grad
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_fwd_slotname_map
,
// grad
std
::
map
<
std
::
string
,
std
::
string
>*
grad_ins_grad_slotname_map
,
// grad
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_ins
,
// grad
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>*
grad_outs
// grad
)
{
const
proto
::
OpProto
&
op_proto
=
*
op_info
.
proto_
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
vector
<
int64_t
>
dims
=
{
1
,
1
,
1
,
1
};
/* ------ Prepare "ins" ------ */
std
::
map
<
std
::
string
,
...
...
@@ -621,8 +605,6 @@ static bool CollectInformationFromOpInfo(
if
(
operators_with_attrs
.
count
(
op_type
))
{
VLOG
(
6
)
<<
"Found operator "
<<
op_type
<<
" using special AttributeMap"
;
attrs
=
operators_with_attrs
[
op_type
];
// default_attrs.insert(operators_with_attrs[op_type].begin(),
// operators_with_attrs[op_type].end());
}
VLOG
(
6
)
<<
"Prepared Default Attributes Map, size = "
<<
default_attrs
.
size
();
...
...
@@ -655,8 +637,8 @@ static bool CollectInformationFromOpInfo(
/* ------ Run GradOpMaker ------ */
if
(
!
op_info
.
dygraph_grad_op_maker_
)
{
VLOG
(
6
)
<<
op_type
<<
" has no GradOpMaker
, skip it
"
;
skipped_operators
.
insert
(
op_type
)
;
VLOG
(
6
)
<<
op_type
<<
" has no GradOpMaker"
;
*
generate_forward_only
=
true
;
return
false
;
}
...
...
@@ -666,17 +648,19 @@ static bool CollectInformationFromOpInfo(
if
(
!
grad_node
)
{
VLOG
(
6
)
<<
"Got nullptr GradOpNode for "
<<
op_type
<<
" likely registered EmptyGradOpMaker
, skip it
"
;
skipped_operators
.
insert
(
op_type
)
;
<<
" likely registered EmptyGradOpMaker"
;
*
generate_forward_only
=
true
;
return
false
;
}
/*
if (grad_node->size() > 1) {
// Backward attributes can be super complicated
VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type;
skipped_operators.insert(op_type);
return false;
}
*/
VLOG
(
6
)
<<
"Prepared GradOpNode"
;
...
...
@@ -901,6 +885,7 @@ static std::string GenerateGradNodeCreationContent(
/* --------- CodeGen: Forward ----- */
/* -------------------------------- */
static
std
::
pair
<
std
::
string
,
std
::
string
>
GenerateForwardFunctionContents
(
bool
generate_forward_only
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
...
...
@@ -1044,7 +1029,6 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Attrs
dygraph_function_args_str
+=
", const paddle::framework::AttributeMap& attr_map"
;
generated_function_body
+=
"
\n
"
;
// [Generation] Get TraceOp
const
char
*
FWD_TRACE_OP_TEMPLATE
=
...
...
@@ -1092,16 +1076,18 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG
(
6
)
<<
"Converted Output VarBase to EagerTensor(s)"
;
// [Generation] ComputeRequireGrad -> GradNodeCreation
std
::
string
grad_node_creation_body_str
=
GenerateGradNodeCreationContent
(
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
op_type
,
in_vars
,
out_vars
);
generated_function_body
+=
grad_node_creation_body_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated GradNode Creation codes"
;
if
(
!
generate_forward_only
)
{
std
::
string
grad_node_creation_body_str
=
GenerateGradNodeCreationContent
(
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
op_type
,
in_vars
,
out_vars
);
generated_function_body
+=
grad_node_creation_body_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated GradNode Creation codes"
;
}
// [Generation] Handle return: Tuple/Vector/Tensor
generated_function_body
+=
"
\n
"
;
std
::
string
return_str
;
std
::
string
return_str
=
""
;
std
::
string
return_type_str
=
""
;
std
::
string
function_proto_return_type_str
=
""
;
if
(
return_contents
.
size
()
>
1
)
{
...
...
@@ -1124,14 +1110,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const
char
*
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
=
"std::tuple<%s>"
;
function_proto_return_type_str
=
paddle
::
string
::
Sprintf
(
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
,
return_type_str
);
}
else
{
}
else
if
(
return_contents
.
size
()
==
1
)
{
// Return vector<Tensor> or Tensor
return_type_str
=
return_types
[
0
];
const
char
*
FWD_TENSOR_RETURN_TEMPLATE
=
" return %s;"
;
return_str
=
paddle
::
string
::
Sprintf
(
FWD_TENSOR_RETURN_TEMPLATE
,
return_contents
[
0
]);
function_proto_return_type_str
=
return_type_str
;
}
else
{
return_str
=
"return nullptr;"
;
function_proto_return_type_str
=
"void*"
;
}
generated_function_body
+=
return_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated return codes"
;
...
...
@@ -1139,6 +1131,11 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
// [Generation] Get Full Function
std
::
string
function_name
=
op_type
+
"_dygraph_function"
;
if
(
dygraph_function_args_str
.
size
()
>
0
)
{
auto
iter
=
dygraph_function_args_str
.
begin
();
if
((
*
iter
)
==
','
)
dygraph_function_args_str
.
erase
(
iter
);
}
const
char
*
FWD_FUNCTION_TEMPLATE
=
"%s %s(%s) {
\n\n
%s
\n
}
\n\n
"
;
std
::
string
fwd_function_str
=
paddle
::
string
::
Sprintf
(
FWD_FUNCTION_TEMPLATE
,
function_proto_return_type_str
,
function_name
,
...
...
@@ -1601,11 +1598,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ---- Collect Information ---- */
/* ----------------------------- */
std
::
vector
<
std
::
string
>
grad_op_types
;
std
::
vector
<
proto
::
OpProto
::
Var
>
in_vars
;
std
::
vector
<
proto
::
OpProto
::
Var
>
out_vars
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_outs_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_fwd_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_grad_slotname_map
;
std
::
vector
<
proto
::
OpProto
::
Var
>
in_vars
;
std
::
vector
<
proto
::
OpProto
::
Var
>
out_vars
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>
grad_ins
;
...
...
@@ -1614,20 +1611,31 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
grad_outs
;
VLOG
(
6
)
<<
"-------- CollectInformationFromOpInfo -------"
;
bool
is_available
=
CollectInformationFromOpInfo
(
op_info
,
&
grad_op_types
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
in_vars
,
&
out_vars
,
&
grad_ins
,
&
grad_outs
);
if
(
!
is_available
)
continue
;
CollectForwardInformationFromOpInfo
(
op_info
,
&
in_vars
,
&
out_vars
);
bool
generate_forward_only
=
false
;
bool
is_available
=
CollectGradInformationFromOpInfo
(
op_info
,
&
generate_forward_only
,
&
grad_op_types
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
grad_ins
,
&
grad_outs
);
if
(
!
is_available
&&
!
generate_forward_only
)
{
VLOG
(
6
)
<<
"Skipped operator: "
<<
op_type
;
continue
;
}
VLOG
(
6
)
<<
"-------- PurifyOpProto -------"
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_inputs_name_pos_map
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_outputs_name_pos_map
;
PurifyOpProto
(
*
op_proto
,
&
fwd_inputs_name_pos_map
,
&
fwd_outputs_name_pos_map
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
in_vars
,
&
out_vars
,
&
grad_ins
,
&
grad_outs
);
PurifyForwardOpProto
(
*
op_proto
,
&
fwd_inputs_name_pos_map
,
&
fwd_outputs_name_pos_map
,
&
in_vars
,
&
out_vars
);
if
(
!
generate_forward_only
)
{
PurifyGradOpProto
(
*
op_proto
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
grad_ins
,
&
grad_outs
);
}
/* --------------------------- */
/* --------- CodeGen --------- */
...
...
@@ -1636,16 +1644,19 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
VLOG
(
6
)
<<
"-------- GenerateForwardFunctionContents -------"
;
std
::
pair
<
std
::
string
,
std
::
string
>
body_and_declaration
=
GenerateForwardFunctionContents
(
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
grad_ins_grad_slotname_map
,
grad_outs_slotname_map
,
grad_ins
,
grad_outs
,
op_type
,
in_vars
,
out_vars
);
generate_forward_only
,
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
grad_ins_grad_slotname_map
,
grad_outs_slotname_map
,
grad_ins
,
grad_outs
,
op_type
,
in_vars
,
out_vars
);
fwd_function_str
+=
body_and_declaration
.
first
+
"
\n
"
;
/* ---- dygraph_forward_api.h ---- */
std
::
string
fwd_function_declare_str
=
body_and_declaration
.
second
;
dygraph_forward_api_str
+=
fwd_function_declare_str
;
if
(
generate_forward_only
)
continue
;
/* ---- nodes.h ---- */
VLOG
(
6
)
<<
"-------- GenerateGradNodeHeaderContents -------"
;
grad_node_h_str
+=
...
...
@@ -1681,6 +1692,52 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateNodeCCFile
(
output_dir
,
grad_node_cc_str
);
}
static
void
PrepareAttrMapForOps
()
{
// Handle "fused_elemwise_add_activation"
std
::
vector
<
std
::
string
>
functor_list
=
{
"a"
,
"b"
};
operators_with_attrs
[
"fused_elemwise_add_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_add_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "fused_elemwise_activation"
operators_with_attrs
[
"fused_elemwise_activation"
]
=
{};
operators_with_attrs
[
"fused_elemwise_activation"
][
"functor_list"
]
=
functor_list
;
// Handle "reverse"
std
::
vector
<
int
>
axis
=
{
0
};
operators_with_attrs
[
"reverse"
]
=
{};
operators_with_attrs
[
"reverse"
][
"axis"
]
=
axis
;
// Handle "flip"
operators_with_attrs
[
"flip"
]
=
{};
operators_with_attrs
[
"flip"
][
"axis"
]
=
axis
;
// Handle "cast"
operators_with_attrs
[
"cast"
]
=
{};
operators_with_attrs
[
"cast"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"cast"
][
"in_dtype"
]
=
5
;
// Handle "transfer_dtype"
operators_with_attrs
[
"transfer_dtype"
]
=
{};
operators_with_attrs
[
"transfer_dtype"
][
"out_dtype"
]
=
5
;
operators_with_attrs
[
"transfer_dtype"
][
"in_dtype"
]
=
5
;
}
static
void
CollectOperatorsToCodeGen
(
const
std
::
string
&
op_list_path
)
{
std
::
string
line
;
std
::
ifstream
op_list_file
(
op_list_path
);
if
(
op_list_file
.
is_open
())
{
while
(
getline
(
op_list_file
,
line
))
{
operators_to_codegen
.
insert
(
line
);
}
op_list_file
.
close
();
}
else
{
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Fatal
(
"Unable to open op_list.txt file"
));
}
}
}
// namespace framework
}
// namespace paddle
...
...
@@ -1693,8 +1750,8 @@ int main(int argc, char* argv[]) {
std
::
string
eager_root
=
argv
[
1
];
std
::
string
op_list_path
=
argv
[
2
];
CollectOperatorsToCodeGen
(
op_list_path
);
PrepareAttrMapForOps
();
paddle
::
framework
::
CollectOperatorsToCodeGen
(
op_list_path
);
paddle
::
framework
::
PrepareAttrMapForOps
();
paddle
::
framework
::
DygraphCodeGeneration
(
eager_root
);
...
...
paddle/fluid/eager/auto_code_generator/op_list.txt
浏览文件 @
de874cdd
...
...
@@ -215,7 +215,6 @@ spp
floor
gelu
retinanet_detection_output
minus
push_dense
silu
sequence_erase
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录