提交 92475282 编写于 作者: L liym27 提交者: Aurelius84

Control flow API: switch_case (#21103)

* add API switch_case. test=develop

add Nest

* modify code according to reviews:
1.Attr(branch_index) support 'uint8' and 'int64' besides 'int32'.
2.remove useless code.
test=develop

* replace fluid.layers.data with fluid.data and polish API document. test=develop
上级 65f70525
......@@ -26,13 +26,14 @@ from .utils import assert_same_structure, flatten, map_structure
import numpy
import warnings
import six
from functools import reduce
from functools import reduce, partial
from ..data_feeder import convert_dtype
__all__ = [
'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than',
'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal',
'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN',
'reorder_lod_tensor_by_rank', 'Print', 'is_empty'
'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'switch_case'
]
......@@ -2785,6 +2786,167 @@ class DynamicRNN(object):
method))
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
op_name=op_name,
right_value=right_value,
error_value=error_value)
return error_message
def switch_case(branch_index, branch_fns, default=None, name=None):
'''
This operator is like a C++ switch/case statement.
Args:
branch_index(Variable): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``.
branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable|list(Variable): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``,
or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``,
or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``.
Raises:
TypeError: If the type of ``branch_index`` is not Variable.
TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``.
TypeError: If the type of ``branch_fns`` is not dict, list or tuple.
TypeError: If the elements of ``branch_fns`` is not 2-tuple.
TypeError: If the first element of 2-tuple in ``branch_fns`` is not integer.
ValueError: If the first element of 2-tuple in ``branch_fns`` is not unique.
TypeError: If the second element of 2-tuple in ``branch_fns`` is not callable.
TypeError: If ``default`` is not None but it is not callable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3)
out_2 = layers.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
# Argument default is None and no index matches. fn_3 will be called because of the max index 7.
out_3 = layers.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
exe = fluid.Executor(fluid.CPUPlace())
res_1, res_2, res_3 = exe.run(main_program,
fetch_list=[out_1, out_2, out_3])
print(res_1) # [[1. 1.]]
print(res_2) # [[2 2] [2 2]]
print(res_3) # [3 3 3]
'''
helper = LayerHelper('switch_case', **locals())
def _check_args(branch_index, branch_fns, default):
if not isinstance(branch_index, Variable):
raise TypeError(
_error_message("The type", "branch_index", "switch_case",
"Variable", type(branch_index)))
if convert_dtype(branch_index.dtype) not in ["uint8", "int32", "int64"]:
raise TypeError(
_error_message("The data type", "branch_index", "switch_case",
"uint8, int32 or int64",
convert_dtype(branch_index.dtype)))
if convert_dtype(branch_index.dtype) != "int64":
branch_index = cast(branch_index, "int64")
if not isinstance(branch_fns, (list, tuple, dict)):
raise TypeError(
_error_message("The type", "branch_fns", "switch_case",
"dict, tuple or list", type(branch_fns)))
branch_fns = branch_fns.items() if isinstance(branch_fns,
dict) else branch_fns
branch_fns = list(enumerate(branch_fns)) if all(
callable(fn) for fn in branch_fns) else branch_fns
keys_of_fns = []
for index_fn_pair in branch_fns:
if not isinstance(index_fn_pair, tuple):
raise TypeError(
_error_message("The elements' type", "branch_fns",
"switch_case", "tuple", type(branch_fns)))
if len(index_fn_pair) != 2:
raise TypeError(
_error_message("The tuple's size", "branch_fns",
"switch_case", "2",
str(len(index_fn_pair)) + "-tuple"))
key, fn = index_fn_pair
if not isinstance(key, int):
raise TypeError(
_error_message("The key's type", "branch_fns",
"switch_case", "int", type(key)))
if key in keys_of_fns:
raise ValueError(
"The key in 'branch_fns' must be unique, but '{}' appears more than once.".
format(key))
else:
keys_of_fns.append(key)
if not callable(fn):
raise TypeError(
_error_message("The type of function for key {}".format(
key), "branch_fns", "switch_case", "callable", type(
fn)))
if default is None:
default = sorted(branch_fns)[-1][1]
branch_fns = sorted(branch_fns)[:-1]
elif not callable(default):
raise TypeError("The default in Op(case) must be callable.")
pred_fn_pairs = []
for index, fn in branch_fns:
new_index = fill_constant(shape=[1], dtype="int64", value=index)
pred = equal(branch_index, new_index)
pred_fn_pairs.append((pred, fn))
return pred_fn_pairs, default
pred_fn_pairs, default = _check_args(branch_index, branch_fns, default)
false_fn = default
for pred, true_fn in pred_fn_pairs:
false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)
final_fn = false_fn
return final_fn()
@templatedoc()
def reorder_lod_tensor_by_rank(x, rank_table):
"""
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
from functools import partial
class TestAPISwitchCase(unittest.TestCase):
def test_return_single_var(self):
def fn_1():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)
def fn_2():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
# call fn_1
out_0 = layers.switch_case(
branch_index=index_1, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
# call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
out_1 = layers.switch_case(
branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3))
# call default fn_3
out_2 = layers.switch_case(
branch_index=index_5,
branch_fns=((1, fn_1), (2, fn_2)),
default=fn_3)
# no default, call fn_2
out_3 = layers.switch_case(
branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)])
# no default, call fn_2 but branch_index is 5
out_4 = layers.switch_case(
branch_index=index_5,
branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program,
fetch_list=[out_0, out_1, out_2, out_3, out_4])
self.assertTrue(
np.allclose(res[0], 1),
"result is {} but answer is {}".format(res[0], 1))
self.assertTrue(
np.allclose(res[1], 2),
"result is {} but answer is {}".format(res[0], 2))
self.assertTrue(
np.allclose(res[2], 3),
"result is {} but answer is {}".format(res[0], 3))
self.assertTrue(
np.allclose(res[3], 2),
"result is {} but answer is {}".format(res[0], 2))
self.assertTrue(
np.allclose(res[4], 2),
"result is {} but answer is {}".format(res[0], 2))
def test_return_var_tuple(self):
def fn_1():
return layers.fill_constant(
shape=[1, 2], dtype='int32', value=1), layers.fill_constant(
shape=[2, 3], dtype='float32', value=2)
def fn_2():
return layers.fill_constant(
shape=[3, 4], dtype='int32', value=3), layers.fill_constant(
shape=[4, 5], dtype='float32', value=4)
def fn_3():
return layers.fill_constant(
shape=[5], dtype='int32', value=5), layers.fill_constant(
shape=[5, 6], dtype='float32', value=6)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
out = layers.switch_case(index_1, ((1, fn_1), (2, fn_2)), fn_3)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out)
self.assertTrue(
np.allclose(np.asarray(ret[0]), np.full((1, 2), 1, np.int32)))
self.assertTrue(
np.allclose(
np.asarray(ret[1]), np.full((2, 3), 2, np.float32)))
class TestAPISwitchCase_Nested(unittest.TestCase):
def test_nested_switch_case(self):
def fn_1(x=1):
out = layers.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=x),
branch_fns={
1: partial(
layers.fill_constant, shape=[1], dtype='int32',
value=1),
x: partial(
layers.fill_constant, shape=[2], dtype='int32', value=x)
})
return out
def fn_2(x=2):
out = layers.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=2),
branch_fns={
1: partial(
layers.fill_constant,
shape=[4, 3],
dtype='int32',
value=1),
2: partial(
fn_1, x=x)
})
return out
def fn_3():
out = layers.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=3),
branch_fns={
1: partial(
layers.fill_constant,
shape=[4, 3],
dtype='int32',
value=1),
3: partial(
fn_2, x=3)
})
return out
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8')
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3)
out_1 = layers.switch_case(
branch_index=index_1, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
out_2 = layers.switch_case(
branch_index=index_2, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
out_3 = layers.switch_case(
branch_index=index_3, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program,
feed={"index_1": np.array(
[1], dtype="uint8")},
fetch_list=[out_1, out_2, out_3])
self.assertTrue(
np.allclose(res[0], 1),
"result is {} but answer is {}".format(res[0], 1))
self.assertTrue(
np.allclose(res[1], 2),
"result is {} but answer is {}".format(res[1], 2))
self.assertTrue(
np.allclose(res[2], 3),
"result is {} but answer is {}".format(res[2], 3))
# test TypeError and ValueError of api switch_case
class TestAPISwitchCase_Error(unittest.TestCase):
def test_error(self):
def fn_1():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)
def fn_2():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
key_float32 = layers.fill_constant(
shape=[1], dtype='float32', value=0.23)
key_int32 = layers.fill_constant(
shape=[1], dtype='int32', value=0.23)
# The type of 'branch_index' in Op(switch_case) must be Variable
def type_error_branch_index():
layers.switch_case(
branch_index=1, branch_fns=[(1, fn_1)], default=fn_3)
self.assertRaises(TypeError, type_error_branch_index)
# The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8
def dtype_error_branch_index():
layers.switch_case(
branch_index=key_float32,
branch_fns=[(1, fn_1)],
default=fn_3)
self.assertRaises(TypeError, dtype_error_branch_index)
# The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict
def type_error_branch_fns():
layers.switch_case(
branch_index=key_int32, branch_fns=1, default=fn_3)
self.assertRaises(TypeError, type_error_branch_fns)
# The elements' type of 'branch_fns' in Op(switch_case) must be tuple
def type_error_index_fn_pair_1():
layers.switch_case(
branch_index=key_int32, branch_fns=[1], default=fn_3)
self.assertRaises(TypeError, type_error_index_fn_pair_1)
# The tuple's size of 'branch_fns' in Op(switch_case) must be 2
def type_error_index_fn_pair_2():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(1, 2, 3)],
default=fn_3)
self.assertRaises(TypeError, type_error_index_fn_pair_2)
# The key's type of 'branch_fns' in Op(switch_case) must be int
def type_error_key():
layers.switch_case(
branch_index=key_int32, branch_fns=[(2.3, 2)], default=fn_3)
self.assertRaises(TypeError, type_error_key)
# The key in 'branch_fns' must be unique
def value_error_key():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(2, fn_1), (2, fn_2)],
default=fn_3)
self.assertRaises(ValueError, value_error_key)
# The type of function in 'branch_fns' must be callable
def type_error_fn():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(1, 1), (2, fn_2)],
default=fn_3)
self.assertRaises(TypeError, type_error_fn)
# The default in Op(case) must be callable
def type_error_default():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(1, fn_1), (2, fn_2)],
default=1)
self.assertRaises(TypeError, type_error_default)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册