Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6cb24967
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6cb24967
编写于
6月 28, 2022
作者:
A
Aurelius84
提交者:
GitHub
6月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat]Unify all API name in_jst import path to improve readablity (#43868)
* [Dy2Stat]Polish all API name of _jst
上级
13451615
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
50 addition
and
151 deletion
+50
-151
python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py
...dle/fluid/dygraph/dygraph_to_static/assert_transformer.py
+3
-5
python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py
...addle/fluid/dygraph/dygraph_to_static/call_transformer.py
+1
-1
python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py
...addle/fluid/dygraph/dygraph_to_static/cast_transformer.py
+1
-2
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
...dle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
+3
-3
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
...addle/fluid/dygraph/dygraph_to_static/list_transformer.py
+1
-1
python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py
...le/fluid/dygraph/dygraph_to_static/logical_transformer.py
+5
-5
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
...addle/fluid/dygraph/dygraph_to_static/loop_transformer.py
+4
-4
python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py
...ddle/fluid/dygraph/dygraph_to_static/print_transformer.py
+1
-1
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
...uid/dygraph/dygraph_to_static/tensor_shape_transformer.py
+2
-8
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+1
-1
python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py
...le/fluid/dygraph/dygraph_to_static/variable_trans_func.py
+3
-61
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
...d/tests/unittests/dygraph_to_static/ifelse_simple_func.py
+2
-4
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py
...id/tests/unittests/dygraph_to_static/test_convert_call.py
+2
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_list.py
+1
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py
...ts/unittests/dygraph_to_static/test_program_translator.py
+6
-6
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
...id/tests/unittests/dygraph_to_static/test_tensor_shape.py
+0
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py
...s/unittests/dygraph_to_static/test_variable_trans_func.py
+0
-24
python/paddle/jit/dy2static/__init__.py
python/paddle/jit/dy2static/__init__.py
+14
-18
python/paddle/jit/dy2static/variable_trans_func.py
python/paddle/jit/dy2static/variable_trans_func.py
+0
-4
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/assert_transformer.py
浏览文件 @
6cb24967
...
@@ -36,10 +36,8 @@ class AssertTransformer(gast.NodeTransformer):
...
@@ -36,10 +36,8 @@ class AssertTransformer(gast.NodeTransformer):
self
.
visit
(
self
.
root
)
self
.
visit
(
self
.
root
)
def
visit_Assert
(
self
,
node
):
def
visit_Assert
(
self
,
node
):
convert_assert_node
=
gast
.
parse
(
convert_assert_node
=
gast
.
parse
(
'_jst.Assert({test}, {msg})'
.
format
(
'_jst.convert_assert({test}, {msg})'
.
format
(
test
=
ast_to_source_code
(
node
.
test
),
test
=
ast_to_source_code
(
node
.
test
),
msg
=
ast_to_source_code
(
node
.
msg
)
if
node
.
msg
else
""
)).
body
[
0
].
value
msg
=
ast_to_source_code
(
node
.
msg
)
if
node
.
msg
else
""
)).
body
[
0
].
value
return
gast
.
Expr
(
value
=
convert_assert_node
)
return
gast
.
Expr
(
value
=
convert_assert_node
)
python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py
浏览文件 @
6cb24967
...
@@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer):
...
@@ -71,7 +71,7 @@ class CallTransformer(gast.NodeTransformer):
if
PDB_SET
in
func_str
:
if
PDB_SET
in
func_str
:
return
node
return
node
new_func_str
=
"_jst.
convert_c
all({})"
.
format
(
func_str
)
new_func_str
=
"_jst.
C
all({})"
.
format
(
func_str
)
new_func_ast
=
gast
.
parse
(
new_func_str
).
body
[
0
].
value
new_func_ast
=
gast
.
parse
(
new_func_str
).
body
[
0
].
value
node
.
func
=
new_func_ast
node
.
func
=
new_func_ast
...
...
python/paddle/fluid/dygraph/dygraph_to_static/cast_transformer.py
浏览文件 @
6cb24967
...
@@ -39,8 +39,7 @@ class CastTransformer(gast.NodeTransformer):
...
@@ -39,8 +39,7 @@ class CastTransformer(gast.NodeTransformer):
func_str
=
ast_to_source_code
(
node
.
func
).
strip
()
func_str
=
ast_to_source_code
(
node
.
func
).
strip
()
if
func_str
in
self
.
_castable_type
and
len
(
node
.
args
)
>
0
:
if
func_str
in
self
.
_castable_type
and
len
(
node
.
args
)
>
0
:
args_str
=
ast_to_source_code
(
node
.
args
[
0
]).
strip
()
args_str
=
ast_to_source_code
(
node
.
args
[
0
]).
strip
()
new_func_str
=
"_jst.convert_var_dtype({}, '{}')"
.
format
(
new_func_str
=
"_jst.AsDtype({}, '{}')"
.
format
(
args_str
,
func_str
)
args_str
,
func_str
)
new_node
=
gast
.
parse
(
new_func_str
).
body
[
0
].
value
new_node
=
gast
.
parse
(
new_func_str
).
body
[
0
].
value
return
new_node
return
new_node
...
...
python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
浏览文件 @
6cb24967
...
@@ -361,8 +361,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
...
@@ -361,8 +361,8 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
After transformed, q and z are created in parent scope. For example,
After transformed, q and z are created in parent scope. For example,
x, y = 5, 10
x, y = 5, 10
q = paddle.jit.dy2static.
data_layer_not_check(name='q', shape=[-1], dtype='float32
')
q = paddle.jit.dy2static.
UndefindVar('q
')
z = paddle.jit.dy2static.
data_layer_not_check(name='z', shape=[-1], dtype='float32
')
z = paddle.jit.dy2static.
UndefindVar('z
')
def true_func(x, y, q):
def true_func(x, y, q):
x = x+1
x = x+1
...
@@ -647,7 +647,7 @@ def create_convert_ifelse_node(return_name_ids,
...
@@ -647,7 +647,7 @@ def create_convert_ifelse_node(return_name_ids,
false_func_source
=
false_func
.
name
false_func_source
=
false_func
.
name
convert_ifelse_layer
=
gast
.
parse
(
convert_ifelse_layer
=
gast
.
parse
(
'_jst.
convert_ife
lse('
'_jst.
IfE
lse('
'{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})'
'{pred}, {true_fn}, {false_fn}, {get_args}, {set_args}, {return_name_ids})'
.
format
(
.
format
(
pred
=
ast_to_source_code
(
pred
),
pred
=
ast_to_source_code
(
pred
),
...
...
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
浏览文件 @
6cb24967
...
@@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer):
...
@@ -252,7 +252,7 @@ class ListTransformer(gast.NodeTransformer):
# 2. pop stmt for a list or dict if len(args_str) == 1
# 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2
# 3. pop stmt for a dict if len(args_str) == 2
if
len
(
args_str
)
<=
2
:
if
len
(
args_str
)
<=
2
:
new_pop_str
=
"_jst.
convert_p
op({}, {})"
\
new_pop_str
=
"_jst.
P
op({}, {})"
\
.
format
(
target_str
,
","
.
join
(
args_str
))
.
format
(
target_str
,
","
.
join
(
args_str
))
new_pop_node
=
gast
.
parse
(
new_pop_str
).
body
[
0
].
value
new_pop_node
=
gast
.
parse
(
new_pop_str
).
body
[
0
].
value
return
new_pop_node
return
new_pop_node
...
...
python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py
浏览文件 @
6cb24967
...
@@ -43,7 +43,7 @@ class LogicalTransformer(gast.NodeTransformer):
...
@@ -43,7 +43,7 @@ class LogicalTransformer(gast.NodeTransformer):
a = x > 1 and y < 1
a = x > 1 and y < 1
Transformed code:
Transformed code:
a =
paddle.jit.dy2static.convert_logical_a
nd(lambda:x>1, lambda:y<1)
a =
_jst.A
nd(lambda:x>1, lambda:y<1)
"""
"""
def
__init__
(
self
,
wrapper_root
):
def
__init__
(
self
,
wrapper_root
):
...
@@ -57,7 +57,7 @@ class LogicalTransformer(gast.NodeTransformer):
...
@@ -57,7 +57,7 @@ class LogicalTransformer(gast.NodeTransformer):
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
if
isinstance
(
node
.
op
,
gast
.
Not
):
if
isinstance
(
node
.
op
,
gast
.
Not
):
arg
=
ast_to_source_code
(
node
.
operand
)
arg
=
ast_to_source_code
(
node
.
operand
)
new_node_str
=
"_jst.
convert_logical_n
ot({})"
.
format
(
arg
)
new_node_str
=
"_jst.
N
ot({})"
.
format
(
arg
)
# NOTE: gast.parse returns Module(body=[expr(value=...)])
# NOTE: gast.parse returns Module(body=[expr(value=...)])
new_node
=
gast
.
parse
(
new_node_str
).
body
[
0
].
value
new_node
=
gast
.
parse
(
new_node_str
).
body
[
0
].
value
return
new_node
return
new_node
...
@@ -66,9 +66,9 @@ class LogicalTransformer(gast.NodeTransformer):
...
@@ -66,9 +66,9 @@ class LogicalTransformer(gast.NodeTransformer):
def
visit_BoolOp
(
self
,
node
):
def
visit_BoolOp
(
self
,
node
):
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
if
isinstance
(
node
.
op
,
gast
.
And
):
if
isinstance
(
node
.
op
,
gast
.
And
):
new_node
=
self
.
_create_bool_op_node
(
node
.
values
,
'
a
nd'
)
new_node
=
self
.
_create_bool_op_node
(
node
.
values
,
'
A
nd'
)
elif
isinstance
(
node
.
op
,
gast
.
Or
):
elif
isinstance
(
node
.
op
,
gast
.
Or
):
new_node
=
self
.
_create_bool_op_node
(
node
.
values
,
'
o
r'
)
new_node
=
self
.
_create_bool_op_node
(
node
.
values
,
'
O
r'
)
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"Only supports and/or syntax in control flow if statement."
)
"Only supports and/or syntax in control flow if statement."
)
...
@@ -95,7 +95,7 @@ class LogicalTransformer(gast.NodeTransformer):
...
@@ -95,7 +95,7 @@ class LogicalTransformer(gast.NodeTransformer):
nodes
=
[
pre_logic_node
]
+
[
post_logic_node
]
nodes
=
[
pre_logic_node
]
+
[
post_logic_node
]
args
=
[
ast_to_source_code
(
child
)
for
child
in
nodes
]
args
=
[
ast_to_source_code
(
child
)
for
child
in
nodes
]
new_node_str
=
"_jst.
convert_logical_
{}(lambda:{}, lambda:{})"
.
format
(
new_node_str
=
"_jst.{}(lambda:{}, lambda:{})"
.
format
(
api_type
,
args
[
0
],
args
[
1
])
api_type
,
args
[
0
],
args
[
1
])
# NOTE: gast.parse return Module(body=[expr(...)])
# NOTE: gast.parse return Module(body=[expr(...)])
new_node
=
gast
.
parse
(
new_node_str
).
body
[
0
].
value
new_node
=
gast
.
parse
(
new_node_str
).
body
[
0
].
value
...
...
python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py
浏览文件 @
6cb24967
...
@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
...
@@ -28,7 +28,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ForLoopTuplePreTransformer
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ForLoopTuplePreTransformer
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ForNodeVisitor
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ForNodeVisitor
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
RenameTransformer
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
RenameTransformer
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_
static_variable_gas
t_node
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_
fill_constan
t_node
__all__
=
[
'LoopTransformer'
,
'NameVisitor'
]
__all__
=
[
'LoopTransformer'
,
'NameVisitor'
]
...
@@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
...
@@ -89,7 +89,7 @@ def create_while_nodes(condition_name, body_name, loop_var_names):
else
:
else
:
assign_loop_var_names
.
append
(
name
)
assign_loop_var_names
.
append
(
name
)
while_func_name
=
"_jst.
convert_while_loop
"
while_func_name
=
"_jst.
While
"
while_node_str
=
"[{}] = {}({}, {}, [{}])"
.
format
(
while_node_str
=
"[{}] = {}({}, {}, [{}])"
.
format
(
","
.
join
(
assign_loop_var_names
),
while_func_name
,
condition_name
,
","
.
join
(
assign_loop_var_names
),
while_func_name
,
condition_name
,
body_name
,
","
.
join
(
loop_var_names
))
body_name
,
","
.
join
(
loop_var_names
))
...
@@ -672,7 +672,7 @@ class LoopTransformer(gast.NodeTransformer):
...
@@ -672,7 +672,7 @@ class LoopTransformer(gast.NodeTransformer):
# We need to create static variable for those variables
# We need to create static variable for those variables
for
name
in
create_var_names
:
for
name
in
create_var_names
:
if
"."
not
in
name
:
if
"."
not
in
name
:
new_stmts
.
append
(
create_
static_variable_gas
t_node
(
name
))
new_stmts
.
append
(
create_
fill_constan
t_node
(
name
))
# 4. append init statements
# 4. append init statements
new_stmts
.
extend
(
init_stmts
)
new_stmts
.
extend
(
init_stmts
)
...
@@ -756,7 +756,7 @@ class LoopTransformer(gast.NodeTransformer):
...
@@ -756,7 +756,7 @@ class LoopTransformer(gast.NodeTransformer):
# We need to create static variable for those variables
# We need to create static variable for those variables
for
name
in
create_var_names
:
for
name
in
create_var_names
:
if
"."
not
in
name
:
if
"."
not
in
name
:
new_stmts
.
append
(
create_
static_variable_gas
t_node
(
name
))
new_stmts
.
append
(
create_
fill_constan
t_node
(
name
))
condition_func_node
=
gast
.
FunctionDef
(
condition_func_node
=
gast
.
FunctionDef
(
name
=
unique_name
.
generate
(
WHILE_CONDITION_PREFIX
),
name
=
unique_name
.
generate
(
WHILE_CONDITION_PREFIX
),
...
...
python/paddle/fluid/dygraph/dygraph_to_static/print_transformer.py
浏览文件 @
6cb24967
...
@@ -50,5 +50,5 @@ class PrintTransformer(gast.NodeTransformer):
...
@@ -50,5 +50,5 @@ class PrintTransformer(gast.NodeTransformer):
return
gast
.
Expr
(
value
=
convert_print_node
)
return
gast
.
Expr
(
value
=
convert_print_node
)
def
_create_print_node
(
self
,
print_args
):
def
_create_print_node
(
self
,
print_args
):
convert_print_func
=
gast
.
parse
(
'_jst.
convert_p
rint'
).
body
[
0
].
value
convert_print_func
=
gast
.
parse
(
'_jst.
P
rint'
).
body
[
0
].
value
return
gast
.
Call
(
func
=
convert_print_func
,
args
=
print_args
,
keywords
=
[])
return
gast
.
Call
(
func
=
convert_print_func
,
args
=
print_args
,
keywords
=
[])
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
浏览文件 @
6cb24967
...
@@ -14,22 +14,16 @@
...
@@ -14,22 +14,16 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
copy
from
paddle.utils
import
gast
from
paddle.utils
import
gast
from
paddle.fluid
import
unique_name
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
slice_is_num
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_paddle_api
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
StaticAnalysisVisitor
class
TensorShapeTransformer
(
gast
.
NodeTransformer
):
class
TensorShapeTransformer
(
gast
.
NodeTransformer
):
"""
"""
This class transforms variable.shape into Static Graph Ast.
This class transforms variable.shape into Static Graph Ast.
All 'xxx.shape' will be converted int '_jst.
convert_s
hape(x)'.
All 'xxx.shape' will be converted int '_jst.
S
hape(x)'.
"""
"""
def
__init__
(
self
,
wrapper_root
):
def
__init__
(
self
,
wrapper_root
):
...
@@ -48,7 +42,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -48,7 +42,7 @@ class TensorShapeTransformer(gast.NodeTransformer):
# NOTE(dev): we can deal with paddle.shape in this case, but it's
# NOTE(dev): we can deal with paddle.shape in this case, but it's
# not pretty to modify into 'convert_shape(paddle)(x)[0]'.
# not pretty to modify into 'convert_shape(paddle)(x)[0]'.
if
args
!=
'paddle'
:
if
args
!=
'paddle'
:
convert_shape_func
=
"_jst.
convert_s
hape({})"
.
format
(
args
)
convert_shape_func
=
"_jst.
S
hape({})"
.
format
(
args
)
shape_node
=
gast
.
parse
(
convert_shape_func
).
body
[
0
].
value
shape_node
=
gast
.
parse
(
convert_shape_func
).
body
[
0
].
value
return
shape_node
return
shape_node
return
node
return
node
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
6cb24967
...
@@ -1178,7 +1178,7 @@ class ForNodeVisitor(object):
...
@@ -1178,7 +1178,7 @@ class ForNodeVisitor(object):
else
:
else
:
iter_var_name
=
ast_to_source_code
(
self
.
iter_node
).
strip
()
iter_var_name
=
ast_to_source_code
(
self
.
iter_node
).
strip
()
convert_len_node_source_str
=
'{} = _jst.
convert_l
en({})'
.
format
(
convert_len_node_source_str
=
'{} = _jst.
L
en({})'
.
format
(
self
.
iter_var_len_name
,
iter_var_name
)
self
.
iter_var_len_name
,
iter_var_name
)
convert_len_node
=
gast
.
parse
(
convert_len_node_source_str
).
body
[
0
]
convert_len_node
=
gast
.
parse
(
convert_len_node_source_str
).
body
[
0
]
...
...
python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py
浏览文件 @
6cb24967
...
@@ -23,57 +23,11 @@ from paddle.fluid.framework import Variable
...
@@ -23,57 +23,11 @@ from paddle.fluid.framework import Variable
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layer_helper
import
LayerHelper
__all__
=
[
__all__
=
[
'create_bool_as_type'
,
'create_fill_constant_node'
,
'create_bool_as_type'
,
'create_fill_constant_node'
,
'to_static_variable'
,
'create_static_variable_gast_node'
,
'data_layer_not_check'
,
'create_undefined_var'
'to_static_variable'
,
'to_static_variable_gast_node'
,
'create_undefined_var'
]
]
def
data_layer_not_check
(
name
,
shape
,
dtype
=
'float32'
,
lod_level
=
0
):
"""
This function creates a Tensor on the global block. The created Tensor
doesn't check the dtype and the shape of feed data because dygraph input
data can be various-length. This API is used in translating dygraph into
static graph.
Note:
The default :code:`stop_gradient` attribute of the Tensor created by
this API is true, which means the gradient won't be passed backward
through the data Tensor. Set :code:`var.stop_gradient = False` If
user would like to pass backward gradient.
Args:
name (str): The name/alias of the Tensor, see :ref:`api_guide_Name`
for more details.
shape (list|tuple): List|Tuple of integers declaring the shape. You can
set "None" at a dimension to indicate the dimension can be of any
size. For example, it is useful to set changeable batch size as "None"
dtype (np.dtype|VarType|str, optional): The type of the data. Supported
dtype: bool, float16, float32, float64, int8, int16, int32, int64,
uint8. Default: float32
lod_level (int, optional): The LoD level of the LoDTensor. Usually users
don't have to set this value. For more details about when and how to
use LoD level, see :ref:`user_guide_lod_tensor` . Default: 0
Returns:
Tensor: The global Tensor that gives access to the data.
"""
helper
=
LayerHelper
(
'data'
,
**
locals
())
shape
=
list
(
shape
)
for
i
in
six
.
moves
.
range
(
len
(
shape
)):
if
shape
[
i
]
is
None
:
shape
[
i
]
=
-
1
return
helper
.
create_global_variable
(
name
=
name
,
shape
=
shape
,
dtype
=
dtype
,
type
=
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
stop_gradient
=
True
,
lod_level
=
lod_level
,
is_data
=
True
,
need_check_feed
=
False
)
def
create_undefined_var
(
name
):
def
create_undefined_var
(
name
):
func_code
=
"{} = _jst.UndefinedVar('{}')"
.
format
(
name
,
name
)
func_code
=
"{} = _jst.UndefinedVar('{}')"
.
format
(
name
,
name
)
return
gast
.
parse
(
func_code
).
body
[
0
]
return
gast
.
parse
(
func_code
).
body
[
0
]
...
@@ -85,18 +39,7 @@ def create_nonlocal_stmt_node(names):
...
@@ -85,18 +39,7 @@ def create_nonlocal_stmt_node(names):
return
gast
.
parse
(
func_code
).
body
[
0
]
return
gast
.
parse
(
func_code
).
body
[
0
]
def
to_static_variable_gast_node
(
name
):
def
create_fill_constant_node
(
name
,
value
=
0
):
func_code
=
"{} = _jst.to_static_variable({})"
.
format
(
name
,
name
)
return
gast
.
parse
(
func_code
).
body
[
0
]
def
create_static_variable_gast_node
(
name
):
func_code
=
"{} = _jst.data_layer_not_check(name='{}', shape=[-1], dtype='float32')"
.
format
(
name
,
unique_name
.
generate
(
name
))
return
gast
.
parse
(
func_code
).
body
[
0
]
def
create_fill_constant_node
(
name
,
value
):
func_code
=
"{} = paddle.full(shape=[1], "
.
format
(
name
)
func_code
=
"{} = paddle.full(shape=[1], "
.
format
(
name
)
if
isinstance
(
value
,
bool
):
if
isinstance
(
value
,
bool
):
func_code
+=
"dtype='bool', fill_value={}, name='{}')"
.
format
(
func_code
+=
"dtype='bool', fill_value={}, name='{}')"
.
format
(
...
@@ -121,7 +64,6 @@ def to_static_variable(x):
...
@@ -121,7 +64,6 @@ def to_static_variable(x):
return
paddle
.
full
(
shape
=
[
1
],
dtype
=
'bool'
,
fill_value
=
x
)
return
paddle
.
full
(
shape
=
[
1
],
dtype
=
'bool'
,
fill_value
=
x
)
if
isinstance
(
x
,
float
):
if
isinstance
(
x
,
float
):
return
paddle
.
full
(
shape
=
[
1
],
dtype
=
'float64'
,
fill_value
=
x
)
return
paddle
.
full
(
shape
=
[
1
],
dtype
=
'float64'
,
fill_value
=
x
)
if
isinstance
(
x
,
six
.
integer_types
):
if
isinstance
(
x
,
six
.
integer_types
):
return
paddle
.
full
(
shape
=
[
1
],
dtype
=
'int64'
,
fill_value
=
x
)
return
paddle
.
full
(
shape
=
[
1
],
dtype
=
'int64'
,
fill_value
=
x
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py
浏览文件 @
6cb24967
...
@@ -72,10 +72,8 @@ def dyfunc_with_if_else3(x):
...
@@ -72,10 +72,8 @@ def dyfunc_with_if_else3(x):
# The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# The transformed code:
# The transformed code:
"""
"""
q = paddle.jit.dy2static.
q = paddle.jit.dy2static.UndefinedVar('q')
data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = paddle.jit.dy2static.UndefinedVar('z')
z = paddle.jit.dy2static.
data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_fn_0(q, x, y):
def true_fn_0(q, x, y):
x = x + 1
x = x + 1
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py
浏览文件 @
6cb24967
...
@@ -266,7 +266,7 @@ class TestDynamicToStaticCode(unittest.TestCase):
...
@@ -266,7 +266,7 @@ class TestDynamicToStaticCode(unittest.TestCase):
return
get_source_code
(
self
.
answer_func
)
return
get_source_code
(
self
.
answer_func
)
def
_get_transformed_code
(
self
):
def
_get_transformed_code
(
self
):
transformed_func
=
_jst
.
convert_c
all
(
self
.
func
)
transformed_func
=
_jst
.
C
all
(
self
.
func
)
return
get_source_code
(
transformed_func
)
return
get_source_code
(
transformed_func
)
def
test_code
(
self
):
def
test_code
(
self
):
...
@@ -289,7 +289,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode):
...
@@ -289,7 +289,7 @@ class TestDynamicToStaticCode2(TestDynamicToStaticCode):
class
StaticCode
():
class
StaticCode
():
def
func_convert_then_not_to_static
(
x
):
def
func_convert_then_not_to_static
(
x
):
y
=
_jst
.
convert_c
all
(
func_not_to_static
)(
x
)
y
=
_jst
.
C
all
(
func_not_to_static
)(
x
)
return
y
return
y
self
.
answer_func
=
StaticCode
.
func_convert_then_not_to_static
self
.
answer_func
=
StaticCode
.
func_convert_then_not_to_static
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py
浏览文件 @
6cb24967
...
@@ -277,6 +277,7 @@ class TestListInWhileLoop(TestListWithoutControlFlow):
...
@@ -277,6 +277,7 @@ class TestListInWhileLoop(TestListWithoutControlFlow):
with
fluid
.
dygraph
.
guard
():
with
fluid
.
dygraph
.
guard
():
if
to_static
:
if
to_static
:
print
(
declarative
(
self
.
dygraph_func
).
code
)
res
=
declarative
(
self
.
dygraph_func
)(
self
.
input
,
self
.
iter_num
)
res
=
declarative
(
self
.
dygraph_func
)(
self
.
input
,
self
.
iter_num
)
else
:
else
:
res
=
self
.
dygraph_func
(
self
.
input
,
self
.
iter_num
)
res
=
self
.
dygraph_func
(
self
.
input
,
self
.
iter_num
)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py
浏览文件 @
6cb24967
...
@@ -90,7 +90,7 @@ class StaticCode1():
...
@@ -90,7 +90,7 @@ class StaticCode1():
x_v
=
x_v
+
1
x_v
=
x_v
+
1
return
x_v
return
x_v
_jst
.
convert_ife
lse
(
_jst
.
IfE
lse
(
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_0
,
false_fn_0
,
get_args_0
,
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_0
,
false_fn_0
,
get_args_0
,
set_args_0
,
(
'x_v'
,
))
set_args_0
,
(
'x_v'
,
))
...
@@ -115,8 +115,8 @@ class StaticCode1():
...
@@ -115,8 +115,8 @@ class StaticCode1():
__return_value_0
=
x_v
__return_value_0
=
x_v
return
__return_value_0
return
__return_value_0
_jst
.
convert_ifelse
(
label
is
not
None
,
true_fn_1
,
false_fn
_1
,
_jst
.
IfElse
(
label
is
not
None
,
true_fn_1
,
false_fn_1
,
get_args
_1
,
get_args_1
,
set_args_1
,
(
'__return_value_0'
,
))
set_args_1
,
(
'__return_value_0'
,
))
return
__return_value_0
return
__return_value_0
...
@@ -147,7 +147,7 @@ class StaticCode2():
...
@@ -147,7 +147,7 @@ class StaticCode2():
x_v
=
x_v
+
1
x_v
=
x_v
+
1
return
x_v
return
x_v
_jst
.
convert_ife
lse
(
_jst
.
IfE
lse
(
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_2
,
false_fn_2
,
get_args_2
,
fluid
.
layers
.
mean
(
x_v
)[
0
]
>
5
,
true_fn_2
,
false_fn_2
,
get_args_2
,
set_args_2
,
(
'x_v'
,
))
set_args_2
,
(
'x_v'
,
))
...
@@ -172,8 +172,8 @@ class StaticCode2():
...
@@ -172,8 +172,8 @@ class StaticCode2():
__return_value_1
=
x_v
__return_value_1
=
x_v
return
__return_value_1
return
__return_value_1
_jst
.
convert_ifelse
(
label
is
not
None
,
true_fn_3
,
false_fn
_3
,
_jst
.
IfElse
(
label
is
not
None
,
true_fn_3
,
false_fn_3
,
get_args
_3
,
get_args_3
,
set_args_3
,
(
'__return_value_1'
,
))
set_args_3
,
(
'__return_value_1'
,
))
return
__return_value_1
return
__return_value_1
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
浏览文件 @
6cb24967
...
@@ -275,7 +275,6 @@ class TestTensorShapeBasic(unittest.TestCase):
...
@@ -275,7 +275,6 @@ class TestTensorShapeBasic(unittest.TestCase):
self
.
expected_slice_op_num
=
0
self
.
expected_slice_op_num
=
0
def
_compute_op_num
(
self
,
program
):
def
_compute_op_num
(
self
,
program
):
print
(
program
)
self
.
op_num
=
sum
([
len
(
block
.
ops
)
for
block
in
program
.
blocks
])
self
.
op_num
=
sum
([
len
(
block
.
ops
)
for
block
in
program
.
blocks
])
self
.
shape_op_num
=
0
self
.
shape_op_num
=
0
self
.
slice_op_num
=
0
self
.
slice_op_num
=
0
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_variable_trans_func.py
浏览文件 @
6cb24967
...
@@ -22,30 +22,6 @@ import paddle.fluid as fluid
...
@@ -22,30 +22,6 @@ import paddle.fluid as fluid
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_fill_constant_node
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_fill_constant_node
from
paddle.fluid.dygraph.dygraph_to_static.variable_trans_func
import
data_layer_not_check
class
TestDataLayerNotCheck
(
unittest
.
TestCase
):
def
test_create_none_shape
(
self
):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
d
=
data_layer_not_check
(
name
=
"d"
,
shape
=
(
None
,
-
1
,
3
))
self
.
assertEqual
(
d
.
shape
,
(
-
1
,
-
1
,
3
))
self
.
assertEqual
(
d
.
name
,
"d"
)
def
test_feed_mismatch_shape
(
self
):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
d
=
data_layer_not_check
(
name
=
"d"
,
shape
=
(
1
,
2
,
3
))
feed_in_data
=
np
.
random
.
uniform
(
size
=
[
1
,
2
,
4
]).
astype
(
np
.
float32
)
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
()
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
ret
=
exe
.
run
(
main_program
,
feed
=
{
d
.
name
:
feed_in_data
},
fetch_list
=
[
d
.
name
])
self
.
assertTrue
(
np
.
allclose
(
ret
,
feed_in_data
))
class
TestVariableTransFunc
(
unittest
.
TestCase
):
class
TestVariableTransFunc
(
unittest
.
TestCase
):
...
...
python/paddle/jit/dy2static/__init__.py
浏览文件 @
6cb24967
...
@@ -14,25 +14,21 @@
...
@@ -14,25 +14,21 @@
from
.base
import
saw
from
.base
import
saw
from
.base
import
UndefinedVar
from
.base
import
UndefinedVar
from
.convert_call_func
import
convert_call
# noqa: F401
from
.convert_operators
import
convert_logical_and
as
And
# noqa: F401
from
.convert_operators
import
cast_bool_if_necessary
# noqa: F401
from
.convert_operators
import
convert_var_dtype
as
AsDtype
# noqa: F401
from
.convert_operators
import
convert_assert
# noqa: F401
from
.convert_operators
import
convert_assert
as
Assert
# noqa: F401
from
.convert_operators
import
convert_ifelse
# noqa: F401
from
.convert_call_func
import
convert_call
as
Call
# noqa: F401
from
.convert_operators
import
convert_len
# noqa: F401
from
.convert_operators
import
convert_ifelse
as
IfElse
# noqa: F401
from
.convert_operators
import
convert_logical_and
# noqa: F401
from
.convert_operators
import
convert_len
as
Len
# noqa: F401
from
.convert_operators
import
convert_logical_not
# noqa: F401
from
.convert_operators
import
convert_logical_not
as
Not
# noqa: F401
from
.convert_operators
import
convert_logical_or
# noqa: F401
from
.convert_operators
import
convert_logical_or
as
Or
# noqa: F401
from
.convert_operators
import
convert_pop
# noqa: F401
from
.convert_operators
import
convert_pop
as
Pop
# noqa: F401
from
.convert_operators
import
convert_print
# noqa: F401
from
.convert_operators
import
convert_print
as
Print
# noqa: F401
from
.convert_operators
import
convert_shape_compare
# noqa: F401
from
.convert_operators
import
convert_shape
as
Shape
# noqa: F401
from
.convert_operators
import
convert_var_dtype
# noqa: F401
from
.convert_operators
import
convert_while_loop
as
While
# noqa: F401
from
.convert_operators
import
convert_shape
# noqa: F401
from
.convert_operators
import
convert_while_loop
# noqa: F401
from
.variable_trans_func
import
create_bool_as_type
# noqa: F401
from
.variable_trans_func
import
create_bool_as_type
# noqa: F401
from
.variable_trans_func
import
create_fill_constant_node
# noqa: F401
from
.variable_trans_func
import
create_static_variable_gast_node
# noqa: F401
from
.variable_trans_func
import
data_layer_not_check
# noqa: F401
from
.variable_trans_func
import
to_static_variable
# noqa: F401
from
.variable_trans_func
import
to_static_variable
# noqa: F401
from
.
variable_trans_func
import
to_static_variable_gast_nod
e
# noqa: F401
from
.
convert_operators
import
convert_shape_compar
e
# noqa: F401
__all__
=
[]
__all__
=
[]
python/paddle/jit/dy2static/variable_trans_func.py
浏览文件 @
6cb24967
...
@@ -15,10 +15,6 @@
...
@@ -15,10 +15,6 @@
from
__future__
import
print_function
from
__future__
import
print_function
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_bool_as_type
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_bool_as_type
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_fill_constant_node
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
create_static_variable_gast_node
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
data_layer_not_check
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
to_static_variable
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
to_static_variable
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.variable_trans_func
import
to_static_variable_gast_node
# noqa: F401
__all__
=
[]
__all__
=
[]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录