diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index 03381b424a7a986aea54506edbda8cf266c2e3a0..d46c0c7c189b7508f2356e5ce24c56d74033e612 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -569,7 +569,7 @@ def case(pred_fn_pairs, default=None, name=None): This operator works like an if-elif-elif-else chain. Args: - pred_fn_pairs(list|tuple): A list or tuple of (pred, fn) pairs. ``pred`` is a boolean Tensor with shape [1], ``fn`` is a callable. All callables return the same structure of Tensors. + pred_fn_pairs(list|tuple): A list or tuple of (pred, fn) pairs. ``pred`` is a boolean Tensor whose numel should be 1 (shape [] or shape [1]), ``fn`` is a callable. All callables return the same structure of Tensors. default(callable, optional): Callable that returns a structure of Tensors. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -702,7 +702,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None): This operator is like a C++ switch/case statement. Args: - branch_index(Tensor): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. + branch_index(Tensor): A Tensor whose numel should be 1 (shape [] or shape [1]) to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``. branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors. default(callable, optional): Callable that returns a structure of Tensors. name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. @@ -910,9 +910,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None): branch will be executed during runtime. Args: - pred(Tensor): A boolean tensor whose numel should be 1. The boolean - value determines whether to return the result of ``true_fn`` or - ``false_fn`` . + pred(Tensor): A boolean tensor whose numel should be 1 (shape [] + or shape [1]). The boolean value determines whether to return the + result of ``true_fn`` or ``false_fn`` . true_fn(callable, optional): A callable to be performed if ``pred`` is true. The default value is ``None`` . false_fn(callable, optional): A callable to be performed if ``pred`` is