diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index f8346a49d2a9b4d6c73accf63f002581ba70fe54..25c552c9a889702bb48d8f10fa79d56d5a0c43d5 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -33,7 +33,7 @@ __all__ = [ 'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than', 'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal', 'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN', - 'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'switch_case' + 'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'case', 'switch_case' ] @@ -1798,6 +1798,130 @@ def cond(pred, true_fn=None, false_fn=None, name=None): return merged_output +def _error_message(what, arg_name, op_name, right_value, error_value): + error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \ + "{right_value}, but received: {error_value}.".format( + what=what, + arg_name=arg_name, + op_name=op_name, + right_value=right_value, + error_value=error_value) + + return error_message + + +def case(pred_fn_pairs, default=None, name=None): + ''' + This operator works like an if-elif-elif-else chain. + + Args: + pred_fn_pairs(list|tuple): A list or tuple of (pred, fn) pairs. ``pred`` is a boolean Tensor with shape [1], ``fn`` is a callable. All callables return the same structure of Tensors. + default(callable, optional): Callable that returns a structure of Tensors. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable|list(Variable): Tensors returned by the callable from the first pair whose pred is True, + or Tensors returned by ``default`` if no pred in ``pred_fn_pairs`` is True and ``default`` is not None, + or Tensors returned by the last callable in ``pred_fn_pairs`` if no pred in ``pred_fn_pairs`` is True and ``default`` is None. + + Raises: + TypeError: If the type of ``pred_fn_pairs`` is not list or tuple. + TypeError: If the type of elements in ``pred_fn_pairs`` is not tuple. + TypeError: If the size of tuples in ``pred_fn_pairs`` is not 2. + TypeError: If the first element of 2-tuple in ``pred_fn_pairs`` is not Variable. + TypeError: If the second element of 2-tuple in ``pred_fn_pairs`` is not callable. + TypeError: If ``default`` is not None but it is not callable. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + def fn_1(): + return layers.fill_constant(shape=[1, 2], dtype='float32', value=1) + + def fn_2(): + return layers.fill_constant(shape=[2, 2], dtype='int32', value=2) + + def fn_3(): + return layers.fill_constant(shape=[3], dtype='int32', value=3) + + main_program = fluid.default_startup_program() + startup_program = fluid.default_main_program() + with program_guard(main_program, startup_program): + x = layers.fill_constant(shape=[1], dtype='float32', value=0.3) + y = layers.fill_constant(shape=[1], dtype='float32', value=0.1) + z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) + + pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3 + pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1 + pred_3 = layers.equal(x, y) # false: 0.3 == 0.1 + + # Call fn_1 because pred_1 is True + out_1 = layers.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3) + + # Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called. + # because fn_3 is the last callable in pred_fn_pairs. + out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]) + + exe = fluid.Executor(fluid.CPUPlace()) + res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2]) + print(res_1) # [[1. 1.]] + print(res_2) # [3 3 3] + ''' + helper = LayerHelper('case', **locals()) + + def _case_check_args(pred_fn_pairs, default): + ''' + Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default. + ''' + if not isinstance(pred_fn_pairs, (list, tuple)): + raise TypeError( + _error_message("The type", "pred_fn_pairs", "case", + "list or tuple", type(pred_fn_pairs))) + + for pred_fn in pred_fn_pairs: + if not isinstance(pred_fn, tuple): + raise TypeError( + _error_message("The elements' type", "pred_fn_pairs", + "case", "tuple", type(pred_fn))) + if len(pred_fn) != 2: + raise TypeError( + _error_message("The tuple's size", "pred_fn_pairs", "case", + "2", str(len(pred_fn)) + "-tuple")) + pred, fn = pred_fn + + if not isinstance(pred, Variable): + raise TypeError( + _error_message("The pred's type", "pred_fn_pairs", "case", + "boolean Variable", type(pred))) + + if not callable(fn): + raise TypeError( + "The fn for {} of pred_fn_pairs in Op(case) must" + " be callable.".format(pred.name)) + + if default is None: + default_index = len(pred_fn_pairs) - 1 # pick the last one + default = pred_fn_pairs[default_index][1] + pred_fn_pairs = pred_fn_pairs[:default_index] + elif not callable(default): + raise TypeError("The default in Op(case) must be callable.") + + return pred_fn_pairs, default + + pred_fn_pairs, default = _case_check_args(pred_fn_pairs, default) + + false_fn = default + for pred, true_fn in reversed(pred_fn_pairs): + false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn) + + final_fn = false_fn + + return final_fn() + + class Switch(object): """ @@ -2786,17 +2910,6 @@ class DynamicRNN(object): method)) -def _error_message(what, arg_name, op_name, right_value, error_value): - error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \ - "{right_value}, but received: {error_value}.".format( - what=what, - arg_name=arg_name, - op_name=op_name, - right_value=right_value, - error_value=error_value) - return error_message - - def switch_case(branch_index, branch_fns, default=None, name=None): ''' This operator is like a C++ switch/case statement. diff --git a/python/paddle/fluid/tests/unittests/test_case.py b/python/paddle/fluid/tests/unittests/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..fa73a5ec62ffed7094b743622d0dede9962e897a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_case.py @@ -0,0 +1,227 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest + +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.layers as layers +from paddle.fluid.framework import Program, program_guard +from functools import partial + + +class TestAPICase(unittest.TestCase): + def test_return_single_var(self): + def fn_1(): + return layers.fill_constant(shape=[4, 2], dtype='int32', value=1) + + def fn_2(): + return layers.fill_constant(shape=[4, 2], dtype='int32', value=2) + + def fn_3(): + return layers.fill_constant(shape=[4, 3], dtype='int32', value=3) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = layers.fill_constant(shape=[1], dtype='float32', value=0.3) + y = layers.fill_constant(shape=[1], dtype='float32', value=0.1) + z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) + pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3 + + # call fn_1 + out_0 = layers.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3) + + # call fn_2 + out_1 = layers.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3) + + # call default fn_3 + out_2 = layers.case( + pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3) + + # no default, call fn_2 + out_3 = layers.case(pred_fn_pairs=[(pred_1, fn_2)]) + + # no default, call fn_2. but pred_2 is false + out_4 = layers.case(pred_fn_pairs=[(pred_2, fn_2)]) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + res = exe.run(main_program, + fetch_list=[out_0, out_1, out_2, out_3, out_4]) + + self.assertTrue(np.allclose(res[0], 1)) + self.assertTrue(np.allclose(res[1], 2)) + self.assertTrue(np.allclose(res[2], 3)) + self.assertTrue(np.allclose(res[3], 2)) + self.assertTrue(np.allclose(res[4], 2)) + + def test_return_var_tuple(self): + def fn_1(): + return layers.fill_constant( + shape=[1, 2], dtype='int32', value=1), layers.fill_constant( + shape=[2, 3], dtype='float32', value=2) + + def fn_2(): + return layers.fill_constant( + shape=[3, 4], dtype='int32', value=3), layers.fill_constant( + shape=[4, 5], dtype='float32', value=4) + + def fn_3(): + return layers.fill_constant( + shape=[5], dtype='int32', value=5), layers.fill_constant( + shape=[5, 6], dtype='float32', value=6) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = layers.fill_constant(shape=[1], dtype='float32', value=1) + y = layers.fill_constant(shape=[1], dtype='float32', value=1) + z = layers.fill_constant(shape=[1], dtype='float32', value=3) + + pred_1 = layers.equal(x, y) # true + pred_2 = layers.equal(x, z) # false + + out = layers.case(((pred_1, fn_1), (pred_2, fn_2)), fn_3) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + ret = exe.run(main_program, fetch_list=out) + + self.assertTrue( + np.allclose(np.asarray(ret[0]), np.full((1, 2), 1, np.int32))) + self.assertTrue( + np.allclose( + np.asarray(ret[1]), np.full((2, 3), 2, np.float32))) + + +class TestAPICase_Nested(unittest.TestCase): + def test_nested_case(self): + def fn_1(x=1): + var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5) + var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6) + out = layers.case(pred_fn_pairs=[(var_5 < var_6, partial( + layers.fill_constant, shape=[1], dtype='int32', value=x)), + (var_5 == var_6, partial( + layers.fill_constant, + shape=[2], + dtype='int32', + value=x))]) + return out + + def fn_2(x=2): + var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5) + var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6) + out = layers.case(pred_fn_pairs=[(var_5 < var_6, partial( + fn_1, x=x)), (var_5 == var_6, partial( + layers.fill_constant, shape=[2], dtype='int32', value=x))]) + return out + + def fn_3(): + var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5) + var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6) + out = layers.case(pred_fn_pairs=[(var_5 < var_6, partial( + fn_2, x=3)), (var_5 == var_6, partial( + layers.fill_constant, shape=[2], dtype='int32', value=7))]) + return out + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = layers.fill_constant(shape=[1], dtype='float32', value=0.3) + y = layers.fill_constant(shape=[1], dtype='float32', value=0.1) + z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) + pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1 + pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3 + + out_1 = layers.case( + pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3) + + out_2 = layers.case( + pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3) + + out_3 = layers.case( + pred_fn_pairs=[(x == y, fn_1), (x == z, fn_2)], default=fn_3) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + res = exe.run(main_program, fetch_list=[out_1, out_2, out_3]) + + self.assertTrue(np.allclose(res[0], 1)) + self.assertTrue(np.allclose(res[1], 2)) + self.assertTrue(np.allclose(res[2], 3)) + + +class TestAPICase_Error(unittest.TestCase): + def test_error(self): + def fn_1(): + return layers.fill_constant(shape=[4, 2], dtype='int32', value=1) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + x = layers.fill_constant(shape=[1], dtype='float32', value=0.23) + z = layers.fill_constant(shape=[1], dtype='float32', value=0.2) + pred_1 = layers.less_than(z, x) # true + + # The type of 'pred_fn_pairs' in case must be list or tuple + def type_error_pred_fn_pairs(): + layers.case(pred_fn_pairs=1, default=fn_1) + + self.assertRaises(TypeError, type_error_pred_fn_pairs) + + # The elements' type of 'pred_fn_pairs' in Op(case) must be tuple + def type_error_pred_fn_1(): + layers.case(pred_fn_pairs=[1], default=fn_1) + + self.assertRaises(TypeError, type_error_pred_fn_1) + + # The tuple's size of 'pred_fn_pairs' in Op(case) must be 2 + def type_error_pred_fn_2(): + layers.case(pred_fn_pairs=[(1, 2, 3)], default=fn_1) + + self.assertRaises(TypeError, type_error_pred_fn_2) + + # The pred's type of 'pred_fn_pairs' in Op(case) must be bool Variable + def type_error_pred(): + layers.case(pred_fn_pairs=[(1, fn_1)], default=fn_1) + + self.assertRaises(TypeError, type_error_pred) + + # The function of pred_fn_pairs in case must be callable + def type_error_fn(): + layers.case(pred_fn_pairs=[(pred_1, 2)], default=fn_1) + + self.assertRaises(TypeError, type_error_fn) + + # The default in Op(case) must be callable + def type_error_default(): + layers.case(pred_fn_pairs=[(pred_1, fn_1)], default=fn_1()) + + self.assertRaises(TypeError, type_error_default) + + +if __name__ == '__main__': + unittest.main()