Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e7bda1dd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e7bda1dd
编写于
11月 27, 2021
作者:
Z
Zhanlue Yang
提交者:
GitHub
11月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added Eager Dygraph AutoCodeGen dependencies #2 (#37575)
上级
d2934a70
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
804 addition
and
7 deletion
+804
-7
paddle/fluid/eager/auto_code_generator/eager_generator.cc
paddle/fluid/eager/auto_code_generator/eager_generator.cc
+804
-7
未找到文件。
paddle/fluid/eager/auto_code_generator/eager_generator.cc
浏览文件 @
e7bda1dd
...
@@ -40,11 +40,10 @@ static std::unordered_set<std::string> operators_to_skip = {
...
@@ -40,11 +40,10 @@ static std::unordered_set<std::string> operators_to_skip = {
"fused_attention"
,
"fused_attention"
,
"diag_v2"
,
"diag_v2"
,
};
};
/*
static
std
::
unordered_set
<
std
::
string
>
operators_to_codegen
=
{
static
std
::
unordered_set
<
std
::
string
>
operators_to_codegen
=
{
"sigmoid"
,
"matmul_v2"
,
"reduce_sum"
,
"elementwise_add"
,
"sigmoid"
,
"matmul_v2"
,
"reduce_sum"
,
"elementwise_add"
,
"share_buffer"
,
"var_conv_2d"
,
"split"
};
"share_buffer"
,
"var_conv_2d"
,
"split"
};
*/
static
std
::
unordered_set
<
std
::
string
>
skipped_operators
=
{};
static
std
::
unordered_set
<
std
::
string
>
skipped_operators
=
{};
...
@@ -107,8 +106,10 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
...
@@ -107,8 +106,10 @@ static std::string AttrTypeToString(const proto::AttrType& type) {
break
;
break
;
}
}
default:
{
default:
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
platform
::
errors
::
Fatal
(
"Unable to recognize AttrType: %d"
,
type
));
"AttrType of type boost::variant only supports specific data types."
"However, detected unrecognized AttrType: %d"
,
type
));
}
}
}
}
return
ret
;
return
ret
;
...
@@ -214,8 +215,10 @@ static std::pair<std::string, std::string> GetAttrType(
...
@@ -214,8 +215,10 @@ static std::pair<std::string, std::string> GetAttrType(
break
;
break
;
}
}
default:
{
default:
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Unable to recognize AttrType: %d"
,
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
variant_pos
));
"AttrType of type boost::variant only supports specific data types."
"However, detected unrecognized AttrType: %d"
,
variant_pos
));
}
}
}
}
return
{
ret
,
val
};
return
{
ret
,
val
};
...
@@ -259,6 +262,7 @@ static void SlotNameMatching(
...
@@ -259,6 +262,7 @@ static void SlotNameMatching(
if
(
grad_fwd_slotname_map
.
count
(
grad_slot_name
)
&&
if
(
grad_fwd_slotname_map
.
count
(
grad_slot_name
)
&&
grad_fwd_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
grad_fwd_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Detected mismatched slot names."
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
grad_slot_name
,
grad_fwd_slotname_map
[
grad_slot_name
],
grad_slot_name
,
grad_fwd_slotname_map
[
grad_slot_name
],
fwd_slot_name
));
fwd_slot_name
));
...
@@ -271,6 +275,7 @@ static void SlotNameMatching(
...
@@ -271,6 +275,7 @@ static void SlotNameMatching(
if
(
grad_grad_slotname_map
.
count
(
grad_slot_name
)
&&
if
(
grad_grad_slotname_map
.
count
(
grad_slot_name
)
&&
grad_grad_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
grad_grad_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Detected mismatched slot names."
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
grad_slot_name
,
grad_grad_slotname_map
[
grad_slot_name
],
grad_slot_name
,
grad_grad_slotname_map
[
grad_slot_name
],
fwd_slot_name
));
fwd_slot_name
));
...
@@ -290,6 +295,7 @@ static void SlotNameMatching(
...
@@ -290,6 +295,7 @@ static void SlotNameMatching(
if
(
grad_fwd_slotname_map
.
count
(
grad_slot_name
)
&&
if
(
grad_fwd_slotname_map
.
count
(
grad_slot_name
)
&&
grad_fwd_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
grad_fwd_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Detected mismatched slot names"
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
grad_slot_name
,
grad_fwd_slotname_map
[
grad_slot_name
],
grad_slot_name
,
grad_fwd_slotname_map
[
grad_slot_name
],
fwd_slot_name
));
fwd_slot_name
));
...
@@ -302,6 +308,7 @@ static void SlotNameMatching(
...
@@ -302,6 +308,7 @@ static void SlotNameMatching(
if
(
grad_grad_slotname_map
.
count
(
grad_slot_name
)
&&
if
(
grad_grad_slotname_map
.
count
(
grad_slot_name
)
&&
grad_grad_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
grad_grad_slotname_map
[
grad_slot_name
]
!=
fwd_slot_name
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Detected mismatched slot names."
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
"grad_slot_name %s matches both %s and %s fwd_slot_name"
,
grad_slot_name
,
grad_grad_slotname_map
[
grad_slot_name
],
grad_slot_name
,
grad_grad_slotname_map
[
grad_slot_name
],
fwd_slot_name
));
fwd_slot_name
));
...
@@ -315,6 +322,7 @@ static void SlotNameMatching(
...
@@ -315,6 +322,7 @@ static void SlotNameMatching(
if
(
!
found_matching
)
{
if
(
!
found_matching
)
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Detected mismatched slot names."
"Found no matching fwd_slot_name for grad_slot_name: %s"
,
"Found no matching fwd_slot_name for grad_slot_name: %s"
,
grad_slot_name
));
grad_slot_name
));
...
@@ -344,7 +352,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
...
@@ -344,7 +352,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now
// Only handle matmul_v2 for now
VLOG
(
1
)
<<
"------ Analyzing Op ------: "
<<
op_type
;
VLOG
(
1
)
<<
"------ Analyzing Op ------: "
<<
op_type
;
//
if (!operators_to_codegen.count(op_type)) return false;
if
(
!
operators_to_codegen
.
count
(
op_type
))
return
false
;
if
(
operators_to_skip
.
count
(
op_type
))
return
false
;
if
(
operators_to_skip
.
count
(
op_type
))
return
false
;
return
true
;
return
true
;
...
@@ -741,5 +749,794 @@ static std::string AppendUseOp(const std::string& op_type) {
...
@@ -741,5 +749,794 @@ static std::string AppendUseOp(const std::string& op_type) {
return
return_str
;
return
return_str
;
}
}
/* -------------------------------- */
/* --------- CodeGen: Forward ----- */
/* -------------------------------- */
static
std
::
pair
<
std
::
string
,
std
::
string
>
GenerateForwardFunctionContents
(
const
std
::
vector
<
paddle
::
framework
::
AttributeMap
>&
grad_node_default_attr_maps
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_grad_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_outs_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
grad_ins
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
grad_outs
,
const
proto
::
OpProto
&
op_proto
)
{
/*
// Forward Function Example:
std::tuple<vector<Tensor>, Tensor, vector<Tensor>>
kernel_function(vector<Tensor>& X, Tensor& Y, const paddle::AttributeMap&
attr_map, size_t
Out0Num, size_t Out1Num) {
// Forward Function Body
// According to fwd_inputs_name_pos_map
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>>
ins =
{ {"X" , SyncToVars(X)}, { "Y" , SyncToVars(Y)} };
std::map<std::string, std::vector<std::shared_ptr<egr::EagerTensor>>>
outs =
{
{"Out0" , ConstructDuplicableOutput(Out0Num)}, {"Out1"
,ConstructDuplicableOutput(Out1Num)} };
// According to op_proto->attrs()
egr::RunOp("op_type", ins, outs, attr_map,
Controller.Instance().GetExpectedPlace(), {});
// According to fwd_outputs_names
std::vector<egr::EagerTensor> Out0 = GetOutputs(outs["Out0"]);
egr::EagerTensor Out1 = GetOutputs(outs["Out1"][0]);
std::vector<egr::EagerTensor> Out2 = GetOutputs(outs["Out2"]);
// Grad Node Generation Codes
...
return std::make_tuple(Out0, Out1, Out2);
}
*/
VLOG
(
6
)
<<
"Generating Dygraph Forward Function"
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
string
generated_function_body
=
""
;
std
::
string
dygraph_function_args_str
=
""
;
/* ------ Dygraph forward function generation ------ */
generated_function_body
+=
" // Dygraph Forward Pass
\n
"
;
generated_function_body
+=
"
\n
"
;
// [Generation] Get Ins Map
std
::
string
ins_contents_str
=
""
;
std
::
vector
<
std
::
string
>
input_args_str_list
(
op_proto
.
inputs
().
size
());
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
const
std
::
string
&
input_name
=
input
.
name
();
size_t
input_position
=
fwd_inputs_name_pos_map
.
at
(
input_name
);
if
(
input
.
duplicable
())
{
const
char
*
FWD_INS_ARG_TEMPLATE
=
"const std::vector<egr::EagerTensor>& %s"
;
input_args_str_list
[
input_position
]
=
paddle
::
string
::
Sprintf
(
FWD_INS_ARG_TEMPLATE
,
input_name
);
}
else
{
const
char
*
FWD_INS_ARG_TEMPLATE
=
"const egr::EagerTensor& %s"
;
input_args_str_list
[
input_position
]
=
paddle
::
string
::
Sprintf
(
FWD_INS_ARG_TEMPLATE
,
input_name
);
}
const
char
*
FWD_INS_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, egr::SyncToVars(%s) },"
;
ins_contents_str
+=
paddle
::
string
::
Sprintf
(
FWD_INS_CONTENT_TEMPLATE
,
input_name
,
input_name
);
}
if
(
ins_contents_str
.
size
()
>
0
)
ins_contents_str
.
pop_back
();
// // Remove trailing ","
for
(
const
std
::
string
&
arg
:
input_args_str_list
)
{
dygraph_function_args_str
+=
arg
;
dygraph_function_args_str
+=
","
;
}
if
(
dygraph_function_args_str
.
size
()
>
0
)
dygraph_function_args_str
.
pop_back
();
const
char
*
FWD_INS_MAP_TEMPLATE
=
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> ins = { "
"%s };
\n
"
;
std
::
string
ins_map_str
=
paddle
::
string
::
Sprintf
(
FWD_INS_MAP_TEMPLATE
,
ins_contents_str
);
generated_function_body
+=
ins_map_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated Ins Map"
;
// [Generation] Get Outs Map
std
::
string
outs_contents_str
=
""
;
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
const
std
::
string
&
output_name
=
output
.
name
();
std
::
string
outnum
=
"1"
;
if
(
output
.
duplicable
())
{
outnum
=
output_name
+
"Num"
;
const
char
*
FWD_NUM_ARG_TEMPLATE
=
", size_t %s"
;
std
::
string
arg_str
=
paddle
::
string
::
Sprintf
(
FWD_NUM_ARG_TEMPLATE
,
outnum
);
dygraph_function_args_str
+=
arg_str
;
const
char
*
FWD_OUTS_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, egr::ConstructDuplicableOutput(%s) },"
;
outs_contents_str
+=
paddle
::
string
::
Sprintf
(
FWD_OUTS_CONTENT_TEMPLATE
,
output_name
,
outnum
);
}
else
{
const
char
*
FWD_OUTS_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},"
;
outs_contents_str
+=
paddle
::
string
::
Sprintf
(
FWD_OUTS_CONTENT_TEMPLATE
,
output_name
);
}
}
if
(
outs_contents_str
.
size
()
>
0
)
outs_contents_str
.
pop_back
();
// Remove trailing ","
const
char
*
FWD_OUTS_MAP_TEMPLATE
=
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> outs = { "
"%s };
\n
"
;
std
::
string
outs_map_str
=
paddle
::
string
::
Sprintf
(
FWD_OUTS_MAP_TEMPLATE
,
outs_contents_str
);
generated_function_body
+=
outs_map_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated Outs Map"
;
// [Generation] Get Attrs
dygraph_function_args_str
+=
", const paddle::framework::AttributeMap& attr_map"
;
generated_function_body
+=
"
\n
"
;
// [Generation] Get TraceOp
const
char
*
FWD_TRACE_OP_TEMPLATE
=
" paddle::framework::AttributeMap attrs = attr_map;
\n
"
" paddle::framework::AttributeMap default_attrs;
\n
"
" egr::RunOp(
\"
%s
\"
, ins, outs, attrs,
\n
"
" egr::Controller::Instance().GetExpectedPlace(),
\n
"
" &default_attrs, true, {});
\n
"
;
std
::
string
trace_op_str
=
paddle
::
string
::
Sprintf
(
FWD_TRACE_OP_TEMPLATE
,
op_proto
.
type
());
generated_function_body
+=
trace_op_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated AttrMap & TraceOp"
;
// [Generation] Convert output VarBase to Vector/Tensor
size_t
output_size
=
op_proto
.
outputs
().
size
();
std
::
vector
<
std
::
string
>
return_contents
(
output_size
);
std
::
vector
<
std
::
string
>
return_types
(
output_size
);
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
const
std
::
string
&
output_name
=
output
.
name
();
std
::
string
out_tensor_str
;
size_t
return_position
=
fwd_outputs_name_pos_map
.
at
(
output_name
);
if
(
output
.
duplicable
())
{
const
char
*
FWD_OUT_TENSORS_TEMPLATE
=
" std::vector<egr::EagerTensor> %s = "
"egr::GetOutputs(outs[
\"
%s
\"
]);
\n
"
;
out_tensor_str
=
paddle
::
string
::
Sprintf
(
FWD_OUT_TENSORS_TEMPLATE
,
output_name
,
output_name
);
return_types
[
return_position
]
=
"std::vector<egr::EagerTensor>"
;
}
else
{
const
char
*
FWD_OUT_TENSOR_TEMPLATE
=
" egr::EagerTensor %s = "
"egr::GetOutput(outs[
\"
%s
\"
][0]);
\n
"
;
out_tensor_str
=
paddle
::
string
::
Sprintf
(
FWD_OUT_TENSOR_TEMPLATE
,
output_name
,
output_name
);
return_types
[
return_position
]
=
"egr::EagerTensor"
;
}
return_contents
[
return_position
]
=
output_name
;
generated_function_body
+=
out_tensor_str
;
}
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Converted Output VarBase to EagerTensor(s)"
;
// [Generation] ComputeRequireGrad -> GradNodeCreation
std
::
string
grad_node_creation_body_str
=
GenerateGradNodeCreationContent
(
grad_node_default_attr_maps
,
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
op_proto
);
generated_function_body
+=
grad_node_creation_body_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated GradNode Creation codes"
;
// [Generation] Handle return: Tuple/Vector/Tensor
generated_function_body
+=
"
\n
"
;
std
::
string
return_str
;
std
::
string
return_type_str
=
""
;
std
::
string
function_proto_return_type_str
=
""
;
if
(
return_contents
.
size
()
>
1
)
{
// Return tuple
std
::
string
return_content_str
=
""
;
for
(
const
std
::
string
&
s
:
return_contents
)
{
return_content_str
+=
s
+
","
;
}
return_content_str
.
pop_back
();
// Remove trailing ","
for
(
const
std
::
string
&
s
:
return_types
)
{
return_type_str
+=
s
+
","
;
}
return_type_str
.
pop_back
();
// Remove trailing ","
const
char
*
FWD_TUPLE_RETURN_TEMPLATE
=
" return std::make_tuple(%s);"
;
return_str
=
paddle
::
string
::
Sprintf
(
FWD_TUPLE_RETURN_TEMPLATE
,
return_content_str
);
const
char
*
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
=
"std::tuple<%s>"
;
function_proto_return_type_str
=
paddle
::
string
::
Sprintf
(
FWD_FUNCTION_PROTO_RETURN_TEMPLATE
,
return_type_str
);
}
else
{
// Return vector<Tensor> or Tensor
return_type_str
=
return_types
[
0
];
const
char
*
FWD_TENSOR_RETURN_TEMPLATE
=
" return %s;"
;
return_str
=
paddle
::
string
::
Sprintf
(
FWD_TENSOR_RETURN_TEMPLATE
,
return_contents
[
0
]);
function_proto_return_type_str
=
return_type_str
;
}
generated_function_body
+=
return_str
;
generated_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated return codes"
;
// [Generation] Get Full Function
std
::
string
function_name
=
op_type
+
"_dygraph_function"
;
const
char
*
FWD_FUNCTION_TEMPLATE
=
"%s %s(%s) {
\n\n
%s
\n
}
\n\n
"
;
std
::
string
fwd_function_str
=
paddle
::
string
::
Sprintf
(
FWD_FUNCTION_TEMPLATE
,
function_proto_return_type_str
,
function_name
,
dygraph_function_args_str
,
generated_function_body
);
// [Generation] Append USE_OP
fwd_function_str
+=
AppendUseOp
(
op_type
);
// [Generation] Generate forward functions header
const
char
*
FWD_HEADER_TEMPLATE
=
"%s %s(%s);
\n
"
;
std
::
string
dygraph_function_declaration_str
=
paddle
::
string
::
Sprintf
(
FWD_HEADER_TEMPLATE
,
function_proto_return_type_str
,
function_name
,
dygraph_function_args_str
);
return
{
fwd_function_str
,
dygraph_function_declaration_str
};
}
/* ---------------------------------------------- */
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
static
std
::
string
GenerateGradNodeCCContents
(
const
std
::
vector
<
paddle
::
framework
::
AttributeMap
>&
grad_node_default_attr_maps
,
const
std
::
vector
<
std
::
string
>&
grad_op_types
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_inputs_name_pos_map
,
const
std
::
unordered_map
<
std
::
string
,
size_t
>&
fwd_outputs_name_pos_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_grad_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_outs_slotname_map
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
grad_ins
,
const
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>&
grad_outs
,
const
proto
::
OpProto
&
op_proto
)
{
VLOG
(
6
)
<<
"Generating Grad Node CC"
;
/* [Outline]
vector<vector<Tensor>> GradNodeXXX::operator()(vector<vector<Tensor>>& grads)
{
const std::shared_ptr<Tracer>& tracer = imperative::GetCurrentTracer();
// Comes from "grad_ins"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> ins =
{
"X" : this->"X", "Y" : this->"Y",
"Out0@Grad":
SyncToVars(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out0@Grad"]]"]),
"Out1@Grad":
TensorsToVarBases(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out1@Grad"]]"])
};
// Comes from "grad_outs"
std::map<std::string, std::vector<std::shared_ptr<VarBase>>> outs =
{
"X@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["X@Grad"]]"].Size()),
"Y@Grad" :
ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["Y@Grad"]]"].Size())
};
// Visit each OpBase
for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) {
// Simply pass entire attribute map to kernels
egr::RunOp("iter->Type()", ins, outs, this->attr_map_,
egr::Controller::Instance().ExpectedPlace(), false, {});
}
vector<vector<egr::EagerTensor>> outputs(outs.size());
for(auto& kv : outs) {
outputs["fwd_inputs_name_pos_map[grad_outs_slotname_map[kv.first]]"] =
GetOutputs(outs["kv.first"]);
}
return outputs;
}
*/
const
std
::
string
&
op_type
=
op_proto
.
type
();
std
::
string
generated_grad_function_body
=
""
;
// [Generation] Get Tracer
generated_grad_function_body
+=
"
\n
"
;
generated_grad_function_body
+=
"
\n
"
;
// [Generation] Get Ins Map
std
::
string
ins_contents_str
=
""
;
for
(
auto
iter
:
grad_ins
)
{
const
std
::
string
&
grad_input_name
=
iter
.
first
;
if
(
grad_ins_fwd_slotname_map
.
count
(
grad_input_name
))
{
// Fwd Tensor
std
::
string
struct_fwd_input_name
=
grad_ins_fwd_slotname_map
.
at
(
grad_input_name
)
+
"_"
;
const
char
*
GRAD_INS_FWD_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, "
"egr::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&this->%s, "
"nullptr)) },"
;
ins_contents_str
+=
paddle
::
string
::
Sprintf
(
GRAD_INS_FWD_CONTENT_TEMPLATE
,
grad_input_name
,
struct_fwd_input_name
);
}
else
if
(
grad_ins_grad_slotname_map
.
count
(
grad_input_name
))
{
// Fwd Tensor's Grad
size_t
fwd_output_position
=
fwd_outputs_name_pos_map
.
at
(
grad_ins_grad_slotname_map
.
at
(
grad_input_name
));
const
char
*
GRAD_INS_GRAD_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, egr::SyncToVars(grads[%d]) },"
;
ins_contents_str
+=
paddle
::
string
::
Sprintf
(
GRAD_INS_GRAD_CONTENT_TEMPLATE
,
grad_input_name
,
fwd_output_position
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s"
,
grad_input_name
));
}
}
if
(
ins_contents_str
.
size
()
>
0
)
ins_contents_str
.
pop_back
();
// // Remove trailing ","
const
char
*
BWD_INS_MAP_TEMPLATE
=
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> ins = { "
"%s };
\n
"
;
std
::
string
ins_map_str
=
paddle
::
string
::
Sprintf
(
BWD_INS_MAP_TEMPLATE
,
ins_contents_str
);
generated_grad_function_body
+=
ins_map_str
;
VLOG
(
6
)
<<
"Generated Ins Map"
;
// [Generation] Get Outs Map
std
::
unordered_set
<
std
::
string
>
duplicable_input_name_set
;
for
(
const
auto
&
in
:
op_proto
.
inputs
())
{
if
(
in
.
duplicable
())
duplicable_input_name_set
.
insert
(
in
.
name
());
}
std
::
string
outs_contents_str
=
""
;
for
(
auto
iter
:
grad_outs
)
{
const
std
::
string
&
grad_output_name
=
iter
.
first
;
if
(
grad_outs_slotname_map
.
count
(
grad_output_name
))
{
// Fwd Tensor
const
std
::
string
&
fwd_input_name
=
grad_outs_slotname_map
.
at
(
grad_output_name
);
size_t
fwd_input_position
=
fwd_inputs_name_pos_map
.
at
(
fwd_input_name
);
if
(
duplicable_input_name_set
.
count
(
fwd_input_name
))
{
const
char
*
GRAD_OUTS_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, egr::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },"
;
outs_contents_str
+=
paddle
::
string
::
Sprintf
(
GRAD_OUTS_CONTENT_TEMPLATE
,
grad_output_name
,
fwd_input_position
);
}
else
{
const
char
*
GRAD_OUTS_CONTENT_TEMPLATE
=
"{
\"
%s
\"
, "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},"
;
outs_contents_str
+=
paddle
::
string
::
Sprintf
(
GRAD_OUTS_CONTENT_TEMPLATE
,
grad_output_name
);
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Fatal
(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s"
,
grad_output_name
));
}
}
if
(
outs_contents_str
.
size
()
>
0
)
outs_contents_str
.
pop_back
();
// // Remove trailing ","
const
char
*
BWD_OUTS_MAP_TEMPLATE
=
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> outs = { "
"%s };
\n
"
;
std
::
string
outs_map_str
=
paddle
::
string
::
Sprintf
(
BWD_OUTS_MAP_TEMPLATE
,
outs_contents_str
);
generated_grad_function_body
+=
outs_map_str
;
generated_grad_function_body
+=
"
\n
"
;
VLOG
(
6
)
<<
"Generated Outs Map"
;
// [Generation] Get Attrs Map
std
::
string
trace_opbase_str
=
""
;
for
(
size_t
i
=
0
;
i
<
grad_node_default_attr_maps
.
size
();
i
++
)
{
const
std
::
string
&
op_base_type
=
grad_op_types
[
i
];
const
char
*
TRACE_OP_TEMPLATE
=
" // Pass the entire attribute map to TraceOp
\n
"
" // The underlying kernel will pickup whatever attribute they need "
"at runtime
\n
"
" egr::RunOp(
\"
%s
\"
, ins, outs, this->attr_map_,
\n
"
" egr::Controller::Instance().GetExpectedPlace(),
\n
"
" &this->default_attr_map_, false, {});
\n
"
;
trace_opbase_str
=
paddle
::
string
::
Sprintf
(
TRACE_OP_TEMPLATE
,
op_base_type
);
}
generated_grad_function_body
+=
trace_opbase_str
;
VLOG
(
6
)
<<
"Generated Attrs Map"
;
// [Generation] Get Return
std
::
string
outputs_str
=
""
;
for
(
auto
iter
:
grad_outs
)
{
const
std
::
string
&
grad_out_name
=
iter
.
first
;
size_t
fwd_input_position
=
fwd_inputs_name_pos_map
.
at
(
grad_outs_slotname_map
.
at
(
grad_out_name
));
const
char
*
BWD_OUTPUT_TEMPLATE
=
" outputs[%d] = GetOutputs(outs[
\"
%s
\"
]);
\n
"
;
outputs_str
+=
paddle
::
string
::
Sprintf
(
BWD_OUTPUT_TEMPLATE
,
fwd_input_position
,
grad_out_name
);
}
const
char
*
BWD_RETURN_TEMPLATE
=
" std::vector<std::vector<egr::EagerTensor>> "
"outputs(outs.size());
\n
%s
\n
"
"return outputs;"
;
std
::
string
return_str
=
paddle
::
string
::
Sprintf
(
BWD_RETURN_TEMPLATE
,
outputs_str
);
generated_grad_function_body
+=
"
\n
"
;
generated_grad_function_body
+=
return_str
;
// [Generation] Get Full Grad Function
const
char
*
GRAD_FUNCTION_TEMPLATE
=
"std::vector<std::vector<egr::EagerTensor>> "
"GradNode%s::operator()(const "
"std::vector<std::vector<egr::EagerTensor>>& grads) {
\n
%s
\n
}"
;
std
::
string
grad_function_str
=
paddle
::
string
::
Sprintf
(
GRAD_FUNCTION_TEMPLATE
,
op_type
,
generated_grad_function_body
);
VLOG
(
6
)
<<
"Generated returns"
;
return
grad_function_str
;
}
/* ----------------------------------------- */
/* --------- CodeGen: GradNode Header ------ */
/* ----------------------------------------- */
static
std
::
string
GenerateGradNodeHeaderContents
(
const
std
::
vector
<
paddle
::
framework
::
AttributeMap
>&
grad_node_default_attr_maps
,
const
std
::
map
<
std
::
string
,
std
::
string
>&
grad_ins_fwd_slotname_map
,
const
proto
::
OpProto
&
op_proto
)
{
VLOG
(
6
)
<<
"Generating Grad Node Header"
;
const
char
*
GRAD_NODE_TEMPLATE
=
"class GradNode%s : public egr::GradNodeBase {
\n
"
" public:
\n
"
" GradNode%s() : egr::GradNodeBase() {}
\n
"
" GradNode%s(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : "
"egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}
\n
"
" ~GradNode%s() override = default;
\n
"
"
\n
"
" virtual std::vector<std::vector<egr::EagerTensor>> "
"operator()(const "
"std::vector<std::vector<egr::EagerTensor>>& grads) "
"override;
\n
"
"
\n
"
" // SetX, SetY, ...
\n
"
"%s
\n
"
" // SetAttrMap
\n
"
"%s
\n
"
"
\n
"
" private:
\n
"
" // TensorWrappers
\n
"
"%s
\n
"
" // Attribute Map
\n
"
"%s
\n
"
"};"
;
const
std
::
string
&
op_type
=
op_proto
.
type
();
// [Generation] Handle Attributes
std
::
string
set_attr_map_str
=
" void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {
\n
"
"attr_map_ = std::move(attr_map);
\n
}
\n
"
;
set_attr_map_str
+=
" void SetDefaultAttrMap(paddle::framework::AttributeMap&& "
"default_attr_map) {
\n
default_attr_map_ = "
"std::move(default_attr_map);
\n
}
\n
"
;
std
::
string
attr_members_str
=
" paddle::framework::AttributeMap attr_map_;
\n
"
;
attr_members_str
+=
" paddle::framework::AttributeMap default_attr_map_;"
;
VLOG
(
6
)
<<
"Generated SetAttr"
;
// [Generation] Handle TensorWrappers
std
::
unordered_set
<
std
::
string
>
duplicable_tensors
;
for
(
const
proto
::
OpProto
::
Var
&
input
:
op_proto
.
inputs
())
{
if
(
input
.
duplicable
())
{
duplicable_tensors
.
insert
(
input
.
name
());
}
}
for
(
const
proto
::
OpProto
::
Var
&
output
:
op_proto
.
outputs
())
{
if
(
output
.
duplicable
())
{
duplicable_tensors
.
insert
(
output
.
name
());
}
}
std
::
string
set_tensor_wrappers_str
=
""
;
std
::
string
tensor_wrapper_members_str
=
""
;
for
(
const
auto
&
kv
:
grad_ins_fwd_slotname_map
)
{
const
std
::
string
&
tensor_wrapper_name
=
kv
.
second
;
const
std
::
string
&
struct_tensor_wrapper_name
=
kv
.
second
+
"_"
;
std
::
string
tensor_wrapper_arg_str
;
std
::
string
tensor_wrapper_body_str
;
if
(
duplicable_tensors
.
count
(
tensor_wrapper_name
))
{
const
char
*
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE
=
"const std::vector<egr::EagerTensor>& %s"
;
tensor_wrapper_arg_str
=
paddle
::
string
::
Sprintf
(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE
,
tensor_wrapper_name
);
const
char
*
TENSOR_WRAPPER_MEMBER_TEMPLATE
=
" std::vector<egr::TensorWrapper> %s;
\n
"
;
tensor_wrapper_members_str
+=
paddle
::
string
::
Sprintf
(
TENSOR_WRAPPER_MEMBER_TEMPLATE
,
struct_tensor_wrapper_name
);
const
char
*
SET_TENSOR_WRAPPER_BODY_TEMPLATE
=
"for(const auto& eager_tensor : %s) {
\n
"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, true "
"/*full_reserved*/) );
\n
"
" }
\n
"
;
tensor_wrapper_body_str
=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_BODY_TEMPLATE
,
tensor_wrapper_name
,
struct_tensor_wrapper_name
);
}
else
{
const
char
*
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE
=
"const egr::EagerTensor& %s"
;
tensor_wrapper_arg_str
=
paddle
::
string
::
Sprintf
(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE
,
tensor_wrapper_name
);
const
char
*
TENSOR_WRAPPER_MEMBER_TEMPLATE
=
" egr::TensorWrapper %s;
\n
"
;
tensor_wrapper_members_str
+=
paddle
::
string
::
Sprintf
(
TENSOR_WRAPPER_MEMBER_TEMPLATE
,
struct_tensor_wrapper_name
);
const
char
*
SET_TENSOR_WRAPPER_BODY_TEMPLATE
=
"%s = egr::TensorWrapper(%s, true /*full_reserved*/);"
;
tensor_wrapper_body_str
=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_BODY_TEMPLATE
,
struct_tensor_wrapper_name
,
tensor_wrapper_name
);
}
const
char
*
SET_TENSOR_WRAPPER_TEMPLATE
=
" void SetTensorWrapper%s(%s) {
\n
%s
\n
}
\n
"
;
set_tensor_wrappers_str
+=
paddle
::
string
::
Sprintf
(
SET_TENSOR_WRAPPER_TEMPLATE
,
tensor_wrapper_name
,
tensor_wrapper_arg_str
,
tensor_wrapper_body_str
);
}
VLOG
(
6
)
<<
"Generated TensorWrapper"
;
std
::
string
grad_node_str
=
paddle
::
string
::
Sprintf
(
GRAD_NODE_TEMPLATE
,
op_type
,
op_type
,
op_type
,
op_type
,
set_tensor_wrappers_str
,
set_attr_map_str
,
tensor_wrapper_members_str
,
attr_members_str
);
return
grad_node_str
;
}
/* --------------------------------- */
/* --------- FileGeneration --------- */
/* ---------------------------------- */
static
void
GenerateForwardHFile
(
const
std
::
string
&
output_dir
,
const
std
::
string
&
dygraph_forward_api_str
)
{
std
::
string
dygraph_forward_api_path
=
output_dir
+
"/dygraph_forward_api.h"
;
std
::
ofstream
forward_header_stream
(
dygraph_forward_api_path
,
std
::
ios
::
out
);
forward_header_stream
<<
dygraph_forward_api_str
;
forward_header_stream
.
close
();
}
static
void
GenerateForwardDygraphFile
(
const
std
::
string
&
op_type
,
const
std
::
string
&
output_dir
,
const
std
::
string
&
fwd_function_str
)
{
std
::
string
forwards_dir
=
output_dir
+
"/forwards/"
;
std
::
string
node_h_filename
=
op_type
+
"_node.h"
;
std
::
string
forward_cc_filename
=
op_type
+
"_dygraph.cc"
;
std
::
string
forward_cc_path
=
forwards_dir
+
forward_cc_filename
;
const
char
*
FORWARD_INCLUDE_TEMPLATE
=
"#include "
"
\"
paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h
\"\n
"
"#include "
"
\"
paddle/fluid/eager/api/generated/fluid_generated/nodes/%s
\"\n\n
"
"#include
\"
paddle/fluid/eager/api/utils/global_utils.h
\"\n
"
"#include
\"
paddle/fluid/eager/legacy/op_runner.h
\"\n
"
;
std
::
string
forward_cc_include_str
=
paddle
::
string
::
Sprintf
(
FORWARD_INCLUDE_TEMPLATE
,
node_h_filename
);
std
::
ofstream
forward_cc_stream
(
forward_cc_path
,
std
::
ios
::
out
);
forward_cc_stream
<<
forward_cc_include_str
;
forward_cc_stream
<<
fwd_function_str
;
forward_cc_stream
.
close
();
}
static
void
GenerateNodeHFile
(
const
std
::
string
&
op_type
,
const
std
::
string
&
output_dir
,
const
std
::
string
&
grad_node_str
)
{
std
::
string
nodes_dir
=
output_dir
+
"/nodes/"
;
std
::
string
node_h_filename
=
op_type
+
"_node.h"
;
std
::
string
node_h_path
=
nodes_dir
+
node_h_filename
;
std
::
string
node_h_include_str
=
"#pragma once
\n
"
"#include
\"
paddle/fluid/eager/tensor_wrapper.h
\"\n
"
"#include
\"
paddle/fluid/eager/legacy/op_runner.h
\"\n
"
"#include
\"
paddle/fluid/eager/grad_node_info.h
\"\n\n
"
;
std
::
ofstream
node_h_stream
(
node_h_path
,
std
::
ios
::
out
);
node_h_stream
<<
node_h_include_str
;
node_h_stream
<<
grad_node_str
;
node_h_stream
.
close
();
}
static
void
GenerateNodeCCFile
(
const
std
::
string
&
op_type
,
const
std
::
string
&
output_dir
,
const
std
::
string
&
grad_function_str
)
{
std
::
string
nodes_dir
=
output_dir
+
"/nodes/"
;
std
::
string
node_h_filename
=
op_type
+
"_node.h"
;
std
::
string
node_cc_filename
=
op_type
+
"_node.cc"
;
std
::
string
node_cc_path
=
nodes_dir
+
node_cc_filename
;
const
char
*
NODE_CC_INCLUDE_TEMPLATE
=
"#include
\"
glog/logging.h
\"\n
"
"#include
\"
paddle/pten/api/all.h
\"\n
"
"#include
\"
paddle/fluid/imperative/tracer.h
\"\n
"
"#include
\"
paddle/fluid/framework/op_registry.h
\"\n
"
"#include
\"
paddle/fluid/eager/utils.h
\"\n
"
"#include
\"
paddle/fluid/eager/api/utils/global_utils.h
\"\n
"
"#include "
"
\"
paddle/fluid/eager/api/generated/fluid_generated/nodes/%s
\"\n\n
"
;
std
::
string
node_cc_include_str
=
paddle
::
string
::
Sprintf
(
NODE_CC_INCLUDE_TEMPLATE
,
node_h_filename
);
std
::
ofstream
node_cc_stream
(
node_cc_path
,
std
::
ios
::
out
);
node_cc_stream
<<
node_cc_include_str
;
node_cc_stream
<<
grad_function_str
;
node_cc_stream
.
close
();
}
static
std
::
string
GenerateDygraphHFileIncludes
()
{
std
::
string
dygraph_forward_api_includes_str
=
"#pragma once
\n
"
"#include
\"
glog/logging.h
\"\n
"
"#include
\"
paddle/fluid/eager/autograd_meta.h
\"\n
"
"#include
\"
paddle/pten/api/all.h
\"\n
"
"#include
\"
paddle/fluid/eager/utils.h
\"\n
"
"#include
\"
paddle/fluid/framework/op_registry.h
\"\n\n
"
;
return
dygraph_forward_api_includes_str
;
}
static
void
DygraphCodeGeneration
(
const
std
::
string
&
output_dir
)
{
std
::
string
dygraph_forward_api_str
=
GenerateDygraphHFileIncludes
();
auto
&
op_info_map
=
paddle
::
framework
::
OpInfoMap
::
Instance
().
map
();
for
(
auto
&
pair
:
op_info_map
)
{
const
OpInfo
&
op_info
=
pair
.
second
;
proto
::
OpProto
*
op_proto
=
op_info
.
proto_
;
if
(
!
CheckOpProto
(
op_proto
))
continue
;
const
std
::
string
&
op_type
=
op_proto
->
type
();
/* ----------------------------- */
/* ---- Collect Information ---- */
/* ----------------------------- */
std
::
vector
<
paddle
::
framework
::
AttributeMap
>
grad_node_default_attr_maps
;
std
::
vector
<
std
::
string
>
grad_op_types
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_inputs_name_pos_map
;
std
::
unordered_map
<
std
::
string
,
size_t
>
fwd_outputs_name_pos_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_outs_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_fwd_slotname_map
;
std
::
map
<
std
::
string
,
std
::
string
>
grad_ins_grad_slotname_map
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>
grad_ins
;
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
shared_ptr
<
paddle
::
imperative
::
VariableWrapper
>>>
grad_outs
;
VLOG
(
6
)
<<
"-------- CollectInformationFromOpInfo -------"
;
bool
is_available
=
CollectInformationFromOpInfo
(
op_info
,
&
grad_node_default_attr_maps
,
&
grad_op_types
,
&
fwd_inputs_name_pos_map
,
&
fwd_outputs_name_pos_map
,
&
grad_outs_slotname_map
,
&
grad_ins_fwd_slotname_map
,
&
grad_ins_grad_slotname_map
,
&
grad_ins
,
&
grad_outs
);
if
(
!
is_available
)
continue
;
/* --------------------------- */
/* --------- CodeGen --------- */
/* --------------------------- */
/* ---- xxx_dygraph.cc ---- */
VLOG
(
6
)
<<
"-------- GenerateForwardFunctionContents -------"
;
std
::
pair
<
std
::
string
,
std
::
string
>
body_and_declaration
=
GenerateForwardFunctionContents
(
grad_node_default_attr_maps
,
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
grad_ins_grad_slotname_map
,
grad_outs_slotname_map
,
grad_ins
,
grad_outs
,
*
op_proto
);
std
::
string
fwd_function_str
=
body_and_declaration
.
first
;
GenerateForwardDygraphFile
(
op_type
,
output_dir
,
fwd_function_str
);
/* ---- dygraph_forward_api.h ---- */
std
::
string
fwd_function_declare_str
=
body_and_declaration
.
second
;
dygraph_forward_api_str
+=
fwd_function_declare_str
;
/* ---- xxx_node.h ---- */
VLOG
(
6
)
<<
"-------- GenerateGradNodeHeaderContents -------"
;
std
::
string
grad_node_h_str
=
GenerateGradNodeHeaderContents
(
grad_node_default_attr_maps
,
grad_ins_fwd_slotname_map
,
*
op_proto
);
GenerateNodeHFile
(
op_type
,
output_dir
,
grad_node_h_str
);
/* ---- xxx_node.cc ---- */
VLOG
(
6
)
<<
"-------- GenerateGradNodeCCContents -------"
;
std
::
string
grad_node_cc_str
=
GenerateGradNodeCCContents
(
grad_node_default_attr_maps
,
grad_op_types
,
fwd_inputs_name_pos_map
,
fwd_outputs_name_pos_map
,
grad_ins_fwd_slotname_map
,
grad_ins_grad_slotname_map
,
grad_outs_slotname_map
,
grad_ins
,
grad_outs
,
*
op_proto
);
GenerateNodeCCFile
(
op_type
,
output_dir
,
grad_node_cc_str
);
VLOG
(
6
)
<<
op_type
<<
": Finished Generation"
;
}
/* ---- dygraph_forward_api.h ---- */
VLOG
(
6
)
<<
"-------- GenerateForwardHFile -------"
;
GenerateForwardHFile
(
output_dir
,
dygraph_forward_api_str
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
if
(
argc
!=
2
)
{
std
::
cerr
<<
"argc must be 2"
<<
std
::
endl
;
return
-
1
;
}
std
::
string
eager_root
=
argv
[
1
];
paddle
::
framework
::
DygraphCodeGeneration
(
eager_root
);
return
0
;
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录