未验证 提交 2e238c6e 编写于 作者: L liym27 提交者: GitHub

[[Dy2Static]]convert for stmt and support variable loaded and created in loop(#24901)

* Move function 'convert_len' to file convert_operators.py 

* Support that for statements are transformed to while statements.

* Fix bug: raise None -> return None. 

* Support variable loaded and created in loop.

* Use int64 in Py2 and Py3 in function to_static_variable.
上级 9a2c1aed
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from paddle.fluid import framework
from paddle.fluid import core
from paddle.fluid.layers import nn
from paddle.fluid.layers import control_flow
def convert_len(var):
"""
return variable(length) from shape ops based on var.type
Note: In addition to some ast transformations, some block-related
operations are added in `len` transformation, such as appending
`shape_op` in var.block.
"""
if isinstance(var, framework.Variable):
if var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS
]:
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
return nn.shape(var)[0]
elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return control_flow.array_length(var)
else:
raise TypeError(
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var))
else:
return len(var)
...@@ -29,7 +29,7 @@ import six ...@@ -29,7 +29,7 @@ import six
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.layers import Layer from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.dygraph.dygraph_to_static.convert_builtins_func import convert_len from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func'] DECORATOR_NAMES = ['declarative', 'dygraph_to_static_func']
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.framework import Variable
from paddle.fluid.layers import control_flow, logical_and, logical_or, logical_not, cast
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.data_feeder import convert_dtype from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
def convert_while_loop(cond, body, loop_vars): def convert_while_loop(cond, body, loop_vars):
...@@ -175,6 +175,33 @@ def _run_py_ifelse(pred, true_fn, false_fn): ...@@ -175,6 +175,33 @@ def _run_py_ifelse(pred, true_fn, false_fn):
return true_fn() if pred else false_fn() return true_fn() if pred else false_fn()
def convert_len(var):
"""
Returns variable(length) from shape ops based on var.type
Note: In addition to some ast transformations, some block-related
operations are added in `len` transformation, such as appending
`shape_op` in var.block.
"""
if isinstance(var, Variable):
if var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS
]:
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
return nn.shape(var)[0]
elif var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
return control_flow.array_length(var)
else:
raise TypeError(
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var))
else:
return len(var)
def cast_bool_if_necessary(var): def cast_bool_if_necessary(var):
assert isinstance(var, Variable) assert isinstance(var, Variable)
if convert_dtype(var.dtype) not in ['bool']: if convert_dtype(var.dtype) not in ['bool']:
......
...@@ -22,14 +22,11 @@ from paddle.fluid import unique_name ...@@ -22,14 +22,11 @@ from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer from paddle.fluid.dygraph.dygraph_to_static.utils import RenameTransformer
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable_gast_node
__all__ = ['LoopTransformer', 'NameVisitor'] __all__ = ['LoopTransformer', 'NameVisitor']
...@@ -89,7 +86,8 @@ class NameVisitor(gast.NodeVisitor): ...@@ -89,7 +86,8 @@ class NameVisitor(gast.NodeVisitor):
# Mapping from gast.While/gast.For to variable nodes # Mapping from gast.While/gast.For to variable nodes
self.before_loop_body_vars = defaultdict(set) self.before_loop_body_vars = defaultdict(set)
self.in_loop_vars = defaultdict(set) # NOTE: Use ordered list as dict value
self.in_loop_vars = defaultdict(list)
# Mapping from gast.While/gast.For to variable nodes which is condition # Mapping from gast.While/gast.For to variable nodes which is condition
# of loop or being modified during the loop # of loop or being modified during the loop
...@@ -103,11 +101,6 @@ class NameVisitor(gast.NodeVisitor): ...@@ -103,11 +101,6 @@ class NameVisitor(gast.NodeVisitor):
self.visit(root_node) self.visit(root_node)
def is_control_flow_loop(self, node):
need_transform = is_control_flow_to_transform(
node, self.static_analysis_visitor)
return need_transform
def get_loop_var_names(self, node): def get_loop_var_names(self, node):
assert isinstance( assert isinstance(
node, (gast.While, gast.For)), "Input node is not gast loop node" node, (gast.While, gast.For)), "Input node is not gast loop node"
...@@ -115,7 +108,15 @@ class NameVisitor(gast.NodeVisitor): ...@@ -115,7 +108,15 @@ class NameVisitor(gast.NodeVisitor):
create_var_names = set() create_var_names = set()
read_context = {type(gast.Load()), type(gast.AugLoad())} read_context = {type(gast.Load()), type(gast.AugLoad())}
in_loop_vars = self.in_loop_vars[node] in_loop_vars_list = self.in_loop_vars[node]
# get dict `var_name_to_ctxs`
var_name_to_ctxs = defaultdict(list)
for var_node in in_loop_vars_list:
var_name_to_ctxs[self._var_node_to_name(var_node)].append(
var_node.ctx)
in_loop_vars = set(in_loop_vars_list)
in_loop_name_strs = self._var_nodes_to_names(in_loop_vars) in_loop_name_strs = self._var_nodes_to_names(in_loop_vars)
before_loop_body_vars = self.before_loop_body_vars[node] before_loop_body_vars = self.before_loop_body_vars[node]
...@@ -160,6 +161,22 @@ class NameVisitor(gast.NodeVisitor): ...@@ -160,6 +161,22 @@ class NameVisitor(gast.NodeVisitor):
# vars out # vars out
loop_var_names.add(name) loop_var_names.add(name)
create_var_names.add(name) create_var_names.add(name)
else:
# If a variable is used and created in loop, but used before created,
# it should be in loop_var and we should create it.
# For example, `var_a` should be in loop_var and we should create it.
#
# res = 0
# for i, x in enumerate(x_array):
# if i > 2:
# x = func1(var_a)
# var_a = func2(x)
#
if isinstance(var_name_to_ctxs[name][0], gast.Load):
loop_var_names.add(name)
create_var_names.add(name)
return loop_var_names, create_var_names return loop_var_names, create_var_names
...@@ -176,7 +193,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -176,7 +193,7 @@ class NameVisitor(gast.NodeVisitor):
type(gast.Store()), type(gast.AugStore()), type(gast.Del()) type(gast.Store()), type(gast.AugStore()), type(gast.Del())
} }
for loop_node in self.current_loop: for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node) self.in_loop_vars[loop_node].append(node)
if type(node.ctx) in write_context: if type(node.ctx) in write_context:
self.write_in_loop[loop_node].add(node) self.write_in_loop[loop_node].add(node)
if self.in_condition: if self.in_condition:
...@@ -219,7 +236,7 @@ class NameVisitor(gast.NodeVisitor): ...@@ -219,7 +236,7 @@ class NameVisitor(gast.NodeVisitor):
self.current_seen_vars.add(node) self.current_seen_vars.add(node)
for loop_node in self.current_loop: for loop_node in self.current_loop:
self.in_loop_vars[loop_node].add(node) self.in_loop_vars[loop_node].append(node)
# sub-nodes are visited during get_attribute_full_name and we shouldn't # sub-nodes are visited during get_attribute_full_name and we shouldn't
# visit again # visit again
...@@ -367,27 +384,25 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -367,27 +384,25 @@ class LoopTransformer(gast.NodeTransformer):
def get_for_stmt_nodes(self, node): def get_for_stmt_nodes(self, node):
# TODO: consider for - else in python # TODO: consider for - else in python
# 1. check whether need to transform # 1. get key statements for different cases
# NOTE: Current need transform cases: # NOTE 1: three key statements:
# 1). for x in range(VarBase[0]|VarBase.numpy()[0])
# 2). for x in VarBase|VarBase.numpy()
# 3). for i, x in enumerate(VarBase|VarBase.numpy())
if not self.name_visitor.is_control_flow_loop(node):
return [node]
# 2. get key statements for different cases
# NOTE: three key statements:
# 1). init_stmts: list[node], prepare nodes of for loop, may not only one # 1). init_stmts: list[node], prepare nodes of for loop, may not only one
# 2). cond_stmt: node, condition node to judge whether continue loop # 2). cond_stmt: node, condition node to judge whether continue loop
# 3). body_stmts: list[node], updated loop body, sometimes we should change # 3). body_stmts: list[node], updated loop body, sometimes we should change
# the original statement in body, not just append new statement # the original statement in body, not just append new statement
#
# NOTE 2: The following `for` statements will be transformed to `while` statements:
# 1). for x in range(*)
# 2). for x in iter_var
# 3). for i, x in enumerate(*)
current_for_node_parser = ForNodeVisitor(node) current_for_node_parser = ForNodeVisitor(node)
stmts_tuple = current_for_node_parser.parse() stmts_tuple = current_for_node_parser.parse()
if stmts_tuple is None: if stmts_tuple is None:
return [node] return [node]
init_stmts, cond_stmt, body_stmts = stmts_tuple init_stmts, cond_stmt, body_stmts = stmts_tuple
# 3. get original loop vars # 2. get original loop vars
loop_var_names, create_var_names = self.name_visitor.get_loop_var_names( loop_var_names, create_var_names = self.name_visitor.get_loop_var_names(
node) node)
# NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases, # NOTE: in 'for x in var' or 'for i, x in enumerate(var)' cases,
...@@ -402,7 +417,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -402,7 +417,7 @@ class LoopTransformer(gast.NodeTransformer):
if iter_var_name not in create_var_names: if iter_var_name not in create_var_names:
loop_var_names.remove(iter_var_name) loop_var_names.remove(iter_var_name)
# 4. prepare result statement list # 3. prepare result statement list
new_stmts = [] new_stmts = []
# Python can create variable in loop and use it out of loop, E.g. # Python can create variable in loop and use it out of loop, E.g.
# #
...@@ -415,13 +430,10 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -415,13 +430,10 @@ class LoopTransformer(gast.NodeTransformer):
if "." not in name: if "." not in name:
new_stmts.append(create_static_variable_gast_node(name)) new_stmts.append(create_static_variable_gast_node(name))
# 5. append init statements # 4. append init statements
new_stmts.extend(init_stmts) new_stmts.extend(init_stmts)
# for x in range(10) in dygraph should be convert into static tensor + 1 <= 10
for name in loop_var_names:
new_stmts.append(to_static_variable_gast_node(name))
# 6. create & append condition function node # 5. create & append condition function node
condition_func_node = gast.FunctionDef( condition_func_node = gast.FunctionDef(
name=unique_name.generate(FOR_CONDITION_PREFIX), name=unique_name.generate(FOR_CONDITION_PREFIX),
args=gast.arguments( args=gast.arguments(
...@@ -449,7 +461,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -449,7 +461,7 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(condition_func_node) new_stmts.append(condition_func_node)
# 7. create & append loop body function node # 6. create & append loop body function node
# append return values for loop body # append return values for loop body
body_stmts.append( body_stmts.append(
gast.Return(value=generate_name_node( gast.Return(value=generate_name_node(
...@@ -481,7 +493,7 @@ class LoopTransformer(gast.NodeTransformer): ...@@ -481,7 +493,7 @@ class LoopTransformer(gast.NodeTransformer):
name, unique_name.generate(GENERATE_VARIABLE_PREFIX)) name, unique_name.generate(GENERATE_VARIABLE_PREFIX))
new_stmts.append(body_func_node) new_stmts.append(body_func_node)
# 8. create & append while loop node # 7. create & append while loop node
while_loop_node = create_while_node(condition_func_node.name, while_loop_node = create_while_node(condition_func_node.name,
body_func_node.name, loop_var_names) body_func_node.name, loop_var_names)
new_stmts.append(while_loop_node) new_stmts.append(while_loop_node)
......
...@@ -38,7 +38,7 @@ dygraph_class_to_static_api = { ...@@ -38,7 +38,7 @@ dygraph_class_to_static_api = {
} }
FOR_ITER_INDEX_PREFIX = '__for_loop_var_index' FOR_ITER_INDEX_PREFIX = '__for_loop_var_index'
FOR_ITER_VAR_SHAPE_PREFIX = '__for_loop_var_shape' FOR_ITER_VAR_LEN_PREFIX = '__for_loop_var_len'
def _is_api_in_module_helper(obj, module_prefix): def _is_api_in_module_helper(obj, module_prefix):
...@@ -668,7 +668,7 @@ class ForNodeVisitor(object): ...@@ -668,7 +668,7 @@ class ForNodeVisitor(object):
In this process, the semantics of for does not change. In this process, the semantics of for does not change.
Now only can parse 3 type statements (Here var is VarBase(Tensor)): Now only can parse 3 type statements (Here var is VarBase(Tensor) or python variable):
1). for x in range(var[*]|var.numpy()[*]) 1). for x in range(var[*]|var.numpy()[*])
2). for x in var|var.numpy() 2). for x in var|var.numpy()
3). for i, x enumerate(var|var.numpy()) 3). for i, x enumerate(var|var.numpy())
...@@ -700,12 +700,11 @@ class ForNodeVisitor(object): ...@@ -700,12 +700,11 @@ class ForNodeVisitor(object):
# - for i, x enumerate(var|var.numpy()) # - for i, x enumerate(var|var.numpy())
self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX) self.iter_idx_name = unique_name.generate(FOR_ITER_INDEX_PREFIX)
# - created shape var to build loop condition: __for_loop_var_shape_0 # - created shape var to build loop condition: __for_loop_var_len_0
# - for x in var|var.numpy() # - for x in var|var.numpy()
# - for i, x enumerate(var|var.numpy()) # - for i, x enumerate(var|var.numpy())
# - for x in var # - for x in var
self.iter_var_shape_name = unique_name.generate( self.iter_var_len_name = unique_name.generate(FOR_ITER_VAR_LEN_PREFIX)
FOR_ITER_VAR_SHAPE_PREFIX)
# - var.numpy()/var # - var.numpy()/var
# - for x in var|var.numpy() # - for x in var|var.numpy()
...@@ -728,7 +727,7 @@ class ForNodeVisitor(object): ...@@ -728,7 +727,7 @@ class ForNodeVisitor(object):
elif self.is_for_enumerate_iter(): elif self.is_for_enumerate_iter():
return self._parse_for_enumerate_stmts() return self._parse_for_enumerate_stmts()
else: else:
raise None return None
def is_for_range_iter(self): def is_for_range_iter(self):
return isinstance(self.node.iter, gast.Call) and isinstance( return isinstance(self.node.iter, gast.Call) and isinstance(
...@@ -736,7 +735,7 @@ class ForNodeVisitor(object): ...@@ -736,7 +735,7 @@ class ForNodeVisitor(object):
gast.Name) and self.node.iter.func.id == "range" gast.Name) and self.node.iter.func.id == "range"
def is_for_iter(self): def is_for_iter(self):
if isinstance(self.node.iter, gast.Name): if isinstance(self.node.iter, (gast.Name, gast.Attribute)):
return True return True
elif isinstance(self.node.iter, gast.Call) and isinstance( elif isinstance(self.node.iter, gast.Call) and isinstance(
self.node.iter.func, self.node.iter.func,
...@@ -776,7 +775,7 @@ class ForNodeVisitor(object): ...@@ -776,7 +775,7 @@ class ForNodeVisitor(object):
def _parse_for_stmts(self): def _parse_for_stmts(self):
init_stmts = [] init_stmts = []
init_stmts.append(self._build_index_init_node()) init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_shape_assign_node()) init_stmts.append(self._build_var_len_assign_node())
compare_node = self._build_compare_node() compare_node = self._build_compare_node()
step_node = self._build_step_node() step_node = self._build_step_node()
...@@ -794,7 +793,7 @@ class ForNodeVisitor(object): ...@@ -794,7 +793,7 @@ class ForNodeVisitor(object):
def _parse_for_enumerate_stmts(self): def _parse_for_enumerate_stmts(self):
init_stmts = [] init_stmts = []
init_stmts.append(self._build_index_init_node()) init_stmts.append(self._build_index_init_node())
init_stmts.append(self._build_var_shape_assign_node()) init_stmts.append(self._build_var_len_assign_node())
init_stmts.append(self._build_enum_init_node()) init_stmts.append(self._build_enum_init_node())
compare_node = self._build_compare_node() compare_node = self._build_compare_node()
...@@ -814,51 +813,49 @@ class ForNodeVisitor(object): ...@@ -814,51 +813,49 @@ class ForNodeVisitor(object):
def _build_index_init_node(self): def _build_index_init_node(self):
if self.is_for_range_iter(): if self.is_for_range_iter():
if self.args_length == 1: if self.args_length == 1:
index_init_node = get_constant_variable_node(self.iter_var_name, index_init_value_str = '0'
0)
else: else:
index_init_node = gast.Assign( index_init_value_str = ast_to_source_code(self.iter_args[
targets=[ 0]).strip()
gast.Name(
id=self.iter_var_name, index_init_var_name = self.iter_var_name
ctx=gast.Store(),
annotation=None,
type_comment=None)
],
value=self.iter_args[0])
else: else:
index_init_node = get_constant_variable_node(self.iter_idx_name, 0) index_init_value_str = '0'
index_init_var_name = self.iter_idx_name
index_init_node_source_str = "{target} = {value}".format(
target=index_init_var_name, value=index_init_value_str)
index_init_node = gast.parse(index_init_node_source_str).body[0]
return index_init_node return index_init_node
def _build_var_shape_assign_node(self): def _build_var_len_assign_node(self):
# get variable shape as iter length # get the length of iterable variable
if isinstance(self.iter_node, gast.Call): if isinstance(self.iter_node, gast.Call) and isinstance(
iter_var = self.iter_node.func self.iter_node.func,
gast.Attribute) and self.iter_node.func.attr == 'numpy':
iter_var_name = ast_to_source_code(self.iter_node.func.value).strip(
)
else: else:
iter_var = self.iter_node iter_var_name = ast_to_source_code(self.iter_node).strip()
return gast.Assign(
targets=[ convert_len_node_source_str = '{} = fluid.dygraph.dygraph_to_static.convert_operators.convert_len({})'.format(
gast.Name( self.iter_var_len_name, iter_var_name)
id=self.iter_var_shape_name,
ctx=gast.Load(), convert_len_node = gast.parse(convert_len_node_source_str).body[0]
annotation=None,
type_comment=None) return convert_len_node
],
value=create_api_shape_node(iter_var))
def _build_enum_init_node(self): def _build_enum_init_node(self):
enum_init_node = get_constant_variable_node(
name=self.enum_idx_name, value=0)
if self.is_for_enumerate_iter() and self.args_length != 1: if self.is_for_enumerate_iter() and self.args_length != 1:
enum_init_node = gast.Assign( init_value_str = ast_to_source_code(self.iter_args[1]).strip()
targets=[ else:
gast.Name( init_value_str = '0'
id=self.enum_idx_name,
ctx=gast.Store(), enum_init_node_source_str = "{} = {}".format(self.enum_idx_name,
annotation=None, init_value_str)
type_comment=None) enum_init_node = gast.parse(enum_init_node_source_str).body[0]
],
value=self.iter_args[1])
return enum_init_node return enum_init_node
def _build_compare_node(self): def _build_compare_node(self):
...@@ -866,15 +863,11 @@ class ForNodeVisitor(object): ...@@ -866,15 +863,11 @@ class ForNodeVisitor(object):
compare_node = self.iter_args[ compare_node = self.iter_args[
0] if self.args_length == 1 else self.iter_args[1] 0] if self.args_length == 1 else self.iter_args[1]
else: else:
compare_node = gast.Subscript( compare_node = gast.Name(
value=gast.Name( id=self.iter_var_len_name,
id=self.iter_var_shape_name,
ctx=gast.Load(), ctx=gast.Load(),
annotation=None, annotation=None,
type_comment=None), type_comment=None)
slice=gast.Index(value=gast.Constant(
value=0, kind=None)),
ctx=gast.Load())
return compare_node return compare_node
def _build_step_node(self): def _build_step_node(self):
......
...@@ -118,11 +118,10 @@ def to_static_variable(x): ...@@ -118,11 +118,10 @@ def to_static_variable(x):
return fill_constant(shape=[1], dtype='float64', value=x) return fill_constant(shape=[1], dtype='float64', value=x)
if six.PY2: if six.PY2:
if isinstance(x, int): if isinstance(x, (int, long)):
return fill_constant(shape=[1], dtype='int32', value=x)
if isinstance(x, long):
return fill_constant(shape=[1], dtype='int64', value=x) return fill_constant(shape=[1], dtype='int64', value=x)
else: else:
if isinstance(x, int): if isinstance(x, int):
return fill_constant(shape=[1], dtype='int64', value=x) return fill_constant(shape=[1], dtype='int64', value=x)
return x return x
...@@ -84,6 +84,16 @@ def for_loop_dyfunc(max_len): ...@@ -84,6 +84,16 @@ def for_loop_dyfunc(max_len):
return ret return ret
def for_loop_dyfunc2(max_len):
# Test case: a variable is used and created in loop, but used before created
for i in range(max_len):
if i > 1:
s = a
a = 1
ret = fluid.layers.fill_constant(shape=[1], dtype="int32", value=s)
return ret
def while_loop_bool_op(x): def while_loop_bool_op(x):
i = fluid.dygraph.to_variable(x) i = fluid.dygraph.to_variable(x)
...@@ -131,7 +141,6 @@ def for_loop_class_var(max_len): ...@@ -131,7 +141,6 @@ def for_loop_class_var(max_len):
foo = Foo() foo = Foo()
# Use `to_variable` so that static analysis can analyze the type of X is Tensor # Use `to_variable` so that static analysis can analyze the type of X is Tensor
# TODO(liym27): Delete it if the type of parameter x can be resolved
max_len = fluid.layers.fill_constant( max_len = fluid.layers.fill_constant(
shape=[1], value=max_len, dtype="int32") shape=[1], value=max_len, dtype="int32")
...@@ -298,6 +307,11 @@ class TestTransformForLoop(unittest.TestCase): ...@@ -298,6 +307,11 @@ class TestTransformForLoop(unittest.TestCase):
self.assertTrue(np.allclose(self._run_dygraph(), self._run_static())) self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
class TestTransformForLoop2(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_dyfunc2
class TestClassVarInForLoop(TestTransformForLoop): class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self): def _init_dyfunc(self):
self.dyfunc = for_loop_class_var self.dyfunc = for_loop_class_var
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册