提交 b0fc8227 编写于 作者: L liym27 提交者: Aurelius84

Add control flow api: case (#21114)

* add control flow API: case. test=develop

* delete 'raise TypeError' in _error_message() and return a string. test=develop

* polish API document. test=develop
上级 6b1e1f0d
...@@ -33,7 +33,7 @@ __all__ = [ ...@@ -33,7 +33,7 @@ __all__ = [
'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than', 'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than',
'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal', 'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal',
'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN', 'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN',
'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'switch_case' 'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'case', 'switch_case'
] ]
...@@ -1798,6 +1798,130 @@ def cond(pred, true_fn=None, false_fn=None, name=None): ...@@ -1798,6 +1798,130 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
return merged_output return merged_output
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
op_name=op_name,
right_value=right_value,
error_value=error_value)
return error_message
def case(pred_fn_pairs, default=None, name=None):
'''
This operator works like an if-elif-elif-else chain.
Args:
pred_fn_pairs(list|tuple): A list or tuple of (pred, fn) pairs. ``pred`` is a boolean Tensor with shape [1], ``fn`` is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable|list(Variable): Tensors returned by the callable from the first pair whose pred is True,
or Tensors returned by ``default`` if no pred in ``pred_fn_pairs`` is True and ``default`` is not None,
or Tensors returned by the last callable in ``pred_fn_pairs`` if no pred in ``pred_fn_pairs`` is True and ``default`` is None.
Raises:
TypeError: If the type of ``pred_fn_pairs`` is not list or tuple.
TypeError: If the type of elements in ``pred_fn_pairs`` is not tuple.
TypeError: If the size of tuples in ``pred_fn_pairs`` is not 2.
TypeError: If the first element of 2-tuple in ``pred_fn_pairs`` is not Variable.
TypeError: If the second element of 2-tuple in ``pred_fn_pairs`` is not callable.
TypeError: If ``default`` is not None but it is not callable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
def fn_1():
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
def fn_2():
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[3], dtype='int32', value=3)
main_program = fluid.default_startup_program()
startup_program = fluid.default_main_program()
with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_3 = layers.equal(x, y) # false: 0.3 == 0.1
# Call fn_1 because pred_1 is True
out_1 = layers.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
# Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
# because fn_3 is the last callable in pred_fn_pairs.
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
exe = fluid.Executor(fluid.CPUPlace())
res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
print(res_1) # [[1. 1.]]
print(res_2) # [3 3 3]
'''
helper = LayerHelper('case', **locals())
def _case_check_args(pred_fn_pairs, default):
'''
Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
'''
if not isinstance(pred_fn_pairs, (list, tuple)):
raise TypeError(
_error_message("The type", "pred_fn_pairs", "case",
"list or tuple", type(pred_fn_pairs)))
for pred_fn in pred_fn_pairs:
if not isinstance(pred_fn, tuple):
raise TypeError(
_error_message("The elements' type", "pred_fn_pairs",
"case", "tuple", type(pred_fn)))
if len(pred_fn) != 2:
raise TypeError(
_error_message("The tuple's size", "pred_fn_pairs", "case",
"2", str(len(pred_fn)) + "-tuple"))
pred, fn = pred_fn
if not isinstance(pred, Variable):
raise TypeError(
_error_message("The pred's type", "pred_fn_pairs", "case",
"boolean Variable", type(pred)))
if not callable(fn):
raise TypeError(
"The fn for {} of pred_fn_pairs in Op(case) must"
" be callable.".format(pred.name))
if default is None:
default_index = len(pred_fn_pairs) - 1 # pick the last one
default = pred_fn_pairs[default_index][1]
pred_fn_pairs = pred_fn_pairs[:default_index]
elif not callable(default):
raise TypeError("The default in Op(case) must be callable.")
return pred_fn_pairs, default
pred_fn_pairs, default = _case_check_args(pred_fn_pairs, default)
false_fn = default
for pred, true_fn in reversed(pred_fn_pairs):
false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)
final_fn = false_fn
return final_fn()
class Switch(object): class Switch(object):
""" """
...@@ -2786,17 +2910,6 @@ class DynamicRNN(object): ...@@ -2786,17 +2910,6 @@ class DynamicRNN(object):
method)) method))
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = "{what} of '{arg_name}' in Op({op_name}) must be " \
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
op_name=op_name,
right_value=right_value,
error_value=error_value)
return error_message
def switch_case(branch_index, branch_fns, default=None, name=None): def switch_case(branch_index, branch_fns, default=None, name=None):
''' '''
This operator is like a C++ switch/case statement. This operator is like a C++ switch/case statement.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
from functools import partial
class TestAPICase(unittest.TestCase):
def test_return_single_var(self):
def fn_1():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)
def fn_2():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
# call fn_1
out_0 = layers.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3)
# call fn_2
out_1 = layers.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3)
# call default fn_3
out_2 = layers.case(
pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3)
# no default, call fn_2
out_3 = layers.case(pred_fn_pairs=[(pred_1, fn_2)])
# no default, call fn_2. but pred_2 is false
out_4 = layers.case(pred_fn_pairs=[(pred_2, fn_2)])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program,
fetch_list=[out_0, out_1, out_2, out_3, out_4])
self.assertTrue(np.allclose(res[0], 1))
self.assertTrue(np.allclose(res[1], 2))
self.assertTrue(np.allclose(res[2], 3))
self.assertTrue(np.allclose(res[3], 2))
self.assertTrue(np.allclose(res[4], 2))
def test_return_var_tuple(self):
def fn_1():
return layers.fill_constant(
shape=[1, 2], dtype='int32', value=1), layers.fill_constant(
shape=[2, 3], dtype='float32', value=2)
def fn_2():
return layers.fill_constant(
shape=[3, 4], dtype='int32', value=3), layers.fill_constant(
shape=[4, 5], dtype='float32', value=4)
def fn_3():
return layers.fill_constant(
shape=[5], dtype='int32', value=5), layers.fill_constant(
shape=[5, 6], dtype='float32', value=6)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=1)
y = layers.fill_constant(shape=[1], dtype='float32', value=1)
z = layers.fill_constant(shape=[1], dtype='float32', value=3)
pred_1 = layers.equal(x, y) # true
pred_2 = layers.equal(x, z) # false
out = layers.case(((pred_1, fn_1), (pred_2, fn_2)), fn_3)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out)
self.assertTrue(
np.allclose(np.asarray(ret[0]), np.full((1, 2), 1, np.int32)))
self.assertTrue(
np.allclose(
np.asarray(ret[1]), np.full((2, 3), 2, np.float32)))
class TestAPICase_Nested(unittest.TestCase):
def test_nested_case(self):
def fn_1(x=1):
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(pred_fn_pairs=[(var_5 < var_6, partial(
layers.fill_constant, shape=[1], dtype='int32', value=x)),
(var_5 == var_6, partial(
layers.fill_constant,
shape=[2],
dtype='int32',
value=x))])
return out
def fn_2(x=2):
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(pred_fn_pairs=[(var_5 < var_6, partial(
fn_1, x=x)), (var_5 == var_6, partial(
layers.fill_constant, shape=[2], dtype='int32', value=x))])
return out
def fn_3():
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(pred_fn_pairs=[(var_5 < var_6, partial(
fn_2, x=3)), (var_5 == var_6, partial(
layers.fill_constant, shape=[2], dtype='int32', value=7))])
return out
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
out_1 = layers.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
out_2 = layers.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3)
out_3 = layers.case(
pred_fn_pairs=[(x == y, fn_1), (x == z, fn_2)], default=fn_3)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
self.assertTrue(np.allclose(res[0], 1))
self.assertTrue(np.allclose(res[1], 2))
self.assertTrue(np.allclose(res[2], 3))
class TestAPICase_Error(unittest.TestCase):
def test_error(self):
def fn_1():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
x = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
pred_1 = layers.less_than(z, x) # true
# The type of 'pred_fn_pairs' in case must be list or tuple
def type_error_pred_fn_pairs():
layers.case(pred_fn_pairs=1, default=fn_1)
self.assertRaises(TypeError, type_error_pred_fn_pairs)
# The elements' type of 'pred_fn_pairs' in Op(case) must be tuple
def type_error_pred_fn_1():
layers.case(pred_fn_pairs=[1], default=fn_1)
self.assertRaises(TypeError, type_error_pred_fn_1)
# The tuple's size of 'pred_fn_pairs' in Op(case) must be 2
def type_error_pred_fn_2():
layers.case(pred_fn_pairs=[(1, 2, 3)], default=fn_1)
self.assertRaises(TypeError, type_error_pred_fn_2)
# The pred's type of 'pred_fn_pairs' in Op(case) must be bool Variable
def type_error_pred():
layers.case(pred_fn_pairs=[(1, fn_1)], default=fn_1)
self.assertRaises(TypeError, type_error_pred)
# The function of pred_fn_pairs in case must be callable
def type_error_fn():
layers.case(pred_fn_pairs=[(pred_1, 2)], default=fn_1)
self.assertRaises(TypeError, type_error_fn)
# The default in Op(case) must be callable
def type_error_default():
layers.case(pred_fn_pairs=[(pred_1, fn_1)], default=fn_1())
self.assertRaises(TypeError, type_error_default)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册