Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f027b2ad
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f027b2ad
编写于
3月 25, 2022
作者:
Z
Zhanlue Yang
提交者:
GitHub
3月 25, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Refactor] refactored eager_gen.py PR #2 (#40907)
上级
5f6038ff
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
78 addition
and
41 deletion
+78
-41
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
...uto_code_generator/final_state_generator/codegen_utils.py
+7
-3
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
...er/auto_code_generator/final_state_generator/eager_gen.py
+71
-38
未找到文件。
paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py
浏览文件 @
f027b2ad
...
...
@@ -50,6 +50,10 @@ yaml_types_mapping = {
#############################
### File Reader Helpers ###
#############################
def
AssertMessage
(
lhs_str
,
rhs_str
):
return
f
"lhs:
{
lhs_str
}
, rhs:
{
rhs_str
}
"
def
ReadFwdFile
(
filepath
):
f
=
open
(
filepath
,
'r'
)
contents
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
...
...
@@ -62,10 +66,10 @@ def ReadBwdFile(filepath):
contents
=
yaml
.
load
(
f
,
Loader
=
yaml
.
FullLoader
)
ret
=
{}
for
content
in
contents
:
assert
'backward_api'
in
content
.
keys
(),
AssertMessage
(
'backward_api'
,
content
.
keys
())
if
'backward_api'
in
content
.
keys
():
api_name
=
content
[
'backward_api'
]
else
:
assert
False
ret
[
api_name
]
=
content
f
.
close
()
...
...
@@ -225,7 +229,7 @@ def ParseYamlReturns(string):
),
f
"The return type
{
ret_type
}
in yaml config is not supported in yaml_types_mapping."
ret_type
=
yaml_types_mapping
[
ret_type
]
assert
"Tensor"
in
ret_type
assert
"Tensor"
in
ret_type
,
AssertMessage
(
"Tensor"
,
ret_type
)
ret_name
=
RemoveSpecialSymbolsInName
(
ret_name
)
returns_list
.
append
([
ret_name
,
ret_type
,
i
])
...
...
paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py
浏览文件 @
f027b2ad
...
...
@@ -16,6 +16,7 @@ import yaml
import
re
import
argparse
import
os
import
logging
from
codegen_utils
import
core_ops_returns_info
,
core_ops_args_info
,
core_ops_args_type_info
from
codegen_utils
import
yaml_types_mapping
from
codegen_utils
import
ReadFwdFile
,
ReadBwdFile
...
...
@@ -30,6 +31,7 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB
from
codegen_utils
import
ParseYamlForward
,
ParseYamlBackward
from
codegen_utils
import
FunctionGeneratorBase
,
YamlGeneratorBase
from
codegen_utils
import
ops_to_fill_zero_for_empty_grads
from
codegen_utils
import
AssertMessage
###########
...
...
@@ -398,14 +400,21 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_api_contents
=
self
.
forward_api_contents
grad_api_contents
=
self
.
grad_api_contents
assert
'api'
in
forward_api_contents
.
keys
()
assert
'args'
in
forward_api_contents
.
keys
()
assert
'output'
in
forward_api_contents
.
keys
()
assert
'backward'
in
forward_api_contents
.
keys
()
assert
'args'
in
grad_api_contents
.
keys
()
assert
'output'
in
grad_api_contents
.
keys
()
assert
'forward'
in
grad_api_contents
.
keys
()
assert
'api'
in
forward_api_contents
.
keys
(
),
"Unable to find
\"
api
\"
in api.yaml"
assert
'args'
in
forward_api_contents
.
keys
(
),
"Unable to find
\"
args
\"
in api.yaml"
assert
'output'
in
forward_api_contents
.
keys
(
),
"Unable to find
\"
output
\"
in api.yaml"
assert
'backward'
in
forward_api_contents
.
keys
(
),
"Unable to find
\"
backward
\"
in api.yaml"
assert
'args'
in
grad_api_contents
.
keys
(
),
"Unable to find
\"
args
\"
in backward.yaml"
assert
'output'
in
grad_api_contents
.
keys
(
),
"Unable to find
\"
output
\"
in backward.yaml"
assert
'forward'
in
grad_api_contents
.
keys
(
),
"Unable to find
\"
forward
\"
in backward.yaml"
def
ForwardsValidationCheck
(
self
):
forward_inputs_list
=
self
.
forward_inputs_list
...
...
@@ -424,8 +433,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
orig_input_type
=
orig_forward_inputs_list
[
i
][
1
]
orig_input_pos
=
orig_forward_inputs_list
[
i
][
2
]
assert
forward_input_type
==
orig_input_type
assert
forward_input_pos
==
orig_input_pos
assert
forward_input_type
==
orig_input_type
,
AssertMessage
(
forward_input_type
,
orig_input_type
)
assert
forward_input_pos
==
orig_input_pos
,
AssertMessage
(
forward_input_pos
,
orig_input_pos
)
for
i
in
range
(
len
(
forward_attrs_list
)):
orig_attr_name
=
orig_forward_attrs_list
[
i
][
0
]
...
...
@@ -436,9 +447,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_attr_type
=
forward_attrs_list
[
i
][
1
]
forward_attr_default
=
forward_attrs_list
[
i
][
2
]
forward_attr_pos
=
forward_attrs_list
[
i
][
3
]
assert
orig_attr_type
==
forward_attr_type
assert
orig_attr_default
==
forward_attr_default
assert
orig_attr_pos
==
forward_attr_pos
assert
orig_attr_type
==
forward_attr_type
,
AssertMessage
(
orig_attr_type
,
forward_attr_type
)
assert
orig_attr_default
==
forward_attr_default
,
AssertMessage
(
orig_attr_default
,
forward_attr_default
)
assert
orig_attr_pos
==
forward_attr_pos
,
AssertMessage
(
orig_attr_pos
,
forward_attr_pos
)
for
i
in
range
(
len
(
forward_returns_list
)):
orig_return_type
=
orig_forward_returns_list
[
i
][
1
]
...
...
@@ -446,8 +460,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
forward_return_type
=
forward_returns_list
[
i
][
1
]
forward_return_pos
=
forward_returns_list
[
i
][
2
]
assert
orig_return_type
==
forward_return_type
assert
orig_return_pos
==
forward_return_pos
assert
orig_return_type
==
forward_return_type
,
AssertMessage
(
orig_return_type
,
forward_return_type
)
assert
orig_return_pos
==
forward_return_pos
,
AssertMessage
(
orig_return_pos
,
forward_return_pos
)
# Check Order: Inputs, Attributes
max_input_position
=
-
1
...
...
@@ -456,7 +472,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
max_attr_position
=
-
1
for
_
,
_
,
_
,
pos
in
forward_attrs_list
:
assert
pos
>
max_input_position
assert
pos
>
max_input_position
,
AssertMessage
(
pos
,
max_input_position
)
max_attr_position
=
max
(
max_attr_position
,
pos
)
def
BackwardValidationCheck
(
self
):
...
...
@@ -471,12 +488,14 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
max_grad_tensor_position
=
-
1
for
_
,
(
_
,
_
,
pos
)
in
backward_grad_inputs_map
.
items
():
assert
pos
>
max_fwd_input_position
assert
pos
>
max_fwd_input_position
,
AssertMessage
(
pos
,
max_grad_tensor_position
)
max_grad_tensor_position
=
max
(
max_grad_tensor_position
,
pos
)
max_attr_position
=
-
1
for
_
,
_
,
_
,
pos
in
backward_attrs_list
:
assert
pos
>
max_grad_tensor_position
assert
pos
>
max_grad_tensor_position
,
AssertMessage
(
pos
,
max_grad_tensor_position
)
max_attr_position
=
max
(
max_attr_position
,
pos
)
def
IntermediateValidationCheck
(
self
):
...
...
@@ -491,7 +510,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
len
(
forward_returns_list
))
for
ret_name
,
_
,
pos
in
forward_returns_list
:
if
ret_name
in
intermediate_outputs
:
assert
pos
in
intermediate_positions
assert
pos
in
intermediate_positions
,
AssertMessage
(
pos
,
intermediate_positions
)
def
CollectBackwardInfo
(
self
):
forward_api_contents
=
self
.
forward_api_contents
...
...
@@ -505,9 +525,12 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
self
.
backward_inputs_list
,
self
.
backward_attrs_list
,
self
.
backward_returns_list
=
ParseYamlBackward
(
backward_args_str
,
backward_returns_str
)
print
(
"Parsed Backward Inputs List: "
,
self
.
backward_inputs_list
)
print
(
"Prased Backward Attrs List: "
,
self
.
backward_attrs_list
)
print
(
"Parsed Backward Returns List: "
,
self
.
backward_returns_list
)
logging
.
info
(
f
"Parsed Backward Inputs List:
{
self
.
backward_inputs_list
}
"
)
logging
.
info
(
f
"Prased Backward Attrs List:
{
self
.
backward_attrs_list
}
"
)
logging
.
info
(
f
"Parsed Backward Returns List:
{
self
.
backward_returns_list
}
"
)
def
CollectForwardInfoFromBackwardContents
(
self
):
...
...
@@ -530,7 +553,9 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_fwd_name
=
FindForwardName
(
backward_input_name
)
if
backward_fwd_name
:
# Grad Input
assert
backward_fwd_name
in
forward_outputs_position_map
.
keys
()
assert
backward_fwd_name
in
forward_outputs_position_map
.
keys
(
),
AssertMessage
(
backward_fwd_name
,
forward_outputs_position_map
.
keys
())
matched_forward_output_type
=
forward_outputs_position_map
[
backward_fwd_name
][
0
]
matched_forward_output_pos
=
forward_outputs_position_map
[
...
...
@@ -556,7 +581,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_input_type
,
False
,
backward_input_pos
]
else
:
assert
False
,
backward_input_name
assert
False
,
f
"Cannot find
{
backward_input_name
}
in forward position map"
for
backward_output
in
backward_returns_list
:
backward_output_name
=
backward_output
[
0
]
...
...
@@ -564,9 +589,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_output_pos
=
backward_output
[
2
]
backward_fwd_name
=
FindForwardName
(
backward_output_name
)
assert
backward_fwd_name
is
not
None
assert
backward_fwd_name
is
not
None
,
f
"Detected
{
backward_fwd_name
}
= None"
assert
backward_fwd_name
in
forward_inputs_position_map
.
keys
(
),
f
"Unable to find
{
backward_fwd_name
}
in forward inputs"
),
AssertMessage
(
backward_fwd_name
,
forward_inputs_position_map
.
keys
())
matched_forward_input_type
=
forward_inputs_position_map
[
backward_fwd_name
][
0
]
...
...
@@ -577,12 +603,15 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
backward_output_type
,
matched_forward_input_pos
,
backward_output_pos
]
print
(
"Generated Backward Fwd Input Map: "
,
self
.
backward_forward_inputs_map
)
print
(
"Generated Backward Grad Input Map: "
,
self
.
backward_grad_inputs_map
)
print
(
"Generated Backward Grad Output Map: "
,
self
.
backward_grad_outputs_map
)
logging
.
info
(
f
"Generated Backward Fwd Input Map:
{
self
.
backward_forward_inputs_map
}
"
)
logging
.
info
(
f
"Generated Backward Grad Input Map:
{
self
.
backward_grad_inputs_map
}
"
)
logging
.
info
(
f
"Generated Backward Grad Output Map:
{
self
.
backward_grad_outputs_map
}
"
)
def
GenerateNodeDeclaration
(
self
):
forward_op_name
=
self
.
forward_api_name
...
...
@@ -642,7 +671,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
set_tensor_wrapper_methods_str
,
set_attribute_methods_str
,
tensor_wrapper_members_str
,
attribute_members_str
)
print
(
"Generated Node Declaration: "
,
self
.
node_declaration_str
)
logging
.
info
(
f
"Generated Node Declaration:
{
self
.
node_declaration_str
}
"
)
def
GenerateNodeDefinition
(
self
):
namespace
=
self
.
namespace
...
...
@@ -710,7 +739,7 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
grad_node_name
,
fill_zero_str
,
grad_node_name
,
grad_api_namespace
,
backward_api_name
,
grad_api_args_str
,
returns_str
)
print
(
"Generated Node Definition: "
,
self
.
node_definition_str
)
logging
.
info
(
f
"Generated Node Definition:
{
self
.
node_definition_str
}
"
)
def
GenerateForwardDefinition
(
self
,
is_inplaced
):
namespace
=
self
.
namespace
...
...
@@ -813,8 +842,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
dygraph_event_str
,
node_creation_str
,
returns_str
)
self
.
forward_declaration_str
+=
f
"
{
returns_type_str
}
{
forward_function_name
}
(
{
inputs_args_declaration_str
}
);
\n
"
print
(
"Generated Forward Definition: "
,
self
.
forward_definition_str
)
print
(
"Generated Forward Declaration: "
,
self
.
forward_declaration_str
)
logging
.
info
(
f
"Generated Forward Definition:
{
self
.
forward_definition_str
}
"
)
logging
.
info
(
f
"Generated Forward Declaration:
{
self
.
forward_declaration_str
}
"
)
def
GenerateNodeCreationCodes
(
self
,
forward_call_str
):
forward_api_name
=
self
.
forward_api_name
...
...
@@ -921,7 +952,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
else
:
if
num_fwd_outputs
>
1
:
# Aligned with forward output position
assert
name
in
forward_outputs_position_map
.
keys
()
assert
name
in
forward_outputs_position_map
.
keys
(
),
AssertMessage
(
name
,
forward_outputs_position_map
.
keys
())
fwd_output_pos
=
forward_outputs_position_map
[
name
][
1
]
tw_name
=
f
"std::get<
{
fwd_output_pos
}
>(api_result)"
else
:
...
...
@@ -1114,7 +1146,8 @@ class DygraphYamlGenerator(YamlGeneratorBase):
if
'backward'
not
in
forward_api_contents
.
keys
():
return
None
backward_api_name
=
forward_api_contents
[
'backward'
]
assert
backward_api_name
in
grad_api_dict
.
keys
()
assert
backward_api_name
in
grad_api_dict
.
keys
(),
AssertMessage
(
backward_api_name
,
grad_api_dict
.
keys
())
backward_api_contents
=
grad_api_dict
[
backward_api_name
]
return
backward_api_contents
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录