diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 5ee827073ab41a1c11f3085e8e9d49a82a9f39cb..c05173a28e25832ef2741c5268ce8a39d07c753e 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -43,6 +43,7 @@ def convert_while_loop(cond, body, loop_vars): def _run_paddle_while_loop(cond, body, loop_vars): + # NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Variable. loop_vars = [to_static_variable(var) for var in loop_vars] loop_vars = control_flow.while_loop(cond, body, loop_vars) return loop_vars @@ -146,7 +147,7 @@ def _run_py_logical_not(x): return not x -def convert_ifelse(pred, true_fn, false_fn): +def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars): """ A function representation of a Python ``if/else`` statement. @@ -154,25 +155,45 @@ def convert_ifelse(pred, true_fn, false_fn): pred(bool|Variable): A boolean variable which determines whether to return the result of ``true_fn`` or ``false_fn`` . true_fn(callable): A callable to be performed if ``pred`` is true. false_fn(callable): A callable to be performed if ``pred`` is false. + true_args(tuple): Parameters of ``true_fn``. + false_args(tuple): Parameters of ``false_fn``. + return_vars(tuple): Return variables of ``true_fn`` and ``false_fn``. Returns: - ``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` . + ``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` . """ if isinstance(pred, Variable): - return _run_paddle_cond(pred, true_fn, false_fn) + return _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, + return_vars) else: - return _run_py_ifelse(pred, true_fn, false_fn) + return _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args) -def _run_paddle_cond(pred, true_fn, false_fn): +def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args, + return_vars): + + return_var_ids = [id(var) for var in return_vars] + # NOTE 1: return vars of Paddle op `control_flow.cond` must be Paddle Variable + # NOTE 2: Here uses id(var) not var, because `if var in return_var` use operator `==`, + # which will call `fluid.layers.equal` and causes error when var in return_vars is not initialized. + true_args = [ + to_static_variable(var) if id(var) in return_var_ids else var + for var in true_args + ] + false_args = [ + to_static_variable(var) if id(var) in return_var_ids else var + for var in false_args + ] + pred = cast_bool_if_necessary(pred) - return control_flow.cond(pred, true_fn, false_fn) + return control_flow.cond(pred, lambda: true_fn(*true_args), + lambda: false_fn(*false_args)) -def _run_py_ifelse(pred, true_fn, false_fn): +def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args): - return true_fn() if pred else false_fn() + return true_fn(*true_args) if pred else false_fn(*false_args) def convert_len(var): @@ -202,6 +223,16 @@ def convert_len(var): return len(var) +def convert_var_shape(x): + """ + A function representation of the shape of variable. + """ + if isinstance(x, Variable): + return nn.shape(x) + else: + return x.shape + + def cast_bool_if_necessary(var): assert isinstance(var, Variable) if convert_dtype(var.dtype) not in ['bool']: diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py index a6384d4a37d147eb7a7c4b230b7c7ab91a728abb..28073f157ddb858da4fdf0e49026f5286d00411b 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py @@ -24,23 +24,14 @@ from collections import defaultdict import gast from paddle.fluid import unique_name -from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none -from paddle.fluid.dygraph.dygraph_to_static.utils import is_candidate_node -from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code -from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node +from paddle.fluid.dygraph.dygraph_to_static.utils import create_funcDef_node, ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import create_assign_node -from paddle.fluid.dygraph.dygraph_to_static.utils import IsControlFlowVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_static_variable_gast_node TRUE_FUNC_PREFIX = 'true_fn' FALSE_FUNC_PREFIX = 'false_fn' -LOGIC_AND_PREFIX = 'logic_and' -LOGIC_OR_PREFIX = 'logic_or' -LOGIC_NOT_PREFIX = 'logic_not' -PLAIN_TENSOR_PREFIX = 'bool_tensor' class IfElseTransformer(gast.NodeTransformer): @@ -66,8 +57,9 @@ class IfElseTransformer(gast.NodeTransformer): self.generic_visit(node) new_vars_stmts, true_func_node, false_func_node, return_name_ids = transform_if_else( node, self.root) - new_node = create_cond_node(return_name_ids, node.test, true_func_node, - false_func_node) + + new_node = create_convert_ifelse_node(return_name_ids, node.test, + true_func_node, false_func_node) return new_vars_stmts + [true_func_node, false_func_node] + [new_node] @@ -86,8 +78,8 @@ class IfElseTransformer(gast.NodeTransformer): """ self.generic_visit(node) - new_node = create_cond_node(None, node.test, node.body, node.orelse, - True) + new_node = create_convert_ifelse_node(None, node.test, node.body, + node.orelse, True) # Note: A blank line will be added separately if transform gast.Expr # into source code. Using gast.Expr.value instead to avoid syntax error # in python. @@ -108,6 +100,7 @@ class NameVisitor(gast.NodeVisitor): # Available only when end_node is set. self._is_finished = False self._candidate_ctxs = (gast.Store, gast.Load, gast.Param) + self._def_func_names = set() def visit(self, node): """Visit a node.""" @@ -173,6 +166,8 @@ class NameVisitor(gast.NodeVisitor): def visit_Name(self, node): blacklist = {'True', 'False', 'None'} if node.id in blacklist: return + if node.id in self._def_func_names: + return if not self._is_call_func_name_node(node): if isinstance(node.ctx, self._candidate_ctxs): self.name_ids[node.id].append(node.ctx) @@ -183,6 +178,7 @@ class NameVisitor(gast.NodeVisitor): self.generic_visit(node) def visit_FunctionDef(self, node): + self._def_func_names.add(node.name) if not self.end_node: self.generic_visit(node) else: @@ -274,6 +270,7 @@ def parse_cond_args(var_ids_dict, return_ids=None, ctx=gast.Load): kw_defaults=None, kwarg=None, defaults=[]) + return arguments @@ -453,56 +450,59 @@ def transform_if_else(node, root): name=unique_name.generate(FALSE_FUNC_PREFIX), input_args=parse_cond_args(orelse_name_ids, modified_name_ids), return_name_ids=return_name_ids) - return create_new_vars_in_parent_stmts, true_func_node, false_func_node, return_name_ids -def create_cond_node(return_name_ids, - pred, - true_func, - false_func, - is_if_expr=False): +def create_convert_ifelse_node(return_name_ids, + pred, + true_func, + false_func, + is_if_expr=False): """ - Create `fluid.layers.cond(pred, true_fn, false_fn)` to replace - original `python if/else` statement. + Create `fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( + pred, true_fn, false_fn, true_args, false_args, return_vars)` + to replace original `python if/else` statement. """ - def create_lambda_node(func_or_expr_node, is_if_expr=False): - body = func_or_expr_node - if not is_if_expr: - body = gast.Call( - func=gast.Name( - id=func_or_expr_node.name, - ctx=gast.Load(), - annotation=None, - type_comment=None), - args=[func_or_expr_node.args], - keywords=[]) - - lambda_node = gast.Lambda( - args=gast.arguments( - args=[], - posonlyargs=[], - vararg=None, - kwonlyargs=[], - kw_defaults=None, - kwarg=None, - defaults=[]), - body=body) - return lambda_node - - cond_api = gast.parse( - 'fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse' - ).body[0].value - true_func_lambda = create_lambda_node(true_func, is_if_expr) - false_func_lambda = create_lambda_node(false_func, is_if_expr) - cond_layer = gast.Call( - func=cond_api, - args=[pred, true_func_lambda, false_func_lambda], - keywords=[]) + def create_name_nodes(name_ids): + if not name_ids: + return gast.Tuple(elts=[], ctx=gast.Load()) + + gast_names = [ + gast.Name( + id=name_id, ctx=gast.Load(), annotation=None, type_comment=None) + for name_id in name_ids + ] + name_node = gast.Tuple(elts=gast_names, ctx=gast.Load()) + return name_node + + if is_if_expr: + true_args = gast.Tuple(elts=[], ctx=gast.Load()) + false_args = gast.Tuple(elts=[], ctx=gast.Load()) + true_func_source = "lambda : {}".format(ast_to_source_code(true_func)) + false_func_source = "lambda : {}".format(ast_to_source_code(false_func)) + else: + true_args = gast.Tuple(elts=true_func.args.args, ctx=gast.Load()) + false_args = gast.Tuple(elts=false_func.args.args, ctx=gast.Load()) + true_func_source = true_func.name + false_func_source = false_func.name + + return_vars = create_name_nodes(return_name_ids) + + convert_ifelse_layer = gast.parse( + 'fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse(' + '{pred}, {true_fn}, {false_fn}, {true_args}, {false_args}, {return_vars})'. + format( + pred=ast_to_source_code(pred), + true_fn=true_func_source, + false_fn=false_func_source, + true_args=ast_to_source_code(true_args), + false_args=ast_to_source_code(false_args), + return_vars=ast_to_source_code(return_vars))).body[0].value + if return_name_ids: - _, cond_node = create_assign_node(return_name_ids, cond_layer) + _, cond_node = create_assign_node(return_name_ids, convert_ifelse_layer) else: # No variables can be returned if no assign statement in if.body. - cond_node = gast.Expr(value=cond_layer) + cond_node = gast.Expr(value=convert_ifelse_layer) return cond_node diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index a21a5af4552da9179e0f54d1c3c8e87bdd52c3bb..cad70f64c466f3ed61d65442230b246eacfd2a74 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -14,16 +14,36 @@ from __future__ import print_function +import copy import gast + from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api -from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node -from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType +from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer +from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor +def create_convert_shape_node(var_shape_node): + assert isinstance(var_shape_node, (gast.Attribute, gast.Subscript)) + + convert_var_shape_func = "fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape" + + if isinstance(var_shape_node, gast.Attribute): + api_shape_node = gast.Call( + func=gast.parse(convert_var_shape_func).body[0].value, + args=[var_shape_node.value], + keywords=[]) + return api_shape_node + + if isinstance(var_shape_node, gast.Subscript): + result_node = copy.deepcopy(var_shape_node) + result_node.value = create_convert_shape_node(result_node.value) + return result_node + + class TensorShapeTransformer(gast.NodeTransformer): """ - This class transforms Tensor.shape used in Paddle Apis and control flow conditions into Static Graph Ast. + This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast. """ def __init__(self, wrapper_root): @@ -32,7 +52,7 @@ class TensorShapeTransformer(gast.NodeTransformer): ), "Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer." self.wrapper_root = wrapper_root self.root = wrapper_root.node - self.name_to_tensor_shape = {} + self.name_to_var_shape = {} self.static_analysis_visitor = StaticAnalysisVisitor(self.root) self.node_to_wrapper_map = self.static_analysis_visitor.get_node_to_wrapper_map( @@ -42,58 +62,60 @@ class TensorShapeTransformer(gast.NodeTransformer): self.scope_var_type_dict = var_env.get_scope_var_type() def transform(self): + SplitAssignTransformer(self.root).transform() self.visit(self.root) def visit_Assign(self, node): - if self._update_name_to_tensor_shape(node): + if self._update_name_to_var_shape(node): return node self.generic_visit(node) return node def visit_Attribute(self, node): if self._used_by_paddle_api(node): - if self.is_tensor_shape(node): - return create_api_shape_node(node) + if self.is_var_shape(node): + return create_convert_shape_node(node) return node def visit_Name(self, node): - if node.id in self.name_to_tensor_shape: + if node.id in self.name_to_var_shape: if self._used_by_paddle_api(node): - tensor_shape_node = self.name_to_tensor_shape[node.id] - return create_api_shape_node(tensor_shape_node) + var_shape_node = self.name_to_var_shape[node.id] + return create_convert_shape_node(var_shape_node) return node def visit_Call(self, node): assert isinstance(node, gast.Call) if is_paddle_api(node): - # Visit gast.Attribute and gast.Name to replace tensor.shape if necessary. + # Visit gast.Attribute and gast.Name to replace var.shape if necessary. self.generic_visit(node) return node def visit_If(self, node): - # Call generic_visit first to transform Tensor.shape that is used in Paddle Api. + # Call generic_visit first to transform var.shape that is used in Paddle Api. self.generic_visit(node) cond = node.test - self._transform_tensor_shape_if_necessary(cond) + self._transform_var_shape_if_necessary(cond) + return node def visit_While(self, node): self.generic_visit(node) cond = node.test - self._transform_tensor_shape_if_necessary(cond) + self._transform_var_shape_if_necessary(cond) return node def visit_For(self, node): self.generic_visit(node) iter = node.iter - self._transform_tensor_shape_if_necessary(iter) + self._transform_var_shape_if_necessary(iter) - # If tensor.shape is a gast.Name and it is used in range function, transform it - self._transform_tensor_shape_in_range(node) + # If var.shape is a gast.Name and it is used in range function, transform it + self._transform_var_shape_in_range(node) return node - def _transform_tensor_shape_in_range(self, node): + def _transform_var_shape_in_range(self, node): assert isinstance(node, gast.For) if not isinstance(node.iter, gast.Call): return False @@ -103,31 +125,33 @@ class TensorShapeTransformer(gast.NodeTransformer): return False args = node.iter.args for idx, arg in enumerate(args): - if isinstance(arg, - gast.Name) and arg.id in self.name_to_tensor_shape: - args[idx] = create_api_shape_node(self.name_to_tensor_shape[ + if isinstance(arg, gast.Name) and arg.id in self.name_to_var_shape: + args[idx] = create_convert_shape_node(self.name_to_var_shape[ arg.id]) return True - def _transform_tensor_shape_if_necessary(self, cond): + def _transform_var_shape_if_necessary(self, cond): + need_transformed = False for child_node in gast.walk(cond): - tensor_shape_node = None + var_shape_node = None if isinstance(child_node, (gast.Attribute)): - if self.is_tensor_shape(child_node): - tensor_shape_node = child_node + if self.is_var_shape(child_node): + var_shape_node = child_node elif isinstance(child_node, (gast.Name)): - if child_node.id in self.name_to_tensor_shape: - tensor_shape_node = self.name_to_tensor_shape[child_node.id] + if child_node.id in self.name_to_var_shape: + var_shape_node = self.name_to_var_shape[child_node.id] - if tensor_shape_node: + if var_shape_node: + need_transformed = True wrapper_node = self.node_to_wrapper_map.get(child_node) parent_node = wrapper_node.parent.node for field, value in gast.iter_fields(parent_node): if child_node is value: setattr(parent_node, field, - create_api_shape_node(tensor_shape_node)) + create_convert_shape_node(var_shape_node)) break + return need_transformed def _used_by_paddle_api(self, node): assert isinstance(node, (gast.Attribute, gast.Name)) @@ -146,11 +170,12 @@ class TensorShapeTransformer(gast.NodeTransformer): return False - def is_tensor_shape(self, node): + def is_var_shape(self, node): """ - Return True if node is like `x.shape` and x is Tensor, return False otherwise. + Return True if node is like `x.shape`, return False otherwise. """ assert isinstance(node, gast.Attribute) + if node.attr != 'shape': return False @@ -159,26 +184,13 @@ class TensorShapeTransformer(gast.NodeTransformer): except AttributeError: return False - if value_id in self.name_to_tensor_shape: + if value_id in self.name_to_var_shape: return True - # TODO: `value_id` may be not in scope_var_type_dict if `value_id` is the arg of decorated function - # Need a better way to confirm whether `value_id` is a Tensor. - try: - var_type_set = self.scope_var_type_dict[value_id] - except KeyError: - return False - - if NodeVarType.NUMPY_NDARRAY in var_type_set: - return False - if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set: - return False - return True - def _update_name_to_tensor_shape(self, node): + def _update_name_to_var_shape(self, node): assert isinstance(node, gast.Assign) - # TODO: Consider node has more than one target. eg: x, y = a, Tensor.shape[1] target_node = node.targets[0] try: target_id = target_node.id @@ -187,17 +199,17 @@ class TensorShapeTransformer(gast.NodeTransformer): value_node = node.value if isinstance(value_node, gast.Name): - if value_node.id in self.name_to_tensor_shape: - self.name_to_tensor_shape[ - target_id] = self.name_to_tensor_shape[value_node.id] + if value_node.id in self.name_to_var_shape: + self.name_to_var_shape[target_id] = self.name_to_var_shape[ + value_node.id] return True if isinstance(value_node, gast.Attribute): - if self.is_tensor_shape(value_node): # eg: x.shape - self.name_to_tensor_shape[target_id] = value_node + if self.is_var_shape(value_node): # eg: x.shape + self.name_to_var_shape[target_id] = value_node return True if isinstance(value_node, gast.Subscript): if isinstance(value_node.value, gast.Attribute): - if self.is_tensor_shape(value_node.value): # eg: x.shape[0] - self.name_to_tensor_shape[target_id] = value_node + if self.is_var_shape(value_node.value): # eg: x.shape[0] + self.name_to_var_shape[target_id] = value_node return True return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py index d3025744b35e990e4f0c900858b458bc1c6568ef..bc2851b630c1a4ff8af8802a0b984ab517b77387 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/variable_trans_func.py @@ -117,11 +117,7 @@ def to_static_variable(x): if isinstance(x, float): return fill_constant(shape=[1], dtype='float64', value=x) - if six.PY2: - if isinstance(x, (int, long)): - return fill_constant(shape=[1], dtype='int64', value=x) - else: - if isinstance(x, int): - return fill_constant(shape=[1], dtype='int64', value=x) + if isinstance(x, six.integer_types): + return fill_constant(shape=[1], dtype='int64', value=x) return x diff --git a/python/paddle/fluid/layers/utils.py b/python/paddle/fluid/layers/utils.py index e2f67e20b784823b293497f655c7d00378d54277..df00a0f561ffcf1db214d52eef3ef4509bb6b074 100644 --- a/python/paddle/fluid/layers/utils.py +++ b/python/paddle/fluid/layers/utils.py @@ -112,9 +112,9 @@ def _yield_flat_nest(nest): def flatten(nest): """ - :alias_main: paddle.flatten - :alias: paddle.flatten,paddle.tensor.flatten,paddle.tensor.manipulation.flatten - :old_api: paddle.fluid.layers.flatten + :alias_main: paddle.flatten + :alias: paddle.flatten,paddle.tensor.flatten,paddle.tensor.manipulation.flatten + :old_api: paddle.fluid.layers.flatten Traverse all entries in the nested structure and put them into an list. """ @@ -341,7 +341,7 @@ def _convert_to_tensor_list(old_list, dtype="int32"): ele.stop_gradient = True new_list_tensor.append(ele) else: - assert (isinstance(ele, int)) + assert isinstance(ele, six.integer_types) temp_out = fill_constant([1], dtype, ele, force_cpu=True) new_list_tensor.append(temp_out) return new_list_tensor diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 2ca62d9332c83ba7eb99d714fdad0b195928ab65..08d832b64a213fc217ae0baaba734ebfb58e9ceb 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -42,7 +42,10 @@ def dyfunc_with_if_else(x_v, label=None): def dyfunc_with_if_else2(x, col=100): row = 0 if abs(col) > x.shape[-1]: - col = -1 + # TODO: Don't support return non-Tensor in Tensor-dependent `if` stament currently. + # `x` is Tensor, `col` is not Tensor, and `col` is the return value of `true_fn` after transformed. + # col = -1 + col = fluid.layers.fill_constant(shape=[1], value=-1, dtype="int64") if fluid.layers.reduce_mean(x).numpy()[0] > x.numpy()[row][col]: y = fluid.layers.relu(x) else: @@ -101,7 +104,12 @@ def nested_if_else(x_v): feat_size = x_v.shape[-1] bias = fluid.layers.fill_constant([feat_size], dtype='float32', value=1) if x_v.shape[0] != batch_size: - batch_size = x_v.shape[0] + # TODO: Don't support return non-Tensor in Tensor-dependent `if` stament currently. + # `x_v.shape[0]` is not Tensor, and `batch_size` is the return value of `true_fn` after transformed. + # col = -1 + # batch_size = x_v.shape[0] + batch_size = fluid.layers.shape(x_v)[0] + # if tensor.shape is [1], now support to compare with numpy. if fluid.layers.mean(x_v).numpy() < 0: y = x_v + bias diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index 858fa7591127731bb73cd4aab397f90fe52ff26c..3cf8f5b71d7760e9cfea11049f02e07ee31a8087 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -72,10 +72,8 @@ class StaticCode1(): return x_v x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( - fluid.layers.mean(x_v)[0] > 5, - lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(x_v), - lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(x_v) - ) + fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ), + (x_v, ), (x_v, )) def true_fn_1(label, x_v): loss = fluid.layers.cross_entropy(x_v, label) @@ -86,9 +84,7 @@ class StaticCode1(): return fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( - label is not None, - lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_1)(label, x_v), - lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_1)()) + label is not None, true_fn_1, false_fn_1, (label, x_v), (), ()) return x_v @@ -104,10 +100,8 @@ class StaticCode2(): return x_v x_v = fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( - fluid.layers.mean(x_v)[0] > 5, - lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_2)(x_v), - lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_2)(x_v) - ) + fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ), + (x_v, ), (x_v, )) def true_fn_3(label, x_v): loss = fluid.layers.cross_entropy(x_v, label) @@ -118,9 +112,7 @@ class StaticCode2(): return fluid.dygraph.dygraph_to_static.convert_operators.convert_ifelse( - label is not None, - lambda: fluid.dygraph.dygraph_to_static.convert_call(true_fn_3)(label, x_v), - lambda: fluid.dygraph.dygraph_to_static.convert_call(false_fn_3)()) + label is not None, true_fn_3, false_fn_3, (label, x_v), (), ()) return x_v @@ -138,7 +130,6 @@ class TestDygraphToStaticCode(unittest.TestCase): self.maxDiff = None def test_decorator(self): - x_v = None program_translator = ProgramTranslator() code = program_translator.get_code(dyfunc_with_if_else) answer = get_source_code(StaticCode1.dyfunc_with_if_else) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index 6acc4bcd1c0471b3126109459511d49176df7c80..46d2b220414c4ec975824271272e0a92d3c057c4 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -36,7 +36,7 @@ def dyfunc_tensor_shape_2(x): def dyfunc_tensor_shape_3(x): - # Don't transform y.shape because y is numpy.ndarray + # Transform y.shape but run y.shape actually because y is not Tensor x = fluid.dygraph.to_variable(x) y = numpy.ones(5) res = fluid.layers.reshape(x, shape=y.shape) @@ -51,7 +51,8 @@ def dyfunc_tensor_shape_4(x): def dyfunc_tensor_shape_5(x): # `res = fluid.layers.reshape(x, shape=(-1, s))` to - # `res = fluid.layers.reshape(x, shape=(-1, fluid.layers.shape(x)[0]))` + # `res = fluid.layers.reshape(x, shape=(-1, + # fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]))` x = fluid.dygraph.to_variable(x) s = x.shape[0] res = fluid.layers.reshape(x, shape=(-1, s)) @@ -63,7 +64,8 @@ def dyfunc_with_if_1(x): res = fluid.layers.reshape(x, [-1, 1]) x_shape_0 = x.shape[0] if x_shape_0 < 1: - # `res.shape[0] > 1` is transformed into `if fluid.layers.shape(res)[0] > 1` + # `res.shape[0]` is transformed into + # `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(res)[0]` if res.shape[0] > 1: res = fluid.layers.fill_constant( value=2, shape=x.shape, dtype="int32") @@ -75,7 +77,7 @@ def dyfunc_with_if_1(x): def dyfunc_with_if_2(x): x = fluid.dygraph.to_variable(x) - # `len(x.shape)` will not be transformed. + # `len(x.shape)` will not be transformed because x.shape is not used by Paddle api. if len(x.shape) < 1: res = x else: @@ -87,7 +89,7 @@ def dyfunc_with_if_2(x): def dyfunc_with_for_1(x): x = fluid.dygraph.to_variable(x) res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") - # `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]` + # `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` for i in range(x.shape[0]): res += 1 return res @@ -98,7 +100,7 @@ def dyfunc_with_for_2(x): x_shape_0 = x.shape[0] res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") - # `x_shape_0` is transformed into `fluid.layers.shape(x)[0]` + # `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` for i in range(x_shape_0): res += 1 return res @@ -122,7 +124,7 @@ def dyfunc_with_for_3(x): def dyfunc_with_while_1(x): x = fluid.dygraph.to_variable(x) res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") - # `x.shape[0]` is transformed into `fluid.layers.shape(x)[0]` + # `x.shape[0]` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` i = 1 while i < x.shape[0]: res += 1 @@ -135,19 +137,14 @@ def dyfunc_with_while_2(x): x_shape_0 = x.shape[0] res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") i = 1 - # `x_shape_0` is transformed into `fluid.layers.shape(x)[0]` - # TODO(liym27): If `x_shape_0` is at right like `while i < x_shape_0`, it will not be transformed. - # Fix this bug next PR. - while x_shape_0 > i: + # `x_shape_0` is transformed into `fluid.dygraph.dygraph_to_static.convert_operators.convert_var_shape(x)[0]` + while i < x_shape_0: res += 1 i = i + 2 return res def dyfunc_with_while_3(x): - # TODO(liym27): - # It will fail to run because the same problem as `dyfunc_with_for_3`. - # After the AST tranformation of for loop is improved, add TestTensorShapeInWhile3. x = fluid.dygraph.to_variable(x) x_shape = x.shape res = fluid.layers.fill_constant(value=0, shape=[1], dtype="int32") @@ -160,6 +157,19 @@ def dyfunc_with_while_3(x): return res +def dyfunc_with_while_4(x): + x = fluid.dygraph.to_variable(x) + y = numpy.ones(5) + y_shape_0 = y.shape[0] + i = 1 + + # Transform y_shape_0 but run y.shape[0] actually because y is not Tensor + while y_shape_0 > i: + x += 1 + i += 1 + return x + + # 1. Basic tests without control flow class TestTensorShapeBasic(unittest.TestCase): def setUp(self): @@ -183,7 +193,7 @@ class TestTensorShapeBasic(unittest.TestCase): return self._run(to_static=False) def get_static_output(self): - return self._run(to_static=False) + return self._run(to_static=True) def test_transformed_static_result(self): static_res = self.get_static_output() @@ -247,5 +257,15 @@ class TestTensorShapeInWhile2(TestTensorShapeBasic): self.dygraph_func = dyfunc_with_while_2 +class TestTensorShapeInWhile3(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_while_3 + + +class TestTensorShapeInWhile4(TestTensorShapeBasic): + def init_test_func(self): + self.dygraph_func = dyfunc_with_while_4 + + if __name__ == '__main__': unittest.main()