From 4af491c2bba288ba6deec95800e3e5aa37ea9de2 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Tue, 17 Mar 2020 10:39:24 +0800 Subject: [PATCH] Tensor.shape support control flow if/for/while and bugfix (#22866) * Support Tensor.shape in control flow if/for/while and separate TensorShapeTransformer from BasicApiTransformer. test=develop --- .../dygraph_to_static/ast_transformer.py | 136 ++---------- .../dygraph_to_static/loop_transformer.py | 17 +- .../dygraph_to_static/static_analysis.py | 4 + .../tensor_shape_transformer.py | 210 ++++++++++++++++++ .../fluid/dygraph/dygraph_to_static/utils.py | 21 +- .../dygraph_to_static/test_tensor_shape.py | 178 +++++++++++++-- 6 files changed, 417 insertions(+), 149 deletions(-) create mode 100644 python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py index 1ff83aefa5c..ea25c3d715f 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py @@ -27,13 +27,14 @@ import gast from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func -from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable +from paddle.fluid.dygraph.dygraph_to_static.utils import is_dygraph_api, is_to_variable from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func -from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node +from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api +from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import TensorShapeTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import LoopTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor __all__ = ['DygraphToStaticAst', 'convert_to_static'] @@ -62,16 +63,21 @@ class DygraphToStaticAst(gast.NodeTransformer): # Generic transformation self.visit(node_wrapper.node) - # Transform basic api of dygraph to static graph - basic_api_trans = BasicApiTransformer(node_wrapper, - self.static_analysis_visitor) - basic_api_trans.ast_visit() + # Transform basic api of dygraph to static graph and get feed_name_to_arg_name + basic_api_trans = BasicApiTransformer(node_wrapper) + basic_api_trans.transform() self.feed_name_to_arg_name = basic_api_trans.get_feed_name_to_arg_id() + # Transform Tensor.shape into fluid.layers.shape(Tensor) + TensorShapeTransformer(node_wrapper).transform() + + # Transform list used in control flow ListTransformer(node_wrapper).transform() + # Transform all if/else statement of Dygraph into Static Graph. IfElseTransformer(node_wrapper).transform() + # Transform for loop and while loop LoopTransformer(node_wrapper).transform() def visit_FunctionDef(self, node): @@ -110,7 +116,7 @@ class BasicApiTransformer(gast.NodeTransformer): Class to transform basic API from dygraph to static graph. """ - def __init__(self, wrapper_root, static_analysis_visitor): + def __init__(self, wrapper_root): assert isinstance( wrapper_root, AstNodeWrapper ), "Input non-AstNodeWrapper node for the initialization of BasicApiTransformer." @@ -123,20 +129,7 @@ class BasicApiTransformer(gast.NodeTransformer): self.feed_name_to_arg_id = {} self.name_to_tensor_shape = {} - # Used for transformation of Tensor.shape - self.static_analysis_visitor = static_analysis_visitor - self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( - ) - self.scope_var_type_dict = {} - self._run_static_visitor() - - def _run_static_visitor(self): - var_env = copy.deepcopy(self.static_analysis_visitor.get_var_env()) - # TODO: Consider that Tensor.shape is used in sub function and sub_scopes is empty - var_env.cur_scope = var_env.cur_scope.sub_scopes[0] - self.scope_var_type_dict = var_env.get_scope_var_type() - - def ast_visit(self): + def transform(self): self.visit(self.root) return self.wrapper_root @@ -153,9 +146,6 @@ class BasicApiTransformer(gast.NodeTransformer): if self._update_class_node_dict(node): return None - if self._update_name_to_tensor_shape(node): - return node - for child_node in gast.walk(node.value): if isinstance(child_node, gast.Call): self._visit_Call(child_node) @@ -171,25 +161,6 @@ class BasicApiTransformer(gast.NodeTransformer): self._visit_Call(child_node) return node - def visit_Attribute(self, node): - if self._used_by_paddle_api(node): - if self.is_tensor_shape(node): - return create_api_shape_node(node) - return node - - def visit_Name(self, node): - if node.id in self.name_to_tensor_shape: - if self._used_by_paddle_api(node): - tensor_shape_node = self.name_to_tensor_shape[node.id] - if isinstance(tensor_shape_node, gast.Attribute): - return create_api_shape_node(tensor_shape_node) - elif isinstance(tensor_shape_node, gast.Subscript): - result_node = copy.deepcopy(tensor_shape_node) - result_node.value = create_api_shape_node( - tensor_shape_node.value) - return result_node - return node - def _visit_Call(self, node): assert isinstance(node, gast.Call) # Replace API `to_variable` with `fluid.layers.assign` @@ -198,10 +169,6 @@ class BasicApiTransformer(gast.NodeTransformer): node = to_assign_node(node) return node - if is_paddle_api(node): - # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary - self.generic_visit(node) - func_name = astor.to_source(gast.gast_to_ast(node.func)) if self._is_dygraph_forward(func_name): @@ -211,53 +178,6 @@ class BasicApiTransformer(gast.NodeTransformer): else: return node - def is_tensor_shape(self, node): - """ - Return True if node is like `x.shape` and x is Tensor, return False otherwise. - """ - assert isinstance(node, gast.Attribute) - if node.attr != 'shape': - return False - - try: - value_id = node.value.id - except AttributeError: - return False - - if value_id in self.name_to_tensor_shape: - return True - - # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function - # Need a better way to confirm whether `value_id` is a Tensor. - try: - var_type_set = self.scope_var_type_dict[value_id] - except KeyError: - return False - - if NodeVarType.NUMPY_NDARRAY in var_type_set: - return False - if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: - return False - - return True - - def _used_by_paddle_api(self, node): - assert isinstance(node, (gast.Attribute, gast.Name)) - wrapper_node = self.node_to_wrapper_map.get(node) - if not wrapper_node: - # Transformed node is not in node_to_wrapper_map - return False - while wrapper_node.parent: - parent_node = wrapper_node.parent.node - if isinstance(parent_node, gast.Call): - if is_paddle_api(parent_node): - return True - else: - return False - wrapper_node = wrapper_node.parent - - return False - def _is_dygraph_forward(self, func_id): return func_id in self.class_node_dict @@ -304,32 +224,6 @@ class BasicApiTransformer(gast.NodeTransformer): def get_feed_name_to_arg_id(self): return self.feed_name_to_arg_id - def _update_name_to_tensor_shape(self, node): - assert isinstance(node, gast.Assign) - # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1] - target_node = node.targets[0] - try: - target_id = target_node.id - except AttributeError: - return False - value_node = node.value - - if isinstance(value_node, gast.Name): - if value_node.id in self.name_to_tensor_shape: - self.name_to_tensor_shape[ - target_id] = self.name_to_tensor_shape[value_node.id] - return True - if isinstance(value_node, gast.Attribute): - if self.is_tensor_shape(value_node): # eg: x.shape - self.name_to_tensor_shape[target_id] = value_node - return True - if isinstance(value_node, gast.Subscript): - if isinstance(value_node.value, gast.Attribute): - if self.is_tensor_shape(value_node.value): # eg: x.shape[0] - self.name_to_tensor_shape[target_id] = value_node - return True - return False - def convert_to_static(dyfunc): """ diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py index d92988e7866..fd471f0431c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/loop_transformer.py @@ -20,7 +20,7 @@ import gast from collections import defaultdict from paddle.fluid import unique_name from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node @@ -70,8 +70,6 @@ class NameVisitor(gast.NodeVisitor): def __init__(self, root_node): # Set of gast.Name or gast.Attribute for variables self.current_seen_vars = set() - # list of nodes of current visit node - self.ancestor_nodes = [] # List of gast.While/gast.For nodes self.current_loop = [] @@ -80,6 +78,10 @@ class NameVisitor(gast.NodeVisitor): self.before_loop_body_vars = defaultdict(set) self.in_loop_vars = defaultdict(set) + self.static_analysis_visitor = StaticAnalysisVisitor(root_node) + self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + ) + self.visit(root_node) def is_control_flow_loop(self, node): @@ -123,11 +125,9 @@ class NameVisitor(gast.NodeVisitor): self.generic_visit(node) def visit(self, node): - self.ancestor_nodes.append(node) method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) ret = visitor(node) - self.ancestor_nodes.pop() return ret def visit_Attribute(self, node): @@ -166,10 +166,9 @@ class NameVisitor(gast.NodeVisitor): return ret def _is_call_func_name_node(self, node): - if self.ancestor_nodes: - parent_node = self.ancestor_nodes[-1] - if isinstance(parent_node, gast.Call) and parent_node.func == node: - return True + parent_node = self.node_to_wrapper_map[node].parent.node + if isinstance(parent_node, gast.Call) and parent_node.func == node: + return True return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py index a0d8e3a5881..115b3aa925c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/static_analysis.py @@ -313,6 +313,10 @@ class StaticAnalysisVisitor(object): return self.var_env.get_var_type(node.id) if isinstance(node, gast.Return): + # If return nothing: + if node.value is None: + return {NodeVarType.NONE} + return_type = self.node_to_wrapper_map[node.value].node_var_type assert self.var_env.cur_scope.scope_type == AstVarScope.SCOPE_TYPE_FUNCTION, "Return at non-function scope" func_name = self.var_env.cur_scope.scope_name diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py new file mode 100644 index 00000000000..52392bc1e0a --- /dev/null +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -0,0 +1,210 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import gast +import astor +import copy +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor +from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform +from paddle.fluid import unique_name +from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func +from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable +from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func +from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api, create_api_shape_node +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor + + +class TensorShapeTransformer(gast.NodeTransformer): + """ + This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast. + """ + + def __init__(self, wrapper_root): + assert isinstance( + wrapper_root, AstNodeWrapper + ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." + self.wrapper_root = wrapper_root + self.root = wrapper_root.node + self.name_to_tensor_shape = {} + + self.static_analysis_visitor = StaticAnalysisVisitor(self.root) + self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( + ) + var_env = self.static_analysis_visitor.get_var_env() + var_env.cur_scope = var_env.cur_scope.sub_scopes[0] + self.scope_var_type_dict = var_env.get_scope_var_type() + + def transform(self): + self.visit(self.root) + + def visit_Assign(self, node): + if self._update_name_to_tensor_shape(node): + return node + self.generic_visit(node) + return node + + def visit_Attribute(self, node): + if self._used_by_paddle_api(node): + if self.is_tensor_shape(node): + return create_api_shape_node(node) + return node + + def visit_Name(self, node): + if node.id in self.name_to_tensor_shape: + if self._used_by_paddle_api(node): + tensor_shape_node = self.name_to_tensor_shape[node.id] + return create_api_shape_node(tensor_shape_node) + return node + + def visit_Call(self, node): + assert isinstance(node, gast.Call) + if is_paddle_api(node): + # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary. + self.generic_visit(node) + + return node + + def visit_If(self, node): + # Call generic_visit first to transform Tensor.shape that is used in Paddle Api. + self.generic_visit(node) + cond = node.test + self._transform_tensor_shape_if_necessary(cond) + return node + + def visit_While(self, node): + self.generic_visit(node) + cond = node.test + self._transform_tensor_shape_if_necessary(cond) + return node + + def visit_For(self, node): + self.generic_visit(node) + iter = node.iter + self._transform_tensor_shape_if_necessary(iter) + + # If tensor.shape is a gast.Name and it is used in range function, transform it + self._transform_tensor_shape_in_range(node) + return node + + def _transform_tensor_shape_in_range(self, node): + assert isinstance(node, gast.For) + if not isinstance(node.iter, gast.Call): + return False + if not isinstance(node.iter.func, gast.Name): + return False + if node.iter.func.id != "range": + return False + args = node.iter.args + for idx, arg in enumerate(args): + if isinstance(arg, + gast.Name) and arg.id in self.name_to_tensor_shape: + args[idx] = create_api_shape_node(self.name_to_tensor_shape[ + arg.id]) + + return True + + def _transform_tensor_shape_if_necessary(self, cond): + for child_node in gast.walk(cond): + tensor_shape_node = None + if isinstance(child_node, (gast.Attribute)): + if self.is_tensor_shape(child_node): + tensor_shape_node = child_node + elif isinstance(child_node, (gast.Name)): + if child_node.id in self.name_to_tensor_shape: + tensor_shape_node = self.name_to_tensor_shape[child_node.id] + + if tensor_shape_node: + wrapper_node = self.node_to_wrapper_map.get(child_node) + parent_node = wrapper_node.parent.node + for field, value in gast.iter_fields(parent_node): + if child_node is value: + setattr(parent_node, field, + create_api_shape_node(tensor_shape_node)) + break + + def _used_by_paddle_api(self, node): + assert isinstance(node, (gast.Attribute, gast.Name)) + wrapper_node = self.node_to_wrapper_map.get(node) + if not wrapper_node: + # Transformed node is not in node_to_wrapper_map + return False + while wrapper_node.parent: + parent_node = wrapper_node.parent.node + if isinstance(parent_node, gast.Call): + if is_paddle_api(parent_node): + return True + else: + return False + wrapper_node = wrapper_node.parent + + return False + + def is_tensor_shape(self, node): + """ + Return True if node is like `x.shape` and x is Tensor, return False otherwise. + """ + assert isinstance(node, gast.Attribute) + if node.attr != 'shape': + return False + + try: + value_id = node.value.id + except AttributeError: + return False + + if value_id in self.name_to_tensor_shape: + return True + + # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function + # Need a better way to confirm whether `value_id` is a Tensor. + try: + var_type_set = self.scope_var_type_dict[value_id] + except KeyError: + return False + + if NodeVarType.NUMPY_NDARRAY in var_type_set: + return False + if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: + return False + + return True + + def _update_name_to_tensor_shape(self, node): + assert isinstance(node, gast.Assign) + # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1] + target_node = node.targets[0] + try: + target_id = target_node.id + except AttributeError: + return False + value_node = node.value + + if isinstance(value_node, gast.Name): + if value_node.id in self.name_to_tensor_shape: + self.name_to_tensor_shape[ + target_id] = self.name_to_tensor_shape[value_node.id] + return True + if isinstance(value_node, gast.Attribute): + if self.is_tensor_shape(value_node): # eg: x.shape + self.name_to_tensor_shape[target_id] = value_node + return True + if isinstance(value_node, gast.Subscript): + if isinstance(value_node.value, gast.Attribute): + if self.is_tensor_shape(value_node.value): # eg: x.shape[0] + self.name_to_tensor_shape[target_id] = value_node + return True + return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py index 784950f861b..ed14bafb17d 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/utils.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/utils.py @@ -230,12 +230,19 @@ def update_args_of_func(node, dygraph_node, method_name): def create_api_shape_node(tensor_shape_node): - assert isinstance(tensor_shape_node, gast.Attribute) - api_shape_node = gast.Call( - func=gast.parse('fluid.layers.shape').body[0].value, - args=[tensor_shape_node.value], - keywords=[]) - return api_shape_node + assert isinstance(tensor_shape_node, (gast.Attribute, gast.Subscript)) + + if isinstance(tensor_shape_node, gast.Attribute): + api_shape_node = gast.Call( + func=gast.parse('fluid.layers.shape').body[0].value, + args=[tensor_shape_node.value], + keywords=[]) + return api_shape_node + + if isinstance(tensor_shape_node, gast.Subscript): + result_node = copy.deepcopy(tensor_shape_node) + result_node.value = create_api_shape_node(result_node.value) + return result_node def get_constant_variable_node(name, value, shape=[1], dtype='int64'): @@ -280,6 +287,8 @@ def create_funcDef_node(nodes, name, input_args, return_name_ids): # add return statement if return_name_ids: nodes.append(gast.Return(value=generate_name_node(return_name_ids))) + else: + nodes.append(gast.Return(value=None)) func_def_node = gast.FunctionDef( name=name, args=input_args, diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 30f46e216f5..cc9eff08eb1 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -58,17 +58,118 @@ def dyfunc_tensor_shape_5(x): return res -test_funcs = [ - dyfunc_tensor_shape_1, dyfunc_tensor_shape_2, dyfunc_tensor_shape_3, - dyfunc_tensor_shape_4, dyfunc_tensor_shape_5 -] +def dyfunc_with_if_1(x): + x = fluid.dygraph.to_variable(x) + res = fluid.layers.reshape(x, [-1, 1]) + x_shape_0 = x.shape[0] + if x_shape_0 < 1: + # `res.shape[0] > 1` is transformed into `if fluid.layers.shape(res)[0] > 1` + if res.shape[0] > 1: + res = fluid.layers.fill_constant( + value=2, shape=x.shape, dtype="int32") + else: + res = fluid.layers.fill_constant( + value=3, shape=x.shape, dtype="int32") + return res + + +def dyfunc_with_if_2(x): + x = fluid.dygraph.to_variable(x) + # `len(x.shape)` will not be transformed. + if len(x.shape) < 1: + res = x + else: + res = fluid.layers.fill_constant(value=8, shape=x.shape, dtype="int32") + + return res + + +def dyfunc_with_for_1(x): + x = fluid.dygraph.to_variable(x) + res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") + # `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]` + for i in range(x.shape[0]): + res += 1 + return res + + +def dyfunc_with_for_2(x): + x = fluid.dygraph.to_variable(x) + x_shape_0 = x.shape[0] + res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") + + # `x_shape_0` is transformed into `fluid.layers.shape(x)[0]` + for i in range(x_shape_0): + res += 1 + return res + + +def dyfunc_with_for_3(x): + # TODO(liym27): + # It will fail to run because `for i in range(len(x.shape))` will be transformed into Paddle while_loop. + # Here the python list x.shape will be added to loop_vars. However, loop_vars doesn't support python list. + # And the condition of `for i in range(len(x.shape))` only uses the length of x.shape, so it doesn't have to be transformed into Paddle while_loop. + # After the AST tranformation of for loop is improved, add TestTensorShapeInFor3. + x = fluid.dygraph.to_variable(x) + res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") + # `len(x.shape)` is not transformed. + for i in range(len(x.shape)): + res += 1 + + return res + + +def dyfunc_with_while_1(x): + x = fluid.dygraph.to_variable(x) + res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") + # `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]` + i = 1 + while i < x.shape[0]: + res += 1 + i = i + 2 + return res + + +def dyfunc_with_while_2(x): + x = fluid.dygraph.to_variable(x) + x_shape_0 = x.shape[0] + res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") + i = 1 + # `x_shape_0` is transformed into `fluid.layers.shape(x)[0]` + # TODO(liym27): If `x_shape_0` is at right like `while i < x_shape_0`, it will not be transformed. + # Fix this bug next PR. + while x_shape_0 > i: + res += 1 + i = i + 2 + return res -class TestTensorShape(unittest.TestCase): +def dyfunc_with_while_3(x): + # TODO(liym27): + # It will fail to run because the same problem as `dyfunc_with_for_3`. + # After the AST tranformation of for loop is improved, add TestTensorShapeInWhile3. + x = fluid.dygraph.to_variable(x) + x_shape = x.shape + res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") + i = 1 + + # `len(x.shape)` is not transformed. + while len(x_shape) > i: + res += 1 + i += 1 + return res + + +# 1. Basic tests without control flow +class TestTensorShapeBasic(unittest.TestCase): def setUp(self): self.input = numpy.ones(5).astype("int32") self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( ) else fluid.CPUPlace() + self.init_test_func() + + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_1 def get_dygraph_output(self): with fluid.dygraph.guard(): @@ -86,14 +187,65 @@ class TestTensorShape(unittest.TestCase): return static_res[0] def test_transformed_static_result(self): - for func in test_funcs: - self.dygraph_func = func - static_res = self.get_static_output() - dygraph_res = self.get_dygraph_output() - self.assertTrue( - numpy.allclose(dygraph_res, static_res), - msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, - static_res)) + static_res = self.get_static_output() + dygraph_res = self.get_dygraph_output() + self.assertTrue( + numpy.allclose(dygraph_res, static_res), + msg='dygraph res is {}\nstatic_res is {}'.format(dygraph_res, + static_res)) + + +class TestTensorShapeBasic2(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_2 + + +class TestTensorShapeBasic3(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_3 + + +class TestTensorShapeBasic4(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_4 + + +class TestTensorShapeBasic5(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_tensor_shape_5 + + +# 2. Tests with control flow if +class TestTensorShapeInIf1(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_if_1 + + +class TestTensorShapeInIf2(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_if_2 + + +# 3. Tests with control flow for loop +class TestTensorShapeInFor1(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_for_1 + + +class TestTensorShapeInFor2(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_for_2 + + +# 4. Tests with control flow while loop +class TestTensorShapeInWhile1(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_while_1 + + +class TestTensorShapeInWhile2(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_while_2 if __name__ == '__main__': -- GitLab