Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9474d140
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9474d140
编写于
3月 31, 2020
作者:
A
Aurelius84
提交者:
GitHub
3月 31, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support Parameter type determination in StaticAnalysis (#23302)
* Support Parameter type determination test=develop
上级
20eed540
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
69 addition
and
19 deletion
+69
-19
python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py
...paddle/fluid/dygraph/dygraph_to_static/static_analysis.py
+40
-13
python/paddle/fluid/tests/unittests/dygraph_to_static/test_static_analysis.py
...tests/unittests/dygraph_to_static/test_static_analysis.py
+29
-6
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py
浏览文件 @
9474d140
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
gast
import
warnings
from
.utils
import
is_paddle_api
,
is_dygraph_api
,
is_numpy_api
from
.utils
import
is_paddle_api
,
is_dygraph_api
,
is_numpy_api
,
index_in_list
__all__
=
[
'AstNodeWrapper'
,
'NodeVarType'
,
'StaticAnalysisVisitor'
]
...
...
@@ -260,9 +260,9 @@ class StaticAnalysisVisitor(object):
def
get_var_env
(
self
):
return
self
.
var_env
def
_get_
node_var_type
(
self
,
cur_wrapper
):
node
=
cur_wrapper
.
node
if
isinstance
(
node
,
gast
.
Constant
):
def
_get_
constant_node_type
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Constant
),
\
"Type of input node should be gast.Constant, but received %s"
%
type
(
node
)
# singleton: None, True or False
if
node
.
value
is
None
:
return
{
NodeVarType
.
NONE
}
...
...
@@ -275,6 +275,13 @@ class StaticAnalysisVisitor(object):
if
isinstance
(
node
.
value
,
str
):
return
{
NodeVarType
.
STRING
}
return
{
NodeVarType
.
UNKNOWN
}
def
_get_node_var_type
(
self
,
cur_wrapper
):
node
=
cur_wrapper
.
node
if
isinstance
(
node
,
gast
.
Constant
):
return
self
.
_get_constant_node_type
(
node
)
if
isinstance
(
node
,
gast
.
BoolOp
):
return
{
NodeVarType
.
BOOLEAN
}
if
isinstance
(
node
,
gast
.
Compare
):
...
...
@@ -308,8 +315,28 @@ class StaticAnalysisVisitor(object):
if
isinstance
(
node
,
gast
.
Name
):
if
node
.
id
==
"None"
:
return
{
NodeVarType
.
NONE
}
if
node
.
id
==
"True"
or
node
.
id
==
"False"
:
if
node
.
id
in
{
"True"
,
"False"
}
:
return
{
NodeVarType
.
BOOLEAN
}
# If node is child of functionDef.arguments
parent_node_wrapper
=
cur_wrapper
.
parent
if
parent_node_wrapper
and
isinstance
(
parent_node_wrapper
.
node
,
gast
.
arguments
):
parent_node
=
parent_node_wrapper
.
node
var_type
=
{
NodeVarType
.
UNKNOWN
}
if
parent_node
.
defaults
:
index
=
index_in_list
(
parent_node
.
args
,
node
)
args_len
=
len
(
parent_node
.
args
)
if
index
!=
-
1
and
args_len
-
index
<=
len
(
parent_node
.
defaults
):
defaults_node
=
parent_node
.
defaults
[
index
-
args_len
]
if
isinstance
(
defaults_node
,
gast
.
Constant
):
var_type
=
self
.
_get_constant_node_type
(
defaults_node
)
# Add node with identified type into cur_env.
self
.
var_env
.
set_var_type
(
node
.
id
,
var_type
)
return
var_type
return
self
.
var_env
.
get_var_type
(
node
.
id
)
if
isinstance
(
node
,
gast
.
Return
):
...
...
python/paddle/fluid/tests/unittests/
test_ast_transformer
_static_analysis.py
→
python/paddle/fluid/tests/unittests/
dygraph_to_static/test
_static_analysis.py
浏览文件 @
9474d140
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -20,7 +20,7 @@ import numpy as np
import
paddle.fluid
as
fluid
import
unittest
from
paddle.fluid.dygraph.dygraph_to_static
import
AstNodeWrapper
,
NodeVarType
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static
import
NodeVarType
,
StaticAnalysisVisitor
def
func_to_test1
(
a
,
b
):
...
...
@@ -117,12 +117,34 @@ result_var_type5 = {
'inner_unknown_func'
:
{
NodeVarType
.
UNKNOWN
},
}
def
func_to_test6
(
x
,
y
=
1
):
i
=
fluid
.
dygraph
.
to_variable
(
x
)
def
add
(
x
,
y
):
return
x
+
y
while
x
<
10
:
i
=
add
(
i
,
x
)
x
=
x
+
y
return
i
result_var_type6
=
{
'i'
:
{
NodeVarType
.
INT
},
'x'
:
{
NodeVarType
.
INT
},
'y'
:
{
NodeVarType
.
INT
},
'add'
:
{
NodeVarType
.
INT
}
}
test_funcs
=
[
func_to_test1
,
func_to_test2
,
func_to_test3
,
func_to_test4
,
func_to_test5
func_to_test1
,
func_to_test2
,
func_to_test3
,
func_to_test4
,
func_to_test5
,
func_to_test6
]
result_var_type
=
[
result_var_type1
,
result_var_type2
,
result_var_type3
,
result_var_type4
,
result_var_type5
result_var_type5
,
result_var_type6
]
...
...
@@ -150,8 +172,8 @@ class TestStaticAnalysis(unittest.TestCase):
self
.
_check_wrapper
(
wrapper_root
,
node_to_wrapper_map
)
def
test_var_env
(
self
):
for
i
in
range
(
5
):
func
=
test_funcs
[
i
]
for
i
,
func
in
enumerate
(
test_funcs
):
var_type
=
result_var_type
[
i
]
test_source_code
=
inspect
.
getsource
(
func
)
ast_root
=
gast
.
parse
(
test_source_code
)
...
...
@@ -164,6 +186,7 @@ class TestStaticAnalysis(unittest.TestCase):
var_env
.
cur_scope
=
var_env
.
cur_scope
.
sub_scopes
[
0
]
scope_var_type
=
var_env
.
get_scope_var_type
()
print
(
scope_var_type
)
self
.
assertEqual
(
len
(
scope_var_type
),
len
(
var_type
))
for
name
in
scope_var_type
:
print
(
"Test var name %s"
%
(
name
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录