未验证 提交 2e238c6e 编写于 作者: L liym27 提交者: GitHub

[[Dy2Static]]convert for stmt and support variable loaded and created in loop(#24901)

* Move function 'convert_len' to file convert_operators.py 

* Support that for statements are transformed to while statements.

* Fix bug: raise None -> return None. 

* Support variable loaded and created in loop.

* Use int64 in Py2 and Py3 in function to_static_variable.
上级 9a2c1aed
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import framework
from paddle.fluid import core
from paddle.fluid.layers import nn
from paddle.fluid.layers import control_flow
def convert_len(var):
"""
return variable(length) from shape ops based on var.type
Note: In addition to some ast transformations, some block-related
operations are added in `len` transformation, such as appending
`shape_op` in var.block.
"""
if isinstance(var, framework.Variable):
if var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS
]:
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
return nn.shape(var)[0]
elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return control_flow.array_length(var)
else:
raise TypeError(
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var))
else:
return len(var)
......@@ -29,7 +29,7 @@ import six
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.dygraph.dygraph_to_static.convert_builtins_func import convert_len
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator()
......
......@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.framework import Variable
from paddle.fluid.layers import control_flow, logical_and, logical_or, logical_not, cast
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
def convert_while_loop(cond, body, loop_vars):
......@@ -175,6 +175,33 @@ def _run_py_ifelse(pred, true_fn, false_fn):
return true_fn() if pred else false_fn()
def convert_len(var):
"""
Returns variable(length) from shape ops based on var.type
Note: In addition to some ast transformations, some block-related
operations are added in `len` transformation, such as appending
`shape_op` in var.block.
"""
if isinstance(var, Variable):
if var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS
]:
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
return nn.shape(var)[0]
elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return control_flow.array_length(var)
else:
raise TypeError(
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var))
else:
return len(var)
def cast_bool_if_necessary(var):
assert isinstance(var, Variable)
if convert_dtype(var.dtype) not in ['bool']:
......
......@@ -22,14 +22,11 @@ from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
__all__ = ['LoopTransformer', 'NameVisitor']
......@@ -89,7 +86,8 @@ class NameVisitor(gast.NodeVisitor):
# Mapping from gast.While/gast.For to variable nodes
self.before_loop_body_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set)
# NOTE: Use ordered list as dict value
self.in_loop_vars = defaultdict(list)
# Mapping from gast.While/gast.For to variable nodes which is condition
# of loop or being modified during the loop
......@@ -103,11 +101,6 @@ class NameVisitor(gast.NodeVisitor):
self.visit(root_node)
def is_control_flow_loop(self, node):
need_transform = is_control_flow_to_transform(
node, self.static_analysis_visitor)
return need_transform
def get_loop_var_names(self, node):
assert isinstance(
node, (gast.While, gast.For)), "Input node is not gast loop node"
......@@ -115,7 +108,15 @@ class NameVisitor(gast.NodeVisitor):
create_var_names = set()
read_context = {type(gast.Load()), type(gast.AugLoad())}
in_loop_vars = self.in_loop_vars[node]
in_loop_vars_list = self.in_loop_vars[node]
# get dict `var_name_to_ctxs`
var_name_to_ctxs = defaultdict(list)
for var_node in in_loop_vars_list:
var_name_to_ctxs[self._var_node_to_name(var_node)].append(
var_node.ctx)
in_loop_vars = set(in_loop_vars_list)
in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
before_loop_body_vars = self.before_loop_body_vars[node]
......@@ -160,6 +161,22 @@ class NameVisitor(gast.NodeVisitor):
# vars out
loop_var_names.add(name)
create_var_names.add(name)
else:
# If a variable is used and created in loop, but used before created,
# it should be in loop_var and we should create it.
# For example, `var_a` should be in loop_var and we should create it.
#
# res = 0
# for i, x in enumerate(x_array):
# if i > 2:
# x = func1(var_a)
# var_a = func2(x)
#
if isinstance(var_name_to_ctxs[name][0], gast.Load):
loop_var_names.add(name)
create_var_names.add(name)
return loop_var_names, create_var_names
......@@ -176,7 +193,7 @@ class NameVisitor(gast.NodeVisitor):
type(gast.Store()), type(gast.AugStore()), type(gast.Del())
}
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
self.in_loop_vars[loop_node].append(node)
if type(node.ctx) in write_context:
self.write_in_loop[loop_node].add(node)
if self.in_condition:
......@@ -219,7 +236,7 @@ class NameVisitor(gast.NodeVisitor):
self.current_seen_vars.add(node)
for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node)
self.in_loop_vars[loop_node].append(node)
# sub-nodes are visited during get_attribute_full_name and we shouldn't
# visit again
......@@ -367,27 +384,25 @@ class LoopTransformer(gast.NodeTransformer):
def get_for_stmt_nodes(self, node):
# TODO: consider for - else in python
# 1. check whether need to transform
# NOTE: Current need transform cases:
# 1). for x in range(VarBase[0]|VarBase.numpy()[0])
# 2). for x in VarBase|VarBase.numpy()
# 3). for i, x in enumerate(VarBase|VarBase.numpy())
if not self.name_visitor.is_control_flow_loop(node):
return [node]
# 2. get key statements for different cases
# NOTE: three key statements:
# 1. get key statements for different cases
# NOTE 1: three key statements:
# 1). init_stmts: list[node], prepare nodes of for loop, may not only one
# 2). cond_stmt: node, condition node to judge whether continue loop
# 3). body_stmts: list[node], updated loop body, sometimes we should change
# the original statement in body, not just append new statement
#
# NOTE 2: The following `for` statements will be transformed to `while` statements:
# 1). for x in range(*)
# 2). for x in iter_var
# 3). for i, x in enumerate(*)
current_for_node_parser = ForNodeVisitor(node)
stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None:
return [node]
init_stmts, cond_stmt, body_stmts = stmts_tuple
# 3. get original loop vars
# 2. get original loop vars
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node)
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
......@@ -402,7 +417,7 @@ class LoopTransformer(gast.NodeTransformer):
if iter_var_name not in create_var_names:
loop_var_names.remove(iter_var_name)
# 4. prepare result statement list
# 3. prepare result statement list
new_stmts = []
# Python can create variable in loop and use it out of loop, E.g.
#
......@@ -415,13 +430,10 @@ class LoopTransformer(gast.NodeTransformer):
if "." not in name:
new_stmts.append(create_static_variable_gast_node(name))
# 5. append init statements
# 4. append init statements
new_stmts.extend(init_stmts)
# for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
for name in loop_var_names:
new_stmts.append(to_static_variable_gast_node(name))
# 6. create & append condition function node
# 5. create & append condition function node
condition_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments(
......@@ -449,7 +461,7 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node)
# 7. create & append loop body function node
# 6. create & append loop body function node
# append return values for loop body
body_stmts.append(
gast.Return(value=generate_name_node(
......@@ -481,7 +493,7 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node)
# 8. create & append while loop node
# 7. create & append while loop node
while_loop_node = create_while_node(condition_func_node.name,
body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node)
......
......@@ -38,7 +38,7 @@ dygraph_class_to_static_api = {
}
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_VAR_SHAPE_PREFIX = '__for_loop_var_shape'
FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
def _is_api_in_module_helper(obj, module_prefix):
......@@ -668,7 +668,7 @@ class ForNodeVisitor(object):
In this process, the semantics of for does not change.
Now only can parse 3 type statements (Here var is VarBase(Tensor)):
Now only can parse 3 type statements (Here var is VarBase(Tensor) or python variable):
1). for x in range(var[*]|var.numpy()[*])
2). for x in var|var.numpy()
3). for i, x enumerate(var|var.numpy())
......@@ -700,12 +700,11 @@ class ForNodeVisitor(object):
# - for i, x enumerate(var|var.numpy())
self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX)
# - created shape var to build loop condition: __for_loop_var_shape_0
# - created shape var to build loop condition: __for_loop_var_len_0
# - for x in var|var.numpy()
# - for i, x enumerate(var|var.numpy())
# - for x in var
self.iter_var_shape_name = unique_name.generate(
FOR_ITER_VAR_SHAPE_PREFIX)
self.iter_var_len_name = unique_name.generate(FOR_ITER_VAR_LEN_PREFIX)
# - var.numpy()/var
# - for x in var|var.numpy()
......@@ -728,7 +727,7 @@ class ForNodeVisitor(object):
elif self.is_for_enumerate_iter():
return self._parse_for_enumerate_stmts()
else:
raise None
return None
def is_for_range_iter(self):
return isinstance(self.node.iter, gast.Call) and isinstance(
......@@ -736,7 +735,7 @@ class ForNodeVisitor(object):
gast.Name) and self.node.iter.func.id == "range"
def is_for_iter(self):
if isinstance(self.node.iter, gast.Name):
if isinstance(self.node.iter, (gast.Name, gast.Attribute)):
return True
elif isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func,
......@@ -776,7 +775,7 @@ class ForNodeVisitor(object):
def _parse_for_stmts(self):
init_stmts = []
init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_shape_assign_node())
init_stmts.append(self._build_var_len_assign_node())
compare_node = self._build_compare_node()
step_node = self._build_step_node()
......@@ -794,7 +793,7 @@ class ForNodeVisitor(object):
def _parse_for_enumerate_stmts(self):
init_stmts = []
init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_shape_assign_node())
init_stmts.append(self._build_var_len_assign_node())
init_stmts.append(self._build_enum_init_node())
compare_node = self._build_compare_node()
......@@ -814,51 +813,49 @@ class ForNodeVisitor(object):
def _build_index_init_node(self):
if self.is_for_range_iter():
if self.args_length == 1:
index_init_node = get_constant_variable_node(self.iter_var_name,
0)
index_init_value_str = '0'
else:
index_init_node = gast.Assign(
targets=[
gast.Name(
id=self.iter_var_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=self.iter_args[0])
index_init_value_str = ast_to_source_code(self.iter_args[
0]).strip()
index_init_var_name = self.iter_var_name
else:
index_init_node = get_constant_variable_node(self.iter_idx_name, 0)
index_init_value_str = '0'
index_init_var_name = self.iter_idx_name
index_init_node_source_str = "{target} = {value}".format(
target=index_init_var_name, value=index_init_value_str)
index_init_node = gast.parse(index_init_node_source_str).body[0]
return index_init_node
def _build_var_shape_assign_node(self):
# get variable shape as iter length
if isinstance(self.iter_node, gast.Call):
iter_var = self.iter_node.func
def _build_var_len_assign_node(self):
# get the length of iterable variable
if isinstance(self.iter_node, gast.Call) and isinstance(
self.iter_node.func,
gast.Attribute) and self.iter_node.func.attr == 'numpy':
iter_var_name = ast_to_source_code(self.iter_node.func.value).strip(
)
else:
iter_var = self.iter_node
return gast.Assign(
targets=[
gast.Name(
id=self.iter_var_shape_name,
ctx=gast.Load(),
annotation=None,
type_comment=None)
],
value=create_api_shape_node(iter_var))
iter_var_name = ast_to_source_code(self.iter_node).strip()
convert_len_node_source_str = '{} = fluid.dygraph.dygraph_to_static.convert_operators.convert_len({})'.format(
self.iter_var_len_name, iter_var_name)
convert_len_node = gast.parse(convert_len_node_source_str).body[0]
return convert_len_node
def _build_enum_init_node(self):
enum_init_node = get_constant_variable_node(
name=self.enum_idx_name, value=0)
if self.is_for_enumerate_iter() and self.args_length != 1:
enum_init_node = gast.Assign(
targets=[
gast.Name(
id=self.enum_idx_name,
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=self.iter_args[1])
init_value_str = ast_to_source_code(self.iter_args[1]).strip()
else:
init_value_str = '0'
enum_init_node_source_str = "{} = {}".format(self.enum_idx_name,
init_value_str)
enum_init_node = gast.parse(enum_init_node_source_str).body[0]
return enum_init_node
def _build_compare_node(self):
......@@ -866,15 +863,11 @@ class ForNodeVisitor(object):
compare_node = self.iter_args[
0] if self.args_length == 1 else self.iter_args[1]
else:
compare_node = gast.Subscript(
value=gast.Name(
id=self.iter_var_shape_name,
compare_node = gast.Name(
id=self.iter_var_len_name,
ctx=gast.Load(),
annotation=None,
type_comment=None),
slice=gast.Index(value=gast.Constant(
value=0, kind=None)),
ctx=gast.Load())
type_comment=None)
return compare_node
def _build_step_node(self):
......
......@@ -118,11 +118,10 @@ def to_static_variable(x):
return fill_constant(shape=[1], dtype='float64', value=x)
if six.PY2:
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int32', value=x)
if isinstance(x, long):
if isinstance(x, (int, long)):
return fill_constant(shape=[1], dtype='int64', value=x)
else:
if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x)
return x
......@@ -84,6 +84,16 @@ def for_loop_dyfunc(max_len):
return ret
def for_loop_dyfunc2(max_len):
# Test case: a variable is used and created in loop, but used before created
for i in range(max_len):
if i > 1:
s = a
a = 1
ret = fluid.layers.fill_constant(shape=[1], dtype="int32", value=s)
return ret
def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x)
......@@ -131,7 +141,6 @@ def for_loop_class_var(max_len):
foo = Foo()
# Use `to_variable` so that static analysis can analyze the type of X is Tensor
# TODO(liym27): Delete it if the type of parameter x can be resolved
max_len = fluid.layers.fill_constant(
shape=[1], value=max_len, dtype="int32")
......@@ -298,6 +307,11 @@ class TestTransformForLoop(unittest.TestCase):
self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestTransformForLoop2(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc2
class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_class_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册