From 924752825212303ac9320c2ab066f33ed0056ab7 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Mon, 18 Nov 2019 14:10:49 +0800 Subject: [PATCH] Control flow API: switch_case (#21103) * add API switch_case. test=develop add Nest * modify code according to reviews: 1.Attr(branch_index) support 'uint8' and 'int64' besides 'int32'. 2.remove useless code. test=develop * replace fluid.layers.data with fluid.data and polish API document. test=develop --- python/paddle/fluid/layers/control_flow.py | 166 +++++++++- .../fluid/tests/unittests/test_switch_case.py | 309 ++++++++++++++++++ 2 files changed, 473 insertions(+), 2 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_switch_case.py diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index 552bfb4c8c6..f8346a49d2a 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -26,13 +26,14 @@ from .utils import assert_same_structure, flatten, map_structure import numpy import warnings import six -from functools import reduce +from functools import reduce, partial +from ..data_feeder import convert_dtype __all__ = [ 'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than', 'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal', 'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN', - 'reorder_lod_tensor_by_rank', 'Print', 'is_empty' + 'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'switch_case' ] @@ -2785,6 +2786,167 @@ class DynamicRNN(object): method)) +def _error_message(what, arg_name, op_name, right_value, error_value): + error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \ + "{right_value}, but received: {error_value}.".format( + what=what, + arg_name=arg_name, + op_name=op_name, + right_value=right_value, + error_value=error_value) + return error_message + + +def switch_case(branch_index, branch_fns, default=None, name=None): + ''' + This operator is like a C++ switch/case statement. + + Args: + branch_index(Variable): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. + branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors. + default(callable, optional): Callable that returns a structure of Tensors. + name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Variable|list(Variable): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``, + or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``, + or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``. + + Raises: + TypeError: If the type of ``branch_index`` is not Variable. + TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``. + TypeError: If the type of ``branch_fns`` is not dict, list or tuple. + TypeError: If the elements of ``branch_fns`` is not 2-tuple. + TypeError: If the first element of 2-tuple in ``branch_fns`` is not integer. + ValueError: If the first element of 2-tuple in ``branch_fns`` is not unique. + TypeError: If the second element of 2-tuple in ``branch_fns`` is not callable. + TypeError: If ``default`` is not None but it is not callable. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + def fn_1(): + return layers.fill_constant(shape=[1, 2], dtype='float32', value=1) + + def fn_2(): + return layers.fill_constant(shape=[2, 2], dtype='int32', value=2) + + def fn_3(): + return layers.fill_constant(shape=[3], dtype='int32', value=3) + + main_program = fluid.default_startup_program() + startup_program = fluid.default_main_program() + with program_guard(main_program, startup_program): + index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) + index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) + + out_1 = layers.switch_case( + branch_index=index_1, + branch_fns={1: fn_1, 2: fn_2}, + default=fn_3) + + out_2 = layers.switch_case( + branch_index=index_2, + branch_fns=[(1, fn_1), (2, fn_2)], + default=fn_3) + + # Argument default is None and no index matches. fn_3 will be called because of the max index 7. + out_3 = layers.switch_case( + branch_index=index_2, + branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)]) + + exe = fluid.Executor(fluid.CPUPlace()) + res_1, res_2, res_3 = exe.run(main_program, + fetch_list=[out_1, out_2, out_3]) + print(res_1) # [[1. 1.]] + print(res_2) # [[2 2] [2 2]] + print(res_3) # [3 3 3] + ''' + helper = LayerHelper('switch_case', **locals()) + + def _check_args(branch_index, branch_fns, default): + if not isinstance(branch_index, Variable): + raise TypeError( + _error_message("The type", "branch_index", "switch_case", + "Variable", type(branch_index))) + + if convert_dtype(branch_index.dtype) not in ["uint8", "int32", "int64"]: + raise TypeError( + _error_message("The data type", "branch_index", "switch_case", + "uint8, int32 or int64", + convert_dtype(branch_index.dtype))) + + if convert_dtype(branch_index.dtype) != "int64": + branch_index = cast(branch_index, "int64") + + if not isinstance(branch_fns, (list, tuple, dict)): + raise TypeError( + _error_message("The type", "branch_fns", "switch_case", + "dict, tuple or list", type(branch_fns))) + + branch_fns = branch_fns.items() if isinstance(branch_fns, + dict) else branch_fns + + branch_fns = list(enumerate(branch_fns)) if all( + callable(fn) for fn in branch_fns) else branch_fns + + keys_of_fns = [] + for index_fn_pair in branch_fns: + if not isinstance(index_fn_pair, tuple): + raise TypeError( + _error_message("The elements' type", "branch_fns", + "switch_case", "tuple", type(branch_fns))) + + if len(index_fn_pair) != 2: + raise TypeError( + _error_message("The tuple's size", "branch_fns", + "switch_case", "2", + str(len(index_fn_pair)) + "-tuple")) + + key, fn = index_fn_pair + + if not isinstance(key, int): + raise TypeError( + _error_message("The key's type", "branch_fns", + "switch_case", "int", type(key))) + + if key in keys_of_fns: + raise ValueError( + "The key in 'branch_fns' must be unique, but '{}' appears more than once.". + format(key)) + else: + keys_of_fns.append(key) + + if not callable(fn): + raise TypeError( + _error_message("The type of function for key {}".format( + key), "branch_fns", "switch_case", "callable", type( + fn))) + + if default is None: + default = sorted(branch_fns)[-1][1] + branch_fns = sorted(branch_fns)[:-1] + elif not callable(default): + raise TypeError("The default in Op(case) must be callable.") + + pred_fn_pairs = [] + for index, fn in branch_fns: + new_index = fill_constant(shape=[1], dtype="int64", value=index) + pred = equal(branch_index, new_index) + pred_fn_pairs.append((pred, fn)) + + return pred_fn_pairs, default + + pred_fn_pairs, default = _check_args(branch_index, branch_fns, default) + false_fn = default + for pred, true_fn in pred_fn_pairs: + false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn) + + final_fn = false_fn + return final_fn() + + @templatedoc() def reorder_lod_tensor_by_rank(x, rank_table): """ diff --git a/python/paddle/fluid/tests/unittests/test_switch_case.py b/python/paddle/fluid/tests/unittests/test_switch_case.py new file mode 100644 index 00000000000..598e415e5fb --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_switch_case.py @@ -0,0 +1,309 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest + +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle.fluid.layers as layers +from paddle.fluid.framework import Program, program_guard +from functools import partial + + +class TestAPISwitchCase(unittest.TestCase): + def test_return_single_var(self): + def fn_1(): + return layers.fill_constant(shape=[4, 2], dtype='int32', value=1) + + def fn_2(): + return layers.fill_constant(shape=[4, 2], dtype='int32', value=2) + + def fn_3(): + return layers.fill_constant(shape=[4, 3], dtype='int32', value=3) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) + index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) + index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5) + + # call fn_1 + out_0 = layers.switch_case( + branch_index=index_1, branch_fns={1: fn_1, + 2: fn_2, + 3: fn_3}) + + # call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3} + out_1 = layers.switch_case( + branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)) + + # call default fn_3 + out_2 = layers.switch_case( + branch_index=index_5, + branch_fns=((1, fn_1), (2, fn_2)), + default=fn_3) + + # no default, call fn_2 + out_3 = layers.switch_case( + branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]) + + # no default, call fn_2 but branch_index is 5 + out_4 = layers.switch_case( + branch_index=index_5, + branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)]) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + res = exe.run(main_program, + fetch_list=[out_0, out_1, out_2, out_3, out_4]) + + self.assertTrue( + np.allclose(res[0], 1), + "result is {} but answer is {}".format(res[0], 1)) + self.assertTrue( + np.allclose(res[1], 2), + "result is {} but answer is {}".format(res[0], 2)) + self.assertTrue( + np.allclose(res[2], 3), + "result is {} but answer is {}".format(res[0], 3)) + self.assertTrue( + np.allclose(res[3], 2), + "result is {} but answer is {}".format(res[0], 2)) + self.assertTrue( + np.allclose(res[4], 2), + "result is {} but answer is {}".format(res[0], 2)) + + def test_return_var_tuple(self): + def fn_1(): + return layers.fill_constant( + shape=[1, 2], dtype='int32', value=1), layers.fill_constant( + shape=[2, 3], dtype='float32', value=2) + + def fn_2(): + return layers.fill_constant( + shape=[3, 4], dtype='int32', value=3), layers.fill_constant( + shape=[4, 5], dtype='float32', value=4) + + def fn_3(): + return layers.fill_constant( + shape=[5], dtype='int32', value=5), layers.fill_constant( + shape=[5, 6], dtype='float32', value=6) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1) + + out = layers.switch_case(index_1, ((1, fn_1), (2, fn_2)), fn_3) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + ret = exe.run(main_program, fetch_list=out) + + self.assertTrue( + np.allclose(np.asarray(ret[0]), np.full((1, 2), 1, np.int32))) + self.assertTrue( + np.allclose( + np.asarray(ret[1]), np.full((2, 3), 2, np.float32))) + + +class TestAPISwitchCase_Nested(unittest.TestCase): + def test_nested_switch_case(self): + def fn_1(x=1): + out = layers.switch_case( + branch_index=layers.fill_constant( + shape=[1], dtype='int32', value=x), + branch_fns={ + 1: partial( + layers.fill_constant, shape=[1], dtype='int32', + value=1), + x: partial( + layers.fill_constant, shape=[2], dtype='int32', value=x) + }) + return out + + def fn_2(x=2): + out = layers.switch_case( + branch_index=layers.fill_constant( + shape=[1], dtype='int32', value=2), + branch_fns={ + 1: partial( + layers.fill_constant, + shape=[4, 3], + dtype='int32', + value=1), + 2: partial( + fn_1, x=x) + }) + return out + + def fn_3(): + out = layers.switch_case( + branch_index=layers.fill_constant( + shape=[1], dtype='int32', value=3), + branch_fns={ + 1: partial( + layers.fill_constant, + shape=[4, 3], + dtype='int32', + value=1), + 3: partial( + fn_2, x=3) + }) + return out + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8') + index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2) + index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3) + + out_1 = layers.switch_case( + branch_index=index_1, branch_fns={1: fn_1, + 2: fn_2, + 3: fn_3}) + out_2 = layers.switch_case( + branch_index=index_2, branch_fns={1: fn_1, + 2: fn_2, + 3: fn_3}) + + out_3 = layers.switch_case( + branch_index=index_3, branch_fns={1: fn_1, + 2: fn_2, + 3: fn_3}) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + + res = exe.run(main_program, + feed={"index_1": np.array( + [1], dtype="uint8")}, + fetch_list=[out_1, out_2, out_3]) + + self.assertTrue( + np.allclose(res[0], 1), + "result is {} but answer is {}".format(res[0], 1)) + self.assertTrue( + np.allclose(res[1], 2), + "result is {} but answer is {}".format(res[1], 2)) + self.assertTrue( + np.allclose(res[2], 3), + "result is {} but answer is {}".format(res[2], 3)) + + +# test TypeError and ValueError of api switch_case +class TestAPISwitchCase_Error(unittest.TestCase): + def test_error(self): + def fn_1(): + return layers.fill_constant(shape=[4, 2], dtype='int32', value=1) + + def fn_2(): + return layers.fill_constant(shape=[4, 2], dtype='int32', value=2) + + def fn_3(): + return layers.fill_constant(shape=[4, 3], dtype='int32', value=3) + + main_program = Program() + startup_program = Program() + with program_guard(main_program, startup_program): + key_float32 = layers.fill_constant( + shape=[1], dtype='float32', value=0.23) + key_int32 = layers.fill_constant( + shape=[1], dtype='int32', value=0.23) + + # The type of 'branch_index' in Op(switch_case) must be Variable + def type_error_branch_index(): + layers.switch_case( + branch_index=1, branch_fns=[(1, fn_1)], default=fn_3) + + self.assertRaises(TypeError, type_error_branch_index) + + # The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8 + def dtype_error_branch_index(): + layers.switch_case( + branch_index=key_float32, + branch_fns=[(1, fn_1)], + default=fn_3) + + self.assertRaises(TypeError, dtype_error_branch_index) + + # The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict + def type_error_branch_fns(): + layers.switch_case( + branch_index=key_int32, branch_fns=1, default=fn_3) + + self.assertRaises(TypeError, type_error_branch_fns) + + # The elements' type of 'branch_fns' in Op(switch_case) must be tuple + def type_error_index_fn_pair_1(): + layers.switch_case( + branch_index=key_int32, branch_fns=[1], default=fn_3) + + self.assertRaises(TypeError, type_error_index_fn_pair_1) + + # The tuple's size of 'branch_fns' in Op(switch_case) must be 2 + def type_error_index_fn_pair_2(): + layers.switch_case( + branch_index=key_int32, + branch_fns=[(1, 2, 3)], + default=fn_3) + + self.assertRaises(TypeError, type_error_index_fn_pair_2) + + # The key's type of 'branch_fns' in Op(switch_case) must be int + def type_error_key(): + layers.switch_case( + branch_index=key_int32, branch_fns=[(2.3, 2)], default=fn_3) + + self.assertRaises(TypeError, type_error_key) + + # The key in 'branch_fns' must be unique + def value_error_key(): + layers.switch_case( + branch_index=key_int32, + branch_fns=[(2, fn_1), (2, fn_2)], + default=fn_3) + + self.assertRaises(ValueError, value_error_key) + + # The type of function in 'branch_fns' must be callable + def type_error_fn(): + layers.switch_case( + branch_index=key_int32, + branch_fns=[(1, 1), (2, fn_2)], + default=fn_3) + + self.assertRaises(TypeError, type_error_fn) + + # The default in Op(case) must be callable + def type_error_default(): + layers.switch_case( + branch_index=key_int32, + branch_fns=[(1, fn_1), (2, fn_2)], + default=1) + + self.assertRaises(TypeError, type_error_default) + + +if __name__ == '__main__': + unittest.main() -- GitLab