未验证 提交 6b0d9590 编写于 作者: G Ghost Screaming 提交者: GitHub

Clean and migrate fluid APIs of paddle.fluid.layers.control_flow (#48233)

* Merge branch 'reduce_sum' of https://github.com/GhostScreaming/Paddle into mine_fluid_clean_common.

* Fix some bugs.

* Clean APIs in python/paddle/fluid/layers/control_flow.py

* Polish code style.

* Change API.

* Fix some bugs.

* Fix some bugs.
上级 3a387df6
......@@ -52,20 +52,14 @@ import paddle
from paddle import _C_ops, _legacy_C_ops
__all__ = [
'While',
'Switch',
'increment',
'array_write',
'array_read',
'cond',
'IfElse',
'StaticRNN',
'reorder_lod_tensor_by_rank',
'Print',
'Assert',
'is_empty',
'case',
'switch_case',
'while_loop',
]
......@@ -527,6 +521,7 @@ def Assert(cond, data=None, summarize=20, name=None):
return op
# (TODO: Mine) There exists dependency. It will be removed later.
class BlockGuard:
"""
BlockGuard class.
......@@ -550,6 +545,7 @@ class BlockGuard:
return True
# (TODO: Mine) There exists dependency. It will be removed later.
class BlockGuardWithCompletion(BlockGuard):
"""
BlockGuardWithCompletion class.
......@@ -1101,6 +1097,7 @@ class StaticRNN:
)
# (TODO: Mine) There exists dependency. It will be removed later.
class WhileGuard(BlockGuard):
def __init__(self, while_op):
if not isinstance(while_op, While):
......@@ -1120,6 +1117,7 @@ class WhileGuard(BlockGuard):
return super().__exit__(exc_type, exc_val, exc_tb)
# (TODO: Mine) There exists dependency. It will be removed later.
def get_inputs_outputs_in_block(
current_block, inner_inputs, inner_outputs, helper
):
......@@ -1182,6 +1180,7 @@ def get_inputs_outputs_in_block(
return inner_inputs, inner_outputs
# (TODO: Mine) There exists dependency. It will be removed later.
class While:
"""
:api_attr: Static Graph
......@@ -1320,6 +1319,7 @@ class While:
support_ret_buildin_type = (bool, float, int)
# (TODO: Mine) There exists dependency. It will be removed later.
def assign_skip_lod_tensor_array(input, output):
"""
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
......@@ -1363,6 +1363,7 @@ def assign_skip_lod_tensor_array(input, output):
assign(input, output)
# (TODO: Mine) There exists dependency (jit.dy2static.convert_operators). It will be removed later.
def while_loop(cond, body, loop_vars, is_test=False, name=None):
"""
:api_attr: Static Graph
......@@ -1473,6 +1474,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
return loop_vars
# (TODO: Mine) There exists dependency. It will be removed later.
def _deal_with_undefined_var(output_vars, loop_vars):
"""Deal with undefined var cases, We create undefined variable based on the results of body().
In Dy2Static, we use undefined var to represent the var created in control flow. This function
......@@ -1511,102 +1513,6 @@ def _deal_with_undefined_var(output_vars, loop_vars):
return results
def lod_rank_table(x, level=0):
"""
LoD Rank Table Operator. Given an input variable **x** and a level number
of LoD, this layer creates a LodRankTable object. A LoDRankTable object
contains a list of bi-element tuples. Each tuple consists of an index and
a length, both of which are int type. Refering to specified level of LoD,
the index is the sequence index number and the length represents the
sequence length. Please note that the list is ranked in descending order by
the length. The following is an example:
.. code-block:: text
x is a LoDTensor:
x.lod = [[2, 1],
[5, 1, 1]]
x.data = [a, b, c, d, e, f, g]
1. set level to 0:
Create lod rank table:
lod_rank_table_obj = lod_rank_table(x, level=0)
Get:
lod_rank_table_obj.items() = [(0, 2), (1, 1)]
2. set level to 1:
Create lod rank table:
lod_rank_table_obj = lod_rank_table(x, level=1)
Get:
lod_rank_table_obj.items() = [(0, 5), (1, 1), (2, 1)]
Args:
x (Variable): Input variable, a LoDTensor based which to create the lod
rank table.
level (int): Specify the LoD level, on which to create the lod rank
table.
Returns:
Variable: The created LoDRankTable object.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[10],
dtype='float32', lod_level=1)
out = layers.lod_rank_table(x=x, level=0)
"""
check_type(x, 'x', (Variable, list), 'lod_rank_table')
if isinstance(x, (list)):
for i, input_x in enumerate(x):
check_type(
input_x, 'input[' + str(i) + ']', Variable, 'lod_rank_table'
)
helper = LayerHelper("lod_rank_table", **locals())
table = helper.create_variable(
type=core.VarDesc.VarType.LOD_RANK_TABLE,
name=unique_name.generate("lod_rank_table"),
)
helper.append_op(
type='lod_rank_table',
inputs={'X': x},
outputs={'Out': table},
attrs={'level': level},
)
return table
@templatedoc()
def max_sequence_len(rank_table):
"""
${comment}
>>> import paddle.fluid as fluid
>>> x = fluid.layers.data(name='x', shape=[10], dtype='float32',
>>> lod_level=1)
>>> rank_table = layers.lod_rank_table(x=x, level=0)
>>> max_seq_len = layers.max_sequence_len(rank_table)
Args:
rank_table(${rank_table_type}): ${rank_table_comment}.
Returns:
${out_comment}.
"""
helper = LayerHelper("max_seqence_len", **locals())
res = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="max_sequence_len",
inputs={"RankTable": rank_table},
outputs={"Out": res},
)
return res
def increment(x, value=1.0, in_place=True):
"""
The OP is usually used for control flow to increment the data of :attr:`x` by an amount :attr:`value`.
......@@ -2422,154 +2328,6 @@ def expand_undefined_var(nest1, nest2, names):
return nest1_out, nest2_out
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = (
"{what} of '{arg_name}' in {op_name} must be "
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
op_name=op_name,
right_value=right_value,
error_value=error_value,
)
)
return error_message
def case(pred_fn_pairs, default=None, name=None):
'''
:api_attr: Static Graph
This operator works like an if-elif-elif-else chain.
Args:
pred_fn_pairs(list|tuple): A list or tuple of (pred, fn) pairs. ``pred`` is a boolean Tensor with shape [1], ``fn`` is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor|list(Tensor): Tensors returned by the callable from the first pair whose pred is True,
or Tensors returned by ``default`` if no pred in ``pred_fn_pairs`` is True and ``default`` is not None,
or Tensors returned by the last callable in ``pred_fn_pairs`` if no pred in ``pred_fn_pairs`` is True and ``default`` is None.
Raises:
TypeError: If the type of ``pred_fn_pairs`` is not list or tuple.
TypeError: If the type of elements in ``pred_fn_pairs`` is not tuple.
TypeError: If the size of tuples in ``pred_fn_pairs`` is not 2.
TypeError: If the first element of 2-tuple in ``pred_fn_pairs`` is not a Tensor.
TypeError: If the second element of 2-tuple in ``pred_fn_pairs`` is not callable.
TypeError: If ``default`` is not None but it is not callable.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
def fn_1():
return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)
def fn_2():
return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[3], dtype='int32', fill_value=3)
main_program = paddle.static.default_startup_program()
startup_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.full(shape=[1], dtype='float32', fill_value=0.3)
y = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
z = paddle.full(shape=[1], dtype='float32', fill_value=0.2)
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
# Call fn_1 because pred_1 is True
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
# Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
# because fn_3 is the last callable in pred_fn_pairs.
out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
exe = paddle.static.Executor(paddle.CPUPlace())
res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
print(res_1) # [[1. 1.]]
print(res_2) # [3 3 3]
'''
helper = LayerHelper('case', **locals())
def _case_check_args(pred_fn_pairs, default):
'''
Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
'''
check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case')
for pred_fn in pred_fn_pairs:
if not isinstance(pred_fn, tuple):
raise TypeError(
_error_message(
"The elements' type",
"pred_fn_pairs",
"case",
tuple,
type(pred_fn),
)
)
if len(pred_fn) != 2:
raise TypeError(
_error_message(
"The tuple's size",
"pred_fn_pairs",
"case",
"2",
str(len(pred_fn)) + "-tuple",
)
)
pred, fn = pred_fn
if not isinstance(pred, Variable):
raise TypeError(
_error_message(
"The pred's type",
"pred_fn_pairs",
"case",
"boolean Variable",
type(pred),
)
)
if not callable(fn):
raise TypeError(
"The fn for {} of pred_fn_pairs in Op(case) must"
" be callable.".format(pred.name)
)
if default is None:
default_index = len(pred_fn_pairs) - 1 # pick the last one
default = pred_fn_pairs[default_index][1]
pred_fn_pairs = pred_fn_pairs[:default_index]
elif not callable(default):
raise TypeError("The default in Op(case) must be callable.")
return pred_fn_pairs, default
pred_fn_pairs, default = _case_check_args(pred_fn_pairs, default)
false_fn = default
for pred, true_fn in reversed(pred_fn_pairs):
false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)
final_fn = false_fn
return final_fn()
class Switch:
"""
:api_attr: Static Graph
......@@ -2698,498 +2456,3 @@ class Switch:
return False # re-raise exception
return True
class IfElseBlockGuard:
def __init__(self, is_true, ifelse):
if not isinstance(ifelse, IfElse):
raise TypeError("ifelse must be an instance of IfElse class")
if ifelse.status != IfElse.OUT_IF_ELSE_BLOCKS:
raise ValueError("You cannot invoke IfElse.block() inside a block")
self.is_true = is_true
self.ie = ifelse
if is_true:
self.cond_block = ifelse.conditional_true_block
else:
self.cond_block = ifelse.conditional_false_block
if not isinstance(self.cond_block, ConditionalBlock):
raise TypeError("Unexpected situation")
self.cond_block = self.cond_block.block()
def __enter__(self):
self.ie.status = (
IfElse.IN_IF_ELSE_TRUE_BLOCKS
if self.is_true
else IfElse.IN_IF_ELSE_FALSE_BLOCKS
)
self.cond_block.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if not self.cond_block.__exit__(exc_type, exc_val, exc_tb):
# re-raise inside exception
return False
if len(self.ie.output_table[1 if self.is_true else 0]) == 0:
raise ValueError("Must set output inside block")
self.ie.status = IfElse.OUT_IF_ELSE_BLOCKS
class IfElse:
"""
:api_attr: Static Graph
This class is used to implement IfElse branch control function. IfElse contains two blocks, true_block and false_block. IfElse will put data satisfying True or False conditions into different blocks to run.
Cond is a 2-D Tensor with shape [N, 1] and data type bool, representing the execution conditions of the corresponding part of the input data.
Note:
A new OP :ref:`api_fluid_layers_cond` is highly recommended instead of ``IfElse``. if the shape of parameter ``cond`` is [1].
OP :ref:`api_fluid_layers_cond` is easier to use and is called with less code but does the same thing as ``IfElse`` .
IfElse OP is different from other OPs in usage, which may cause some users confusion. Here is a simple example to illustrate this OP.
.. code-block:: python
# The following code completes the function: subtract 10 from the data greater than 0 in x, add 10 to the data less than 0 in x, and sum all the data.
import numpy as np
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[4, 1], dtype='float32', append_batch_size=False)
y = fluid.layers.data(name='y', shape=[4, 1], dtype='float32', append_batch_size=False)
x_d = np.array([[3], [1], [-2], [-3]]).astype(np.float32)
y_d = np.zeros((4, 1)).astype(np.float32)
# Compare the size of x, y pairs of elements, output cond, cond is shape [4, 1], data type bool 2-D tensor.
# Based on the input data x_d, y_d, it can be inferred that the data in cond are [[true], [true], [false], [false]].
cond = fluid.layers.greater_than(x, y)
# Unlike other common OPs, ie below returned by the OP is an IfElse OP object
ie = fluid.layers.IfElse(cond)
with ie.true_block():
# In this block, according to cond condition, the data corresponding to true dimension in X is obtained and subtracted by 10.
out_1 = ie.input(x)
out_1 = out_1 - 10
ie.output(out_1)
with ie.false_block():
# In this block, according to cond condition, get the data of the corresponding condition in X as false dimension, and add 10
out_1 = ie.input(x)
out_1 = out_1 + 10
ie.output(out_1)
# According to cond condition, the data processed in the two blocks are merged. The output here is output, the type is List, and the element type in List is Variable.
output = ie() # [array([[-7.], [-9.], [ 8.], [ 7.]], dtype=float32)]
# Get the first Variable in the output List and add all elements.
out = paddle.sum(output[0])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
res = exe.run(fluid.default_main_program(), feed={"x":x_d, "y":y_d}, fetch_list=[out])
print(res)
# [array([-1.], dtype=float32)]
Args:
cond (Variable): cond is a 2-D Tensor with shape [N, 1] and data type bool, representing the corresponding execution conditions of N input data. The data type is bool.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Returns:
Unlike other common OPs, the OP call returns an IfElse OP object (e.g. ie in the example), which branches the input data by calling the internal functions of the object ``true_block ()``, ``false_block ()``, ``input ()``, ``output ()``, and integrates the data processed by different branches as the overall output by calling the internal ``call ()`` function. The output type is a list, and the type of each element in the list is Variable.
Internal Functions:
The block is constructed by calling the ``with ie. true_block()`` function in the object, and the computational logic under condition true is put into the block. If no corresponding block is constructed, the input data in the corresponding conditional dimension is unchanged.
The block is constructed by calling the ``with ie. false_block()`` function in the object, and the computational logic under condition false is put into the block. If no corresponding block is constructed, the input data in the corresponding conditional dimension is unchanged.
``Out = ie. input (x)`` will take out the data of the corresponding conditional dimension in X and put it into out, supporting the internal processing of multiple inputs in block.
``ie. output (out)`` writes the result to the output of the corresponding condition.
There is a ``call ()`` function inside the object, that is, by calling ``output = ie ()``, all the outputs inside the block of False are fused as the whole output, the output type is a list, and the type of each element in the list is Variable.
"""
OUT_IF_ELSE_BLOCKS = 0
IN_IF_ELSE_TRUE_BLOCKS = 1
IN_IF_ELSE_FALSE_BLOCKS = 2
def __init__(self, cond, name=None):
check_type(cond, "cond", Variable, "fluid.layers.IfElse")
check_type(name, "name", (str, type(None)), "fluid.layers.IfElse")
self.helper = LayerHelper('ifelse', name=name)
self.cond = cond
self.input_table = {}
self.status = IfElse.OUT_IF_ELSE_BLOCKS
self.conditional_true_block = ConditionalBlock(inputs=[self.cond])
self.conditional_false_block = ConditionalBlock(inputs=[self.cond])
self.output_table = ([], []) # (true_outs, false_outs)
def input(self, x):
if self.status == IfElse.OUT_IF_ELSE_BLOCKS:
raise ValueError("input must in true/false blocks")
if id(x) not in self.input_table:
parent_block = self._parent_block()
out_true = parent_block.create_var(
name=unique_name.generate_with_ignorable_key(
'ifelse_input' + self.helper.name
),
dtype=x.dtype,
)
out_false = parent_block.create_var(
name=unique_name.generate_with_ignorable_key(
'ifelse_input' + self.helper.name
),
dtype=x.dtype,
)
parent_block.append_op(
type='split_lod_tensor',
inputs={
'X': x,
'Mask': self.cond,
},
outputs={'OutTrue': out_true, 'OutFalse': out_false},
attrs={'level': 0},
)
self.input_table[id(x)] = (out_true, out_false)
else:
out_true, out_false = self.input_table[id(x)]
if self.status == IfElse.IN_IF_ELSE_TRUE_BLOCKS:
return out_true
else:
return out_false
def _parent_block(self):
current_block = self.helper.main_program.current_block()
return self.helper.main_program.block(current_block.parent_idx)
def true_block(self):
return IfElseBlockGuard(True, self)
def false_block(self):
return IfElseBlockGuard(False, self)
def output(self, *outs):
if self.status == self.OUT_IF_ELSE_BLOCKS:
raise ValueError("output can only be invoked in the sub-block")
out_table = self.output_table[
1 if self.status == self.IN_IF_ELSE_TRUE_BLOCKS else 0
]
parent_block = self._parent_block()
for each_out in outs:
check_type(
each_out, "each output", Variable, "fluid.layers.IfElse.output"
)
# create outside tensor
outside_out = parent_block.create_var(
name=unique_name.generate_with_ignorable_key(
"_".join([self.helper.name, 'output'])
),
dtype=each_out.dtype,
)
out_table.append(outside_out)
# assign local var to outside
assign(input=each_out, output=outside_out)
def __call__(self):
if self.status != self.OUT_IF_ELSE_BLOCKS:
raise ValueError("IfElse::__call__ must be out of sub-block")
false_len, true_len = list(map(len, self.output_table))
if false_len == 0 and true_len == 0:
raise ValueError(
"Must invoke true_block/false_block before " "__call__"
)
elif false_len != true_len and false_len != 0 and true_len != 0:
raise ValueError("The output side must be same")
elif false_len == 0 or true_len == 0:
return self.output_table[0 if false_len != 0 else 1]
# else none of false_len/true_len is zero
# merge together
rlist = []
for false_var, true_var in zip(*self.output_table):
rlist.append(
merge_lod_tensor(
in_true=true_var,
in_false=false_var,
mask=self.cond,
x=self.cond,
level=0,
)
)
return rlist
def switch_case(branch_index, branch_fns, default=None, name=None):
'''
:api_attr: Static Graph
This operator is like a C++ switch/case statement.
Args:
branch_index(Tensor): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``.
branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor|list(Tensor): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``,
or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``,
or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``.
Raises:
TypeError: If the type of ``branch_index`` is not Tensor.
TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``.
TypeError: If the type of ``branch_fns`` is not dict, list or tuple.
TypeError: If the elements of ``branch_fns`` is not 2-tuple.
TypeError: If the first element of 2-tuple in ``branch_fns`` is not integer.
ValueError: If the first element of 2-tuple in ``branch_fns`` is not unique.
TypeError: If the second element of 2-tuple in ``branch_fns`` is not callable.
TypeError: If ``default`` is not None but it is not callable.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
def fn_1():
return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)
def fn_2():
return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[3], dtype='int32', fill_value=3)
main_program = paddle.static.default_startup_program()
startup_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
index_1 = paddle.full(shape=[1], dtype='int32', fill_value=1)
index_2 = paddle.full(shape=[1], dtype='int32', fill_value=2)
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3)
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
# Argument default is None and no index matches. fn_3 will be called because of the max index 7.
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
exe = paddle.static.Executor(paddle.CPUPlace())
res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
print(res_1) # [[1. 1.]]
print(res_2) # [[2 2] [2 2]]
print(res_3) # [3 3 3]
'''
helper = LayerHelper('switch_case', **locals())
def _check_args(branch_index, branch_fns, default):
check_variable_and_dtype(
branch_index,
'branch_index',
['uint8', 'int32', 'int64'],
'switch_case',
)
if convert_dtype(branch_index.dtype) != "int64":
branch_index = cast(branch_index, "int64")
check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case')
branch_fns = (
branch_fns.items() if isinstance(branch_fns, dict) else branch_fns
)
branch_fns = (
list(enumerate(branch_fns))
if all(callable(fn) for fn in branch_fns)
else branch_fns
)
keys_of_fns = []
for index_fn_pair in branch_fns:
if not isinstance(index_fn_pair, tuple):
raise TypeError(
_error_message(
"The elements' type",
"branch_fns",
"switch_case",
tuple,
type(branch_fns),
)
)
if len(index_fn_pair) != 2:
raise TypeError(
_error_message(
"The tuple's size",
"branch_fns",
"switch_case",
"2",
str(len(index_fn_pair)) + "-tuple",
)
)
key, fn = index_fn_pair
if not isinstance(key, int):
raise TypeError(
_error_message(
"The key's type",
"branch_fns",
"switch_case",
int,
type(key),
)
)
if key in keys_of_fns:
raise ValueError(
"The key in 'branch_fns' must be unique, but '{}' appears more than once.".format(
key
)
)
else:
keys_of_fns.append(key)
if not callable(fn):
raise TypeError(
_error_message(
"The type of function for key {}".format(key),
"branch_fns",
"switch_case",
"callable",
type(fn),
)
)
if default is None:
default = sorted(branch_fns)[-1][1]
branch_fns = sorted(branch_fns)[:-1]
elif not callable(default):
raise TypeError("The default in Op(case) must be callable.")
pred_fn_pairs = []
for index, fn in branch_fns:
new_index = fill_constant(shape=[1], dtype="int64", value=index)
pred = paddle.equal(branch_index, new_index)
pred_fn_pairs.append((pred, fn))
return pred_fn_pairs, default
pred_fn_pairs, default = _check_args(branch_index, branch_fns, default)
false_fn = default
for pred, true_fn in pred_fn_pairs:
false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)
final_fn = false_fn
return final_fn()
@templatedoc()
def reorder_lod_tensor_by_rank(x, rank_table):
"""
${comment}
Args:
x(${x_type}): ${x_comment}.
rank_table(${rank_table_type}): ${rank_table_comment}.
Returns:
out(${out_type}): ${out_comment}.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data_desc = (['input', [9], 0], ['ref', [5], 1])
data = fluid.layers.data(name=data_desc[0][0], shape=data_desc[0][1])
rank_data = fluid.layers.data(name=data_desc[1][0], shape=data_desc[1][1])
table = fluid.layers.control_flow.lod_rank_table(rank_data)
new_data = fluid.layers.reorder_lod_tensor_by_rank(
x=data, rank_table=table)
"""
check_type(x, 'x', (Variable), 'reorder_lod_tensor_by_rank')
check_type(
rank_table, 'rank_table', (Variable), 'reorder_lod_tensor_by_rank'
)
if rank_table.type != core.VarDesc.VarType.LOD_RANK_TABLE:
raise TypeError("The type of rank_table should be LOD_RANK_TABLE.")
helper = LayerHelper('reorder_lod_tensor_by_rank', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='reorder_lod_tensor_by_rank',
inputs={'X': [x], 'RankTable': [rank_table]},
outputs={'Out': [out]},
)
return out
def is_empty(x, name=None):
"""
Test whether a Tensor is empty.
Args:
x (Tensor): The Tensor to be tested.
name (str, optional): The default value is ``None`` . Normally users
don't have to set this parameter. For more information,
please refer to :ref:`api_guide_Name` .
Returns:
Tensor: A bool scalar Tensor. True if 'x' is an empty Tensor.
Examples:
.. code-block:: python
import paddle
input = paddle.rand(shape=[4, 32, 32], dtype='float32')
res = paddle.is_empty(x=input)
print("res:", res)
# ('res:', Tensor: eager_tmp_1
# - place: CPUPlace
# - shape: [1]
# - layout: NCHW
# - dtype: bool
# - data: [0])
"""
if in_dygraph_mode():
return _C_ops.is_empty(x)
if _in_legacy_dygraph():
return _legacy_C_ops.is_empty(x)
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'is_empty'
)
check_type(name, "name", (str, type(None)), "is_empty")
helper = LayerHelper("is_empty", **locals())
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
helper.append_op(
type='is_empty', inputs={'X': [x]}, outputs={'Out': [cond]}
)
return cond
......@@ -1594,7 +1594,7 @@ def _dynamic_decode_declarative(
max_step_num = tensor.fill_constant(
shape=[1], dtype="int64", value=max_step_num
)
while_op = control_flow.While(cond, is_test=is_test)
while_op = paddle.static.nn.control_flow.While(cond, is_test=is_test)
sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64")
sequence_lengths.stop_gradient = True
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.layers.control_flow import (
ConditionalBlock,
merge_lod_tensor,
split_lod_tensor,
)
from paddle.fluid.optimizer import MomentumOptimizer
paddle.enable_static()
class TestMNISTIfElseOp(unittest.TestCase):
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_raw_api(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
image = layers.data(name='x', shape=[784], dtype='float32')
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = paddle.less_than(x=label, y=limit)
true_image, false_image = split_lod_tensor(input=image, mask=cond)
true_out = paddle.tensor.create_tensor(dtype='float32')
true_cond = ConditionalBlock([cond])
with true_cond.block():
hidden = layers.fc(input=true_image, size=100, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
layers.assign(input=prob, output=true_out)
false_out = paddle.tensor.create_tensor(dtype='float32')
false_cond = ConditionalBlock([cond])
with false_cond.block():
hidden = layers.fc(input=false_image, size=200, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
layers.assign(input=prob, output=false_out)
prob = merge_lod_tensor(
in_true=true_out, in_false=false_out, mask=cond, x=image
)
loss = layers.cross_entropy(input=prob, label=label)
avg_loss = paddle.mean(loss)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, startup_prog)
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192),
batch_size=10,
)
place = core.CPUPlace()
exe = Executor(place)
exe.run(startup_prog)
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array([x[0] for x in data]).astype("float32")
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = np.expand_dims(y_data, axis=1)
outs = exe.run(
prog, feed={'x': x_data, 'y': y_data}, fetch_list=[avg_loss]
)
print(outs[0])
if outs[0] < 1.0:
return
self.assertFalse(True)
# FIXME: https://github.com/PaddlePaddle/Paddle/issues/12245#issuecomment-406462379
def not_test_ifelse(self):
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
image = layers.data(name='x', shape=[784], dtype='float32')
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = paddle.less_than(x=label, y=limit)
ie = layers.IfElse(cond)
with ie.true_block():
true_image = ie.input(image)
hidden = layers.fc(input=true_image, size=100, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
with ie.false_block():
false_image = ie.input(image)
hidden = layers.fc(input=false_image, size=200, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
prob = ie()
loss = layers.cross_entropy(input=prob[0], label=label)
avg_loss = paddle.mean(loss)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, startup_prog)
train_reader = paddle.batch(
paddle.reader.shuffle(paddle.dataset.mnist.train(), buf_size=8192),
batch_size=200,
)
place = core.CPUPlace()
exe = Executor(place)
exe.run(startup_prog)
PASS_NUM = 100
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array([x[0] for x in data]).astype("float32")
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = y_data.reshape((y_data.shape[0], 1))
outs = exe.run(
prog, feed={'x': x_data, 'y': y_data}, fetch_list=[avg_loss]
)
print(outs[0])
if outs[0] < 1.0:
return
self.assertFalse(True)
class TestIfElse(unittest.TestCase):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 0.5
self.data = np.random.rand(25, 1).astype(np.float32)
def numpy_cal(self):
s1 = self.data[np.where(self.data < self.cond_value)]
res = np.sum(np.exp(s1))
s2 = self.data[np.where(self.data >= self.cond_value)]
res += np.sum(np.tanh(s2))
return res
def compare_ifelse_op_and_numpy(self, place):
self.set_test_case()
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
src = layers.data(name='data', shape=[1], dtype='float32')
cond = layers.fill_constant(
[1], dtype='float32', value=self.cond_value
)
ifcond = paddle.less_than(x=src, y=cond)
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
true_target = paddle.exp(true_target)
ie.output(true_target)
with ie.false_block():
false_target = ie.input(src)
false_target = paddle.tanh(false_target)
ie.output(false_target)
if_out = ie()
out = paddle.sum(if_out[0])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fetch_list = [out]
(o1,) = exe.run(
fluid.default_main_program(),
feed={'data': self.data},
fetch_list=[out],
)
o2 = self.numpy_cal()
np.testing.assert_allclose(
o1,
o2,
rtol=1e-05,
atol=1e-08,
)
def test_cpu(self):
self.compare_ifelse_op_and_numpy(fluid.CPUPlace())
def test_cuda(self):
if not core.is_compiled_with_cuda():
return
self.compare_ifelse_op_and_numpy(fluid.CUDAPlace(0))
class TestIfElseTrueBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = 10.0
self.data = np.random.rand(25, 1).astype(np.float32)
class TestIfElseFalseBranch(TestIfElse):
def set_test_case(self):
# condiction is: self.data < self.cond_value
self.cond_value = -10.0
self.data = np.random.rand(25, 1).astype(np.float32)
class TestIfElseError(unittest.TestCase):
def test_input_type_error(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
src = layers.data(name='data', shape=[1], dtype='float32')
const_value = layers.fill_constant(
[1], dtype='float32', value=123.0
)
ifcond = paddle.less_than(x=src, y=const_value)
with self.assertRaises(TypeError):
ie = layers.IfElse(set())
with self.assertRaises(TypeError):
ie = layers.IfElse(ifcond, set())
with self.assertRaises(TypeError):
ie = layers.IfElse(ifcond)
with ie.true_block():
true_target = ie.input(src)
true_target = paddle.exp(true_target)
ie.output([])
if __name__ == '__main__':
unittest.main()
......@@ -174,7 +174,7 @@ def get_program():
cond = paddle.less_than(x=i, y=loop_len)
auto.shard_tensor(cond, _g_process_mesh, [None])
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
pre_input = fluid.layers.array_read(array=input_array, i=i)
......
......@@ -84,7 +84,9 @@ class TestHybridParallelInferenceHelperClass(unittest.TestCase):
)
print(cond_int.shape)
cond = paddle.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond, is_test=True)
while_op = paddle.static.nn.control_flow.While(
cond, is_test=True
)
with while_op.block():
with paddle.fluid.device_guard(f'{device}:all'):
......
......@@ -1763,7 +1763,7 @@ def fast_decode(
shape=[1], dtype=start_tokens.dtype, value=0
)
cond = paddle.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
while_op = paddle.static.nn.control_flow.While(cond)
# array states will be stored for each step.
ids = layers.array_write(
paddle.reshape(start_tokens, (-1, 1)), step_idx
......
......@@ -161,7 +161,7 @@ def dyfunc_ifExp_with_while(x):
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
ten = fluid.layers.fill_constant(shape=[1], dtype='int64', value=10)
i, ten, y = fluid.layers.while_loop(cond, body, [i, ten, y])
i, ten, y = paddle.static.nn.while_loop(cond, body, [i, ten, y])
return y[0]
......
......@@ -145,7 +145,7 @@ class TestSetValueItemSlice5(TestSetValueApi):
# return i, x
#
# i = paddle.zeros(shape=(1, ), dtype='int32')
# i, x = paddle.fluid.layers.while_loop(cond, body, [i, x])
# i, x = paddle.static.nn.while_loop(cond, body, [i, x])
#
# def _get_answer(self):
# self.data[0] = self.value
......
......@@ -147,7 +147,7 @@ class TestSetValueItemSlice4(TestSetValueApi):
# return i, x
# i = paddle.zeros(shape=(1, ), dtype='int32')
# i, x = paddle.fluid.layers.while_loop(cond, body, [i, x])
# i, x = paddle.static.nn.while_loop(cond, body, [i, x])
# def _get_answer(self):
# self.data[0] = self.value
......
......@@ -64,8 +64,8 @@ class TestWhileOp(unittest.TestCase):
cond2 = paddle.logical_or(x=j, y=array_len2)
cond2 = paddle.ones(shape=[1], dtype='int32')
cond2 = layers.cast(cond2, 'bool')
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......
......@@ -17,9 +17,20 @@ import unittest
import numpy as np
import paddle
sys.path.append("../")
from op_test import OpTest, skip_check_grad_ci
from test_reorder_lod_tensor import convert_to_offset
paddle.enable_static()
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
def compute_seqpool_sum(x, offset, out, pad_value=0.0):
......
......@@ -24,6 +24,8 @@ import paddle.fluid.layers as layers
import paddle.fluid.optimizer as optimizer
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
class TestAPICase(unittest.TestCase):
def test_return_single_var(self):
......@@ -46,25 +48,29 @@ class TestAPICase(unittest.TestCase):
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
# call fn_1
out_0 = layers.case(
out_0 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_1, fn_2)], default=fn_3
)
# call fn_2
out_1 = layers.case(
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)
# call default fn_3
out_2 = layers.case(
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=((pred_2, fn_1), (pred_2, fn_2)), default=fn_3
)
# no default, call fn_2
out_3 = layers.case(pred_fn_pairs=[(pred_1, fn_2)])
out_3 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_2)]
)
# no default, call fn_2. but pred_2 is false
out_4 = layers.case(pred_fn_pairs=[(pred_2, fn_2)])
out_4 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_2)]
)
place = (
fluid.CUDAPlace(0)
......@@ -109,7 +115,9 @@ class TestAPICase(unittest.TestCase):
pred_1 = paddle.equal(x, y) # true
pred_2 = paddle.equal(x, z) # false
out = layers.case(((pred_1, fn_1), (pred_2, fn_2)), fn_3)
out = paddle.static.nn.control_flow.case(
((pred_1, fn_1), (pred_2, fn_2)), fn_3
)
place = (
fluid.CUDAPlace(0)
......@@ -132,7 +140,7 @@ class TestAPICase_Nested(unittest.TestCase):
def fn_1(x=1):
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(
out = paddle.static.nn.control_flow.case(
pred_fn_pairs=[
(
var_5 < var_6,
......@@ -159,7 +167,7 @@ class TestAPICase_Nested(unittest.TestCase):
def fn_2(x=2):
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(
out = paddle.static.nn.control_flow.case(
pred_fn_pairs=[
(var_5 < var_6, partial(fn_1, x=x)),
(
......@@ -178,7 +186,7 @@ class TestAPICase_Nested(unittest.TestCase):
def fn_3():
var_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
var_6 = layers.fill_constant(shape=[1], dtype='int32', value=6)
out = layers.case(
out = paddle.static.nn.control_flow.case(
pred_fn_pairs=[
(var_5 < var_6, partial(fn_2, x=3)),
(
......@@ -203,15 +211,15 @@ class TestAPICase_Nested(unittest.TestCase):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
out_1 = layers.case(
out_1 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(
out_2 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_2, fn_1), (pred_1, fn_2)], default=fn_3
)
out_3 = layers.case(
out_3 = paddle.static.nn.control_flow.case(
pred_fn_pairs=[(x == y, fn_1), (x == z, fn_2)], default=fn_3
)
......@@ -243,37 +251,49 @@ class TestAPICase_Error(unittest.TestCase):
# The type of 'pred_fn_pairs' in case must be list or tuple
def type_error_pred_fn_pairs():
layers.case(pred_fn_pairs=1, default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=1, default=fn_1
)
self.assertRaises(TypeError, type_error_pred_fn_pairs)
# The elements' type of 'pred_fn_pairs' in Op(case) must be tuple
def type_error_pred_fn_1():
layers.case(pred_fn_pairs=[1], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[1], default=fn_1
)
self.assertRaises(TypeError, type_error_pred_fn_1)
# The tuple's size of 'pred_fn_pairs' in Op(case) must be 2
def type_error_pred_fn_2():
layers.case(pred_fn_pairs=[(1, 2, 3)], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(1, 2, 3)], default=fn_1
)
self.assertRaises(TypeError, type_error_pred_fn_2)
# The pred's type of 'pred_fn_pairs' in Op(case) must be bool Variable
def type_error_pred():
layers.case(pred_fn_pairs=[(1, fn_1)], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(1, fn_1)], default=fn_1
)
self.assertRaises(TypeError, type_error_pred)
# The function of pred_fn_pairs in case must be callable
def type_error_fn():
layers.case(pred_fn_pairs=[(pred_1, 2)], default=fn_1)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, 2)], default=fn_1
)
self.assertRaises(TypeError, type_error_fn)
# The default in Op(case) must be callable
def type_error_default():
layers.case(pred_fn_pairs=[(pred_1, fn_1)], default=fn_1())
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(pred_1, fn_1)], default=fn_1()
)
self.assertRaises(TypeError, type_error_default)
......@@ -308,7 +328,9 @@ class TestMutiTask(unittest.TestCase):
loss = paddle.mean(sum, name="f_2_loss")
adagrad.minimize(loss)
layers.case(pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2)
paddle.static.nn.control_flow.case(
pred_fn_pairs=[(switch_id == one, fn_1)], default=fn_2
)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
......
......@@ -19,6 +19,8 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
paddle.enable_static()
def execute(main_program, startup_program):
if paddle.is_compiled_with_cuda():
......@@ -153,7 +155,7 @@ class TestDeviceGuard(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with paddle.static.device_guard("cpu"):
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
i = paddle.increment(x=i, value=1)
paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
......
......@@ -20,6 +20,8 @@ import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
paddle.enable_static()
def build_and_run_program(place, batch_size, beam_size, stop_gradient=False):
fluid.default_startup_program().random_seed = 1
......@@ -37,7 +39,7 @@ def build_and_run_program(place, batch_size, beam_size, stop_gradient=False):
shape=[1], dtype="int64", value=10, force_cpu=True
)
cond = paddle.less_than(x=step_idx, y=max_len)
while_op = layers.While(cond)
while_op = paddle.static.nn.control_flow.While(cond)
scores = layers.array_write(x, step_idx)
with while_op.block():
bs = layers.cast(paddle.shape(x)[0], "int64")
......
......@@ -103,8 +103,8 @@ class TestEagerDeletionWhileOpBase(unittest.TestCase):
array_len2.stop_gradient = True
cond2 = paddle.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......
......@@ -21,7 +21,14 @@ from sequence.test_sequence_pool import (
compute_seqpool_sqrt,
compute_seqpool_sum,
)
from test_reorder_lod_tensor import convert_to_offset
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
class TestFusionSeqPoolConcatOp(OpTest):
......
......@@ -22,7 +22,14 @@ from sequence.test_sequence_pool import (
compute_seqpool_sum,
)
from test_cvm_op import cvm_compute
from test_reorder_lod_tensor import convert_to_offset
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
class TestFusionSeqPoolCVMConcatOp(OpTest):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from test_imperative_base import new_program_scope
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph_utils as dygraph_utils
from paddle.fluid import core
from paddle.fluid.dygraph.layer_object_helper import LayerObjectHelper
from paddle.fluid.framework import _in_legacy_dygraph, _test_eager_guard
from paddle.fluid.layer_helper import LayerHelper
class MyLayer(fluid.Layer):
def __init__(self):
super().__init__()
def forward(self, inputs):
x = fluid.layers.relu(inputs)
self._x_for_debug = x
x = paddle.multiply(x, x)
x = paddle.sum(x)
return [x]
class MLP(fluid.Layer):
def __init__(self, input_size):
super().__init__()
self._linear1 = paddle.nn.Linear(
input_size,
3,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.1)
),
bias_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.1)
),
)
self._linear2 = paddle.nn.Linear(
3,
4,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.1)
),
bias_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.1)
),
)
def forward(self, inputs):
x = self._linear1(inputs)
x = self._linear2(x)
x = paddle.sum(x)
return x
class SimpleRNNCell(fluid.Layer):
def __init__(self, step_input_size, hidden_size, output_size, param_attr):
super().__init__()
self.step_input_size = step_input_size
self.hidden_size = hidden_size
self.output_size = output_size
self._dtype = core.VarDesc.VarType.FP32
self.param_attr = param_attr
i2h_param_shape = [self.step_input_size, self.hidden_size]
h2h_param_shape = [self.hidden_size, self.hidden_size]
h2o_param_shape = [self.output_size, self.hidden_size]
self._i2h_w = None
self._i2h_w = self.create_parameter(
attr=self.param_attr,
shape=i2h_param_shape,
dtype=self._dtype,
is_bias=False,
)
self._h2h_w = self.create_parameter(
attr=self.param_attr,
shape=h2h_param_shape,
dtype=self._dtype,
is_bias=False,
)
self._h2o_w = self.create_parameter(
attr=self.param_attr,
shape=h2o_param_shape,
dtype=self._dtype,
is_bias=False,
)
def forward(self, input, pre_hidden):
tmp_i2h = paddle.fluid.layers.nn.mul(input, self._i2h_w)
tmp_h2h = paddle.fluid.layers.nn.mul(pre_hidden, self._h2h_w)
hidden = paddle.add(tmp_h2h, tmp_i2h)
hidden = self._helper.append_activation(hidden, act='tanh')
out = paddle.fluid.layers.nn.mul(hidden, self._h2o_w)
softmax_out = paddle.nn.functional.softmax(out)
reduce_out = paddle.sum(softmax_out)
return reduce_out, hidden
class SimpleRNN(fluid.Layer):
def __init__(self):
super().__init__()
self.seq_len = 4
self._cell = SimpleRNNCell(
3,
3,
3,
fluid.ParamAttr(initializer=fluid.initializer.Constant(value=0.1)),
)
def forward(self, inputs):
outs = list()
pre_hiddens = list()
init_hidden = self.create_parameter(
attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.1)
),
shape=[1, 3],
dtype='float32',
is_bias=False,
)
pre_hidden = init_hidden
for i in range(self.seq_len):
input = paddle.slice(inputs, axes=[1], starts=[i], ends=[i + 1])
input = paddle.reshape(input, shape=[1, 3])
out_softmax, pre_hidden = self._cell(input, pre_hidden)
outs.append(out_softmax)
return outs, pre_hiddens
class TestImperative(unittest.TestCase):
def functional_dygraph_context(self):
self.assertFalse(fluid.dygraph.enabled())
fluid.enable_dygraph()
self.assertTrue(fluid.dygraph.enabled())
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
var_inp = paddle.to_tensor(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out1 = out.numpy()
out.backward()
dy_grad1 = mlp._linear1.weight.gradient()
fluid.disable_dygraph()
self.assertFalse(fluid.dygraph.enabled())
with fluid.dygraph.guard():
self.assertTrue(fluid.dygraph.enabled())
var_inp = paddle.to_tensor(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out2 = out.numpy()
out.backward()
dy_grad2 = mlp._linear1.weight.gradient()
self.assertFalse(fluid.dygraph.enabled())
np.testing.assert_array_equal(dy_out1, dy_out2)
np.testing.assert_array_equal(dy_grad1, dy_grad2)
def test_functional_dygraph_context(self):
with _test_eager_guard():
self.functional_dygraph_context()
self.functional_dygraph_context()
def functional_paddle_imperative_dygraph_context(self):
self.assertFalse(paddle.in_dynamic_mode())
paddle.disable_static()
self.assertTrue(paddle.in_dynamic_mode())
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
var_inp = paddle.to_tensor(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out1 = out.numpy()
out.backward()
dy_grad1 = mlp._linear1.weight.gradient()
paddle.enable_static()
self.assertFalse(paddle.in_dynamic_mode())
paddle.disable_static()
self.assertTrue(paddle.in_dynamic_mode())
var_inp = paddle.to_tensor(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out2 = out.numpy()
out.backward()
dy_grad2 = mlp._linear1.weight.gradient()
paddle.enable_static()
self.assertFalse(paddle.in_dynamic_mode())
np.testing.assert_array_equal(dy_out1, dy_out2)
np.testing.assert_array_equal(dy_grad1, dy_grad2)
def test_functional_paddle_imperative_dygraph_context(self):
with _test_eager_guard():
self.functional_paddle_imperative_dygraph_context()
self.functional_paddle_imperative_dygraph_context()
def func_isinstance(self):
var = fluid.layers.data(shape=[1], name='x', dtype='float32')
self.assertTrue(isinstance(var, fluid.Variable))
with fluid.dygraph.guard():
if not _in_legacy_dygraph():
var_base = paddle.to_tensor(np.array([3, 4, 5]))
self.assertTrue(isinstance(var_base, core.eager.Tensor))
else:
var_base = paddle.to_tensor(np.array([3, 4, 5]))
self.assertTrue(isinstance(var_base, core.VarBase))
self.assertTrue(isinstance(var_base, fluid.Variable))
def test_isinstance(self):
with _test_eager_guard():
self.func_isinstance()
self.func_isinstance()
def func_create_varbase(self):
x = np.ones([2, 2], np.float32)
y = np.zeros([3, 3], np.float32)
t = fluid.Tensor()
t.set(x, fluid.CPUPlace())
if not _in_legacy_dygraph():
egr_tmp = fluid.core.eager.Tensor(
value=x, place=fluid.core.CPUPlace()
)
egr_tmp2 = fluid.core.eager.Tensor(y, fluid.core.CPUPlace())
egr_tmp3 = paddle.to_tensor(x)
egr_tmp4 = fluid.core.eager.Tensor(y)
egr_tmp5 = fluid.core.eager.Tensor(value=x)
egr_tmp6 = fluid.core.eager.Tensor(t)
np.testing.assert_array_equal(x, egr_tmp.numpy())
np.testing.assert_array_equal(y, egr_tmp2.numpy())
np.testing.assert_array_equal(x, egr_tmp3.numpy())
np.testing.assert_array_equal(y, egr_tmp4.numpy())
np.testing.assert_array_equal(x, egr_tmp5.numpy())
np.testing.assert_array_equal(x, egr_tmp6.numpy())
else:
tmp = fluid.core.VarBase(value=x, place=fluid.core.CPUPlace())
tmp2 = fluid.core.VarBase(y, fluid.core.CPUPlace())
tmp3 = paddle.to_tensor(x)
tmp4 = fluid.core.VarBase(y)
tmp5 = fluid.core.VarBase(value=x)
tmp6 = fluid.core.VarBase(t)
np.testing.assert_array_equal(x, tmp.numpy())
np.testing.assert_array_equal(y, tmp2.numpy())
np.testing.assert_array_equal(x, tmp3.numpy())
np.testing.assert_array_equal(y, tmp4.numpy())
np.testing.assert_array_equal(x, tmp5.numpy())
np.testing.assert_array_equal(x, tmp6.numpy())
def test_create_varbase(self):
with fluid.dygraph.guard():
with _test_eager_guard():
self.func_create_varbase()
self.func_create_varbase()
def test_no_grad_guard(self):
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = paddle.nn.Linear(2, 2)
self.assertIsNone(l0.weight._grad_ivar())
l1 = paddle.nn.Linear(2, 2)
with fluid.dygraph.no_grad():
self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2
self.assertTrue(tmp.stop_gradient)
x = paddle.to_tensor(data)
y = paddle.add(l0(x), tmp)
o = l1(y)
o.backward()
self.assertIsNone(tmp._grad_ivar())
self.assertIsNotNone(l0.weight._grad_ivar())
def test_paddle_imperative_no_grad_guard(self):
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = paddle.nn.Linear(2, 2)
self.assertIsNone(l0.weight._grad_ivar())
l1 = paddle.nn.Linear(2, 2)
with paddle.no_grad():
self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2
self.assertTrue(tmp.stop_gradient)
x = paddle.to_tensor(data)
y = paddle.add(l0(x), tmp)
o = l1(y)
o.backward()
self.assertIsNone(tmp._grad_ivar())
self.assertIsNotNone(l0.weight._grad_ivar())
def test_paddle_imperative_set_grad_enabled(self):
data = np.array([[2, 3], [4, 5]]).astype('float32')
with fluid.dygraph.guard():
l0 = paddle.nn.Linear(2, 2)
self.assertIsNone(l0.weight._grad_ivar())
l1 = paddle.nn.Linear(2, 2)
with paddle.set_grad_enabled(False):
self.assertTrue(l1.weight.stop_gradient is False)
tmp = l1.weight * 2
with paddle.set_grad_enabled(True):
tmp2 = l1.weight * 2
self.assertTrue(tmp.stop_gradient)
self.assertTrue(tmp2.stop_gradient is False)
x = paddle.to_tensor(data)
y = paddle.add(l0(x), tmp2)
o = l1(y)
o.backward()
self.assertIsNone(tmp._grad_ivar())
self.assertIsNotNone(tmp2._grad_ivar())
self.assertIsNotNone(l0.weight._grad_ivar())
def test_paddle_imperative_is_grad_enabled(self):
with fluid.dygraph.guard():
with paddle.set_grad_enabled(False):
self.assertTrue(paddle.is_grad_enabled() is False)
with paddle.set_grad_enabled(True):
self.assertTrue(paddle.is_grad_enabled())
def func_sum_op(self):
x = np.ones([2, 2], np.float32)
with fluid.dygraph.guard():
inputs = []
for _ in range(10):
tmp = paddle.to_tensor(x)
tmp.stop_gradient = False
inputs.append(tmp)
ret = paddle.add_n(inputs)
loss = paddle.sum(ret)
loss.backward()
with fluid.dygraph.guard():
inputs2 = []
for _ in range(10):
tmp = paddle.to_tensor(x)
tmp.stop_gradient = False
inputs2.append(tmp)
ret2 = paddle.add_n(inputs2)
loss2 = paddle.sum(ret2)
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
loss2.backward()
np.testing.assert_allclose(ret.numpy(), x * 10, rtol=1e-05)
np.testing.assert_allclose(inputs[0].gradient(), x, rtol=1e-05)
np.testing.assert_allclose(ret2.numpy(), x * 10, rtol=1e-05)
a = inputs2[0].gradient()
np.testing.assert_allclose(inputs2[0].gradient(), x, rtol=1e-05)
def test_sum_op(self):
with _test_eager_guard():
self.func_sum_op()
self.func_sum_op()
def func_empty_var(self):
with fluid.dygraph.guard():
cur_program = fluid.Program()
cur_block = cur_program.current_block()
# Normally, we don't allow tensor with -1 shape being created in dygraph mode, this test is not good.
if _in_legacy_dygraph():
new_variable = cur_block.create_var(
name="X", shape=[-1, 23, 48], dtype='float32'
)
else:
new_variable = cur_block.create_var(
name="X", shape=[1, 23, 48], dtype='float32'
)
try:
new_variable.numpy()
except Exception as e:
assert type(e) == ValueError
try:
new_variable.backward()
except Exception as e:
assert type(e) == core.EnforceNotMet
try:
new_variable.clear_gradient()
except Exception as e:
assert type(e) == core.EnforceNotMet
def test_empty_var(self):
with _test_eager_guard():
self.func_empty_var()
self.func_empty_var()
def func_empty_grad(self):
with fluid.dygraph.guard():
x = np.ones([2, 2], np.float32)
new_var = paddle.to_tensor(x)
self.assertIsNone(new_var.gradient())
try:
new_var.clear_gradient()
except Exception as e:
assert type(e) == core.EnforceNotMet
with fluid.dygraph.guard():
cur_program = fluid.Program()
cur_block = cur_program.current_block()
# Normally, we don't allow tensor with -1 shape being created in dygraph mode, this test is not good.
if _in_legacy_dygraph():
new_variable = cur_block.create_var(
name="X", shape=[-1, 23, 48], dtype='float32'
)
else:
new_variable = cur_block.create_var(
name="X", shape=[1, 23, 48], dtype='float32'
)
try:
new_variable.gradient()
except Exception as e:
assert type(e) == ValueError
def test_empty_grad(self):
with _test_eager_guard():
self.func_empty_grad()
self.func_empty_grad()
def func_set_persistable(self):
with fluid.dygraph.guard():
x = np.ones([2, 2], np.float32)
new_var = paddle.to_tensor(x)
self.assertFalse(new_var.persistable)
new_var.persistable = True
self.assertTrue(new_var.persistable)
def test_set_persistable(self):
with _test_eager_guard():
self.func_set_persistable()
self.func_set_persistable()
def func_layer(self):
with fluid.dygraph.guard():
l = fluid.Layer("l")
self.assertRaises(NotImplementedError, l.forward, [])
def test_layer(self):
with _test_eager_guard():
self.func_layer()
self.func_layer()
def func_layer_in_out(self):
np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32)
with fluid.dygraph.guard():
var_inp = paddle.to_tensor(np_inp)
var_inp.stop_gradient = False
l = MyLayer()
x = l(var_inp)[0]
self.assertIsNotNone(x)
dy_out = x.numpy()
x.backward()
dy_grad = l._x_for_debug.gradient()
with fluid.dygraph.guard():
var_inp2 = paddle.to_tensor(np_inp)
var_inp2.stop_gradient = False
l2 = MyLayer()
x2 = l2(var_inp2)[0]
self.assertIsNotNone(x2)
dy_out2 = x2.numpy()
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
x2.backward()
dy_grad2 = l2._x_for_debug.gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[3], append_batch_size=False
)
l = MyLayer()
x = l(inp)[0]
param_grads = fluid.backward.append_backward(
x, parameter_list=[l._x_for_debug.name]
)[0]
exe = fluid.Executor(
fluid.CPUPlace()
if not core.is_compiled_with_cuda()
else fluid.CUDAPlace(0)
)
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
fetch_list=[x.name, param_grads[1].name],
)
np.testing.assert_array_equal(dy_out, static_out)
np.testing.assert_array_equal(dy_grad, static_grad)
np.testing.assert_array_equal(dy_out2, static_out)
np.testing.assert_array_equal(dy_grad2, static_grad)
def test_layer_in_out(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
with _test_eager_guard():
self.func_layer_in_out()
self.func_layer_in_out()
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})
def func_mlp(self):
np_inp = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
with fluid.dygraph.guard():
var_inp = paddle.to_tensor(np_inp)
mlp = MLP(input_size=2)
out = mlp(var_inp)
dy_out = out.numpy()
out.backward()
dy_grad = mlp._linear1.weight.gradient()
with fluid.dygraph.guard():
var_inp2 = paddle.to_tensor(np_inp)
mlp2 = MLP(input_size=2)
out2 = mlp2(var_inp2)
dy_out2 = out2.numpy()
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
out2.backward()
dy_grad2 = mlp2._linear1.weight.gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[2, 2], append_batch_size=False
)
mlp = MLP(input_size=2)
out = mlp(inp)
param_grads = fluid.backward.append_backward(
out, parameter_list=[mlp._linear1.weight.name]
)[0]
exe = fluid.Executor(
fluid.CPUPlace()
if not core.is_compiled_with_cuda()
else fluid.CUDAPlace(0)
)
exe.run(fluid.default_startup_program())
static_out, static_grad = exe.run(
feed={inp.name: np_inp},
fetch_list=[out.name, param_grads[1].name],
)
np.testing.assert_allclose(dy_out, static_out, rtol=1e-05)
np.testing.assert_allclose(dy_grad, static_grad, rtol=1e-05)
np.testing.assert_allclose(dy_out2, static_out, rtol=1e-05)
np.testing.assert_allclose(dy_grad2, static_grad, rtol=1e-05)
params = mlp.parameters(True)
self.assertEqual("linear_0.w_0", params[0].name)
self.assertEqual("linear_0.b_0", params[1].name)
self.assertEqual("linear_1.w_0", params[2].name)
self.assertEqual("linear_1.b_0", params[3].name)
self.assertEqual(len(params), 4)
sublayers = mlp.sublayers()
self.assertEqual(mlp._linear1, sublayers[0])
self.assertEqual(mlp._linear2, sublayers[1])
self.assertEqual(len(sublayers), 2)
def test_mlp(self):
with _test_eager_guard():
self.func_mlp()
self.func_mlp()
def test_gradient_accumulation(self):
def test_single_api(sort_sum_gradient):
fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient})
x = paddle.to_tensor(5.0, stop_gradient=False)
for i in range(10):
y = paddle.pow(x, 4.0)
y.backward()
self.assertEqual(x.grad.numpy(), (i + 1) * 500)
x.clear_gradient()
self.assertEqual(x.grad.numpy(), 0.0)
for i in range(10):
y = paddle.pow(x, 4.0)
y.backward()
self.assertEqual(x.grad.numpy(), (i + 1) * 500)
x.clear_grad()
self.assertEqual(x.grad.numpy(), 0.0)
def test_simple_net(sort_sum_gradient):
fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient})
x = paddle.to_tensor(5.0, stop_gradient=False)
y = paddle.to_tensor(2.0, stop_gradient=False)
z = paddle.to_tensor(3.0, stop_gradient=False)
def fun(x, y, z):
loss1 = x * x * y
loss2 = x * z
loss1.backward(retain_graph=True)
loss2.backward(retain_graph=True)
np.testing.assert_array_equal(x.grad.numpy(), [23.0])
np.testing.assert_array_equal(y.grad.numpy(), [25.0])
np.testing.assert_array_equal(z.grad.numpy(), [5.0])
x.clear_grad()
y.clear_grad()
z.clear_grad()
dx = paddle.grad([loss1], x, create_graph=True)[0]
loss = loss1 + loss2 + dx
# loss = x*x*y + x*z + 2*x*y
return loss
loss = fun(x, y, z)
loss.backward(retain_graph=True)
# x.grad = 2*x*y + z + 2*y = 27
np.testing.assert_array_equal(x.grad.numpy(), [27])
loss.backward(retain_graph=True)
np.testing.assert_array_equal(x.grad.numpy(), [54])
loss.backward()
np.testing.assert_array_equal(x.grad.numpy(), [81])
with self.assertRaises(RuntimeError):
loss.backward()
loss1 = x * x * y
loss2 = x * z
dx = paddle.grad([loss1], x, create_graph=True)[0]
loss = loss1 + loss2 + dx
loss.backward()
np.testing.assert_array_equal(dx.grad.numpy(), [1])
np.testing.assert_array_equal(x.grad.numpy(), [108])
def test_mlp(sort_sum_gradient):
fluid.set_flags({'FLAGS_sort_sum_gradient': sort_sum_gradient})
input_size = 5
paddle.seed(1)
mlp1 = MLP(input_size=input_size)
# generate the gradient of each step
mlp2 = MLP(input_size=input_size)
expected_weight1_grad = 0.0
expected_bias1_grad = 0.0
expected_weight2_grad = 0.0
expected_bias2_grad = 0.0
for batch_id in range(100):
x = paddle.uniform([10, input_size])
detach_x = x.detach()
clear_loss = mlp2(detach_x)
clear_loss.backward()
expected_weight1_grad = (
expected_weight1_grad + mlp2._linear1.weight.grad.numpy()
)
expected_bias1_grad = (
expected_bias1_grad + mlp2._linear1.bias.grad.numpy()
)
expected_weight2_grad = (
expected_weight2_grad + mlp2._linear2.weight.grad.numpy()
)
expected_bias2_grad = (
expected_bias2_grad + mlp2._linear2.bias.grad.numpy()
)
loss = mlp1(x)
loss.backward()
np.testing.assert_array_equal(loss.grad.numpy(), [1])
np.testing.assert_allclose(
mlp1._linear1.weight.grad.numpy(),
expected_weight1_grad,
rtol=1e-05,
)
np.testing.assert_allclose(
mlp1._linear1.bias.grad.numpy(),
expected_bias1_grad,
rtol=1e-05,
)
np.testing.assert_allclose(
mlp1._linear2.weight.grad.numpy(),
expected_weight2_grad,
rtol=1e-05,
)
np.testing.assert_allclose(
mlp1._linear2.bias.grad.numpy(),
expected_bias2_grad,
rtol=1e-05,
)
mlp2.clear_gradients()
np.testing.assert_array_equal(clear_loss.grad.numpy(), [1])
if ((batch_id + 1) % 10) % 2 == 0:
mlp1.clear_gradients()
expected_weight1_grad = 0.0
expected_bias1_grad = 0.0
expected_weight2_grad = 0.0
expected_bias2_grad = 0.0
elif ((batch_id + 1) % 10) % 2 == 1:
mlp1.clear_gradients()
mlp1._linear1.weight._set_grad_ivar(
paddle.ones([input_size, 3])
)
mlp1._linear2.weight._set_grad_ivar(paddle.ones([3, 4]))
expected_weight1_grad = 1.0
expected_bias1_grad = 0.0
expected_weight2_grad = 1.0
expected_bias2_grad = 0.0
with fluid.dygraph.guard():
test_single_api(False)
test_single_api(True)
test_simple_net(False)
test_simple_net(True)
test_mlp(False)
test_mlp(True)
def func_dygraph_vs_static(self):
np_inp1 = np.random.rand(4, 3, 3)
np_inp2 = np.random.rand(4, 3, 3)
# dynamic graph
with fluid.dygraph.guard():
inp1 = paddle.to_tensor(np_inp1)
inp2 = paddle.to_tensor(np_inp2)
if np.sum(np_inp1) < np.sum(np_inp2):
x = paddle.add(inp1, inp2)
else:
x = paddle.subtract(inp1, inp2)
dygraph_result = x.numpy()
# static graph
with new_program_scope():
inp_data1 = fluid.layers.data(
name='inp1', shape=[3, 3], dtype=np.float32
)
inp_data2 = fluid.layers.data(
name='inp2', shape=[3, 3], dtype=np.float32
)
a = paddle.expand(
paddle.reshape(paddle.sum(inp_data1), [1, 1]),
[4, -1],
)
b = paddle.expand(
paddle.reshape(paddle.sum(inp_data2), [1, 1]),
[4, -1],
)
cond = paddle.less_than(x=a, y=b)
ie = fluid.layers.IfElse(cond)
with ie.true_block():
d1 = ie.input(inp_data1)
d2 = ie.input(inp_data2)
d3 = paddle.add(d1, d2)
ie.output(d3)
with ie.false_block():
d1 = ie.input(inp_data1)
d2 = ie.input(inp_data2)
d3 = paddle.subtract(d1, d2)
ie.output(d3)
out = ie()
exe = fluid.Executor(
fluid.CPUPlace()
if not core.is_compiled_with_cuda()
else fluid.CUDAPlace(0)
)
static_result = exe.run(
fluid.default_main_program(),
feed={'inp1': np_inp1, 'inp2': np_inp2},
fetch_list=out,
)[0]
np.testing.assert_allclose(dygraph_result, static_result, rtol=1e-05)
def test_dygraph_vs_static(self):
with _test_eager_guard():
self.func_dygraph_vs_static()
self.func_dygraph_vs_static()
def func_rnn(self):
np_inp = np.array(
[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0],
[10.0, 11.0, 12.0],
]
)
np_inp = np_inp.reshape((1, 4, 3))
np_inp = np_inp.astype(np.float32)
with fluid.dygraph.guard():
var_inp = paddle.to_tensor(np_inp)
var_inp = paddle.reshape(var_inp, shape=[1, 4, 3])
simple_rnn = SimpleRNN()
outs, pre_hiddens = simple_rnn.forward(var_inp)
dy_out = outs[3].numpy()
outs[3].backward()
dy_grad_h2o = simple_rnn._cell._h2o_w.gradient()
dy_grad_h2h = simple_rnn._cell._h2h_w.gradient()
dy_grad_i2h = simple_rnn._cell._i2h_w.gradient()
with fluid.dygraph.guard():
var_inp2 = paddle.to_tensor(np_inp)
var_inp2 = paddle.reshape(var_inp2, shape=[1, 4, 3])
simple_rnn2 = SimpleRNN()
outs2, pre_hiddens2 = simple_rnn2.forward(var_inp2)
dy_out2 = outs2[3].numpy()
fluid.set_flags({'FLAGS_sort_sum_gradient': True})
outs2[3].backward()
dy_grad_h2o2 = simple_rnn2._cell._h2o_w.gradient()
dy_grad_h2h2 = simple_rnn2._cell._h2h_w.gradient()
dy_grad_i2h2 = simple_rnn2._cell._i2h_w.gradient()
with new_program_scope():
inp = fluid.layers.data(
name="inp", shape=[1, 4, 3], append_batch_size=False
)
simple_rnn = SimpleRNN()
outs, pre_hiddens = simple_rnn(inp)
param_grads = fluid.backward.append_backward(outs[3])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
(
static_out,
static_grad_h2o,
static_grad_h2h,
static_grad_i2h,
) = exe.run(
feed={inp.name: np_inp},
fetch_list=[
outs[3].name,
param_grads[0][1].name,
param_grads[1][1].name,
param_grads[2][1].name,
],
)
np.testing.assert_array_equal(dy_out, static_out)
np.testing.assert_array_equal(dy_grad_h2o, static_grad_h2o)
np.testing.assert_array_equal(dy_grad_h2h, static_grad_h2h)
np.testing.assert_array_equal(dy_grad_i2h, static_grad_i2h)
np.testing.assert_array_equal(dy_out2, static_out)
np.testing.assert_array_equal(dy_grad_h2o2, static_grad_h2o)
np.testing.assert_array_equal(dy_grad_h2h2, static_grad_h2h)
np.testing.assert_array_equal(dy_grad_i2h2, static_grad_i2h)
def test_rnn(self):
with _test_eager_guard():
self.func_rnn()
self.func_rnn()
def func_layer_attrs(self):
layer = fluid.dygraph.Layer("test")
layer.test_attr = 1
self.assertFalse(hasattr(layer, "whatever"))
self.assertTrue(hasattr(layer, "test_attr"))
self.assertEqual(layer.test_attr, 1)
my_layer = MyLayer()
my_layer.w1 = my_layer.create_parameter([3, 3])
my_layer.add_parameter('w2', None)
self.assertEqual(len(my_layer.parameters()), 1)
self.assertRaises(TypeError, my_layer.__setattr__, 'w1', 'str')
my_layer.w1 = None
self.assertEqual(len(my_layer.parameters()), 0)
my_layer.l1 = paddle.nn.Linear(3, 3)
self.assertEqual(len(my_layer.sublayers()), 1)
self.assertRaises(TypeError, my_layer.__setattr__, 'l1', 'str')
my_layer.l1 = None
self.assertEqual(len(my_layer.sublayers()), 0)
def test_layer_attrs(self):
with _test_eager_guard():
self.func_layer_attrs()
self.func_layer_attrs()
class TestDygraphUtils(unittest.TestCase):
def func_append_activation_in_dygraph_exception(self):
with new_program_scope():
np_inp = np.random.random(size=(10, 20, 30)).astype(np.float32)
a = fluid.layers.data("a", [10, 20])
func = dygraph_utils._append_activation_in_dygraph
self.assertRaises(AssertionError, func, a, act="sigmoid")
def test_append_activation_in_dygraph_exception(self):
with _test_eager_guard():
self.func_append_activation_in_dygraph_exception()
self.func_append_activation_in_dygraph_exception()
def func_append_activation_in_dygraph1(self):
a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
func = dygraph_utils._append_activation_in_dygraph
with fluid.dygraph.guard():
a = paddle.to_tensor(a_np)
res1 = func(a, act="hard_sigmoid")
res2 = paddle.nn.functional.hardsigmoid(a, slope=0.2)
np.testing.assert_array_equal(res1.numpy(), res2.numpy())
def test_append_activation_in_dygraph1(self):
with _test_eager_guard():
self.func_append_activation_in_dygraph1()
self.func_append_activation_in_dygraph1()
def func_append_activation_in_dygraph2(self):
a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
func = dygraph_utils._append_activation_in_dygraph
with fluid.dygraph.guard():
a = paddle.to_tensor(a_np)
res1 = func(a, act="sigmoid", use_mkldnn=True, use_cudnn=True)
res2 = paddle.nn.functional.sigmoid(a)
np.testing.assert_allclose(res1.numpy(), res2.numpy(), rtol=1e-05)
def test_append_activation_in_dygraph2(self):
with _test_eager_guard():
self.func_append_activation_in_dygraph2()
self.func_append_activation_in_dygraph2()
def func_append_activation_in_dygraph3(self):
a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
helper = LayerObjectHelper(fluid.unique_name.generate("test"))
func = helper.append_activation
with fluid.dygraph.guard():
a = paddle.to_tensor(a_np)
res1 = func(a, act="sigmoid", use_cudnn=True)
res2 = paddle.nn.functional.sigmoid(a)
np.testing.assert_array_equal(res1.numpy(), res2.numpy())
def test_append_activation_in_dygraph3(self):
with _test_eager_guard():
self.func_append_activation_in_dygraph3()
self.func_append_activation_in_dygraph3()
def func_append_activation_in_dygraph_use_mkldnn(self):
a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32)
helper = LayerHelper(
fluid.unique_name.generate("test"), act="relu", use_mkldnn=True
)
func = helper.append_activation
with fluid.dygraph.guard():
a = paddle.to_tensor(a_np)
res1 = func(a)
res2 = fluid.layers.relu(a)
np.testing.assert_array_equal(res1.numpy(), res2.numpy())
def test_append_activation_in_dygraph_use_mkldnn(self):
with _test_eager_guard():
self.func_append_activation_in_dygraph_use_mkldnn()
self.func_append_activation_in_dygraph_use_mkldnn()
def func_append_activation_in_dygraph_global_use_mkldnn(self):
a_np = np.random.uniform(-2, 2, (10, 20, 30)).astype(np.float32)
helper = LayerHelper(fluid.unique_name.generate("test"), act="relu")
func = helper.append_activation
with fluid.dygraph.guard(fluid.core.CPUPlace()):
a = paddle.to_tensor(a_np)
fluid.set_flags({'FLAGS_use_mkldnn': True})
try:
res1 = func(a)
finally:
fluid.set_flags({'FLAGS_use_mkldnn': False})
res2 = fluid.layers.relu(a)
np.testing.assert_array_equal(res1.numpy(), res2.numpy())
def test_append_activation_in_dygraph_global_use_mkldnn(self):
with _test_eager_guard():
self.func_append_activation_in_dygraph_global_use_mkldnn()
self.func_append_activation_in_dygraph_global_use_mkldnn()
def func_append_bias_in_dygraph_exception(self):
with new_program_scope():
np_inp = np.random.random(size=(10, 20, 30)).astype(np.float32)
a = fluid.layers.data("a", [10, 20])
func = dygraph_utils._append_bias_in_dygraph
self.assertRaises(AssertionError, func, a)
def test_append_bias_in_dygraph_exception(self):
with _test_eager_guard():
self.func_append_bias_in_dygraph_exception()
self.func_append_bias_in_dygraph_exception()
def func_append_bias_in_dygraph(self):
a_np = np.random.random(size=(10, 20, 30)).astype(np.float32)
func = dygraph_utils._append_bias_in_dygraph
with fluid.dygraph.guard():
a = paddle.to_tensor(a_np)
res1 = func(a, bias=a)
res2 = paddle.add(a, a)
np.testing.assert_array_equal(res1.numpy(), res2.numpy())
def test_append_bias_in_dygraph(self):
with _test_eager_guard():
self.func_append_bias_in_dygraph()
self.func_append_bias_in_dygraph()
class TestDygraphGuardWithError(unittest.TestCase):
def func_without_guard(self):
with fluid.dygraph.guard():
x = paddle.to_tensor(np.zeros([10, 10]))
with self.assertRaisesRegexp(
TypeError, "Please use `with fluid.dygraph.guard()"
):
y = paddle.matmul(x, x)
def test_without_guard(self):
with _test_eager_guard():
self.func_without_guard()
self.func_without_guard()
class TestMetaclass(unittest.TestCase):
def func_metaclass(self):
self.assertEqual(type(MyLayer).__name__, 'type')
self.assertNotEqual(type(MyLayer).__name__, 'pybind11_type')
if not _in_legacy_dygraph():
self.assertEqual(
type(paddle.fluid.core.eager.Tensor).__name__, 'type'
)
else:
self.assertEqual(
type(paddle.fluid.core.VarBase).__name__, 'pybind11_type'
)
def test_metaclass(self):
with _test_eager_guard():
self.func_metaclass()
self.func_metaclass()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -24,6 +24,8 @@ from paddle.fluid import core, unique_name
LOADED_VAR_SUFFIX = ".load_0"
paddle.enable_static()
def while_softmax_regression(img):
def cond(i, times, pred):
......@@ -37,7 +39,7 @@ def while_softmax_regression(img):
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
times = fluid.layers.fill_constant(shape=[1], dtype='int64', value=5)
pred = fluid.layers.fc(input=img, size=10, act='softmax')
i, times, pred = fluid.layers.while_loop(
i, times, pred = paddle.static.nn.while_loop(
cond=cond, body=body, loop_vars=[i, times, pred]
)
return pred
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# nlp model stack of op operate on lod. It's a classical test case in optimize pass.
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid import Program, compiler, program_guard
from paddle.fluid.executor import Executor
from paddle.fluid.optimizer import MomentumOptimizer
class TestIrMemoryOptimizeIfElseOp(unittest.TestCase):
def check_network_convergence(
self, use_cuda=True, use_mem_opt=False, iter_num=5
):
paddle.seed(100)
paddle.framework.random._manual_program_seed(100)
prog = Program()
startup_prog = Program()
with program_guard(prog, startup_prog):
image = layers.data(name='x', shape=[784], dtype='float32')
label = layers.data(name='y', shape=[1], dtype='int64')
limit = layers.fill_constant(shape=[1], dtype='int64', value=5)
cond = paddle.less_than(x=label, y=limit)
ie = layers.IfElse(cond)
with ie.true_block():
true_image = ie.input(image)
hidden = layers.fc(input=true_image, size=100, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
with ie.false_block():
false_image = ie.input(image)
hidden = layers.fc(input=false_image, size=200, act='tanh')
prob = layers.fc(input=hidden, size=10, act='softmax')
ie.output(prob)
prob = ie()
loss = layers.cross_entropy(input=prob[0], label=label)
avg_loss = paddle.mean(loss)
optimizer = MomentumOptimizer(learning_rate=0.001, momentum=0.9)
optimizer.minimize(avg_loss, startup_prog)
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=200
)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = Executor(place)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy._use_device = (
core.DeviceType.CUDA if use_cuda else core.DeviceType.CPU
)
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = use_mem_opt
train_cp = compiler.CompiledProgram(fluid.default_main_program())
train_cp = train_cp.with_data_parallel(
loss_name=avg_loss.name,
exec_strategy=exec_strategy,
build_strategy=build_strategy,
)
fetch_list = [avg_loss.name]
exe.run(startup_prog)
PASS_NUM = 100
loop = 0
ret = []
for pass_id in range(PASS_NUM):
for data in train_reader():
x_data = np.array([x[0] for x in data]).astype("float32")
y_data = np.array([x[1] for x in data]).astype("int64")
y_data = y_data.reshape((y_data.shape[0], 1))
outs = exe.run(
train_cp,
feed={'x': x_data, 'y': y_data},
fetch_list=[avg_loss],
)
loop += 1
ret.append(outs[0])
if iter_num == loop:
return ret
return ret
def test_ifelse(self):
ret1 = self.check_network_convergence(False, True)
print(ret1)
ret2 = self.check_network_convergence(False, False)
print(ret2)
np.testing.assert_allclose(ret1, ret2, rtol=1e-05)
if fluid.core.is_compiled_with_cuda():
ret1 = self.check_network_convergence(True, True)
print(ret1)
ret2 = self.check_network_convergence(True, False)
print(ret2)
np.testing.assert_allclose(ret1, ret2, rtol=1e-05)
if __name__ == "__main__":
unittest.main()
......@@ -1387,7 +1387,7 @@ class TestLayer(LayerTest):
def body(i):
return i + 1
out = layers.while_loop(cond, body, [i])
out = paddle.static.nn.while_loop(cond, body, [i])
static_ret = self.get_static_graph_result(feed={}, fetch_list=out)
with self.dynamic_graph():
......@@ -1400,14 +1400,14 @@ class TestLayer(LayerTest):
def body1(i):
return i + 1
dy_ret = layers.while_loop(cond1, body1, [i])
dy_ret = paddle.static.nn.while_loop(cond1, body1, [i])
with self.assertRaises(ValueError):
j = layers.fill_constant(shape=[1], dtype='int64', value=0)
def body2(i):
return i + 1, i + 2
layers.while_loop(cond1, body2, [j])
paddle.static.nn.while_loop(cond1, body2, [j])
np.testing.assert_array_equal(static_ret[0], dy_ret[0].numpy())
......@@ -1659,10 +1659,12 @@ class TestLayer(LayerTest):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
out_2 = paddle.static.nn.case(
pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]
)
place = (
fluid.CUDAPlace(0)
......@@ -1682,10 +1684,10 @@ class TestLayer(LayerTest):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(
out_2 = paddle.static.nn.case(
pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]
)
eager_dynamic_res1 = out_1.numpy()
......@@ -1699,10 +1701,12 @@ class TestLayer(LayerTest):
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
out_1 = layers.case(
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3
)
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
out_2 = paddle.static.nn.case(
pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)]
)
dynamic_res1 = out_1.numpy()
dynamic_res2 = out_2.numpy()
......@@ -1725,17 +1729,17 @@ class TestLayer(LayerTest):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3,
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3,
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)],
)
......@@ -1759,17 +1763,17 @@ class TestLayer(LayerTest):
shape=[1], dtype='int32', value=2
)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3,
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3,
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)],
)
......@@ -1781,17 +1785,17 @@ class TestLayer(LayerTest):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3,
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3,
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)],
)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy
from paddle.fluid import Program, core, program_guard
from paddle.fluid.executor import Executor
from paddle.fluid.layers import data
from paddle.fluid.layers.control_flow import lod_rank_table
class TestLoDRankTable(unittest.TestCase):
def test_lod_rank_table(self):
x = data(name='x', shape=[100])
cpu = core.CPUPlace()
rank_table = lod_rank_table(x=x, level=1)
rank_table.persistable = True
exe = Executor(cpu)
scope = core.Scope()
tensor = core.LoDTensor()
tensor.set(numpy.random.random(size=(17, 100)), cpu)
tensor.set_recursive_sequence_lengths(
[[1, 2], [5, 1, 1], [3, 1, 5, 1, 3, 3, 1]]
)
exe.run(scope=scope, feed={'x': tensor})
var = scope.find_var(rank_table.name)
table = var.get_lod_rank_table()
self.assertEqual([(0, 5), (1, 1), (2, 1)], list(table.items()))
class TestLoDRankTableError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
x = numpy.random.random((2, 4)).astype("float32")
def test_Variable():
rank_table = lod_rank_table(x=x, level=1)
self.assertRaises(TypeError, test_Variable)
def test_list_Variable():
rank_table = lod_rank_table(x=[x], level=1)
self.assertRaises(TypeError, test_list_Variable)
x = data(name='x', shape=[10], dtype='float32', lod_level=1)
out = lod_rank_table(x=x, level=0)
out = lod_rank_table(x=[x], level=0)
if __name__ == '__main__':
unittest.main()
......@@ -101,7 +101,7 @@ def static(
mod_two = paddle.remainder(id, two) == 0
if loss_in_switch:
avg_loss = layers.case(
avg_loss = paddle.static.nn.case(
[(mod_two, lambda: fn_1(adam, None, prediction, label))],
lambda: fn_2(sgd, None, prediction, label),
)
......@@ -112,7 +112,7 @@ def static(
logits=prediction, label=label
)
avg_loss_2 = paddle.mean(loss_2)
avg_loss = layers.case(
avg_loss = paddle.static.nn.case(
[(mod_two, lambda: fn_1(adam, avg_loss_1))],
lambda: fn_2(sgd, avg_loss_2),
)
......@@ -264,7 +264,7 @@ class TestMultiOptimizersMultiCardsError(unittest.TestCase):
cond = layers.fill_constant([1], 'bool', True)
layers.case(
paddle.static.nn.case(
[(cond, lambda: fn_1(adam, avg_loss))],
lambda: fn_2(sgd, avg_loss),
)
......
......@@ -46,7 +46,7 @@ class TestProfiler(unittest.TestCase):
until = layers.fill_constant([1], dtype='int64', value=10)
data_arr = layers.array_write(hidden1, i)
cond = paddle.less_than(x=counter, y=until)
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
hidden_n = fluid.layers.fc(input=hidden1, size=64, act='relu')
layers.array_write(hidden_n, i, data_arr)
......
......@@ -100,7 +100,7 @@ def cond_net(use_feed=None):
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = two == 0
avg_loss = fluid.layers.case(
avg_loss = paddle.static.nn.case(
[(pred, lambda: loss1(prediction, label))],
lambda: loss2(prediction, label),
)
......@@ -132,7 +132,7 @@ def optimization_in_cond_net(with_optimize=False):
sgd = fluid.optimizer.SGD(learning_rate=0.1)
two = fluid.layers.fill_constant([1], 'int32', 2)
pred = two == 0
avg_loss = fluid.layers.case(
avg_loss = paddle.static.nn.case(
[(pred, lambda: loss1(sgd, prediction, label, with_optimize))],
lambda: loss2(sgd, prediction, label, with_optimize),
)
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.layers.control_flow import lod_rank_table
def convert_to_offset(lod):
offset = [[0] for i in lod]
for i, level in enumerate(lod):
for seq_len in level:
offset[i].append(offset[i][-1] + seq_len)
return offset
class TestReorderLoDTensor(unittest.TestCase):
num_seq = 5
# [name, shape, lod_level] pair indicating data info of source and target
data_desc = (['input', [9], 0], ['ref', [5], 1])
@classmethod
def setUpClass(cls):
cls.set_program()
@classmethod
def set_program(cls):
dat = fluid.layers.data(
name=cls.data_desc[0][0], shape=cls.data_desc[0][1]
)
dat.stop_gradient = False
rank_dat = fluid.layers.data(
name=cls.data_desc[1][0], shape=cls.data_desc[1][1]
)
table = lod_rank_table(rank_dat)
new_dat = fluid.layers.reorder_lod_tensor_by_rank(
x=dat, rank_table=table
)
loss = paddle.sum(new_dat)
fluid.backward.append_backward(loss=loss)
cls.fetch_list = [new_dat, cls.data_desc[0][0] + '@GRAD']
def run_program(self):
outputs = []
input_grads = []
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.set_inputs(place)
exe = fluid.Executor(place)
output, input_grad = exe.run(
fluid.default_main_program(),
feed=self.inputs,
fetch_list=self.fetch_list,
return_numpy=False,
)
outputs.append(output)
input_grads.append(input_grad)
self.actual_outputs = outputs
self.actual_grads = input_grads
def set_data(self):
self.data = {}
for desc in self.data_desc:
data_name = desc[0]
data_shape = desc[1]
data_lod_level = desc[2]
data_lod = []
for i in range(data_lod_level):
lod_level_i = np.random.randint(
low=1,
high=5,
size=self.num_seq
if i == 0
else sum(lod_level_i), # noqa: F821
).tolist()
data_lod.append(lod_level_i)
data_value = np.random.random(
size=[sum(data_lod[-1]) if data_lod else self.num_seq]
+ data_shape
).astype('float32')
self.data[data_name] = (data_value, data_lod)
def set_inputs(self, place):
self.inputs = {}
for desc in self.data_desc:
tensor = fluid.Tensor()
tensor.set(self.data[desc[0]][0], place)
if self.data[desc[0]][1]:
tensor.set_recursive_sequence_lengths(self.data[desc[0]][1])
self.inputs[desc[0]] = tensor
def reorder(self):
level = 0
# compute the rank_table according to ref_lod
ref_lod = self.data[self.data_desc[1][0]][1][level]
rank_table = [] # list of (index, length)
for i in range(len(ref_lod)):
rank_table.append((i, ref_lod[i]))
rank_table = sorted(
rank_table, key=functools.cmp_to_key(lambda x, y: y[1] - x[1])
)
# compute the input sequence info according to input_lod
input_value, input_lod = self.data[self.data_desc[0][0]]
offset_lod = convert_to_offset(input_lod)
input_table = [] # list of (offset, length, sub_lod)
if offset_lod:
for i in range(len(offset_lod[level]) - 1):
start_idx = i
end_idx = i + 1
sub_lod = []
for lod_level_i in offset_lod[level:]:
sub_lod_i = []
for idx in range(start_idx, end_idx):
sub_lod_i.append(
lod_level_i[idx + 1] - lod_level_i[idx]
)
sub_lod.append(sub_lod_i)
start_idx = lod_level_i[start_idx]
end_idx = lod_level_i[end_idx]
input_table.append((start_idx, end_idx - start_idx, sub_lod))
else:
input_table = [(i, 1, []) for i in range(len(rank_table))]
# reorder by rank_table
output_value = np.zeros_like(input_value)
output_lod = []
offset = 0
for index, length in rank_table:
input_seq_start = input_table[index][0]
input_seq_len = input_table[index][1]
input_seq_end = input_seq_start + input_seq_len
output_value[offset : offset + input_seq_len] = input_value[
input_seq_start:input_seq_end
]
offset += input_seq_len
input_seq_sub_lod = input_table[index][2]
if len(output_lod) == 0:
output_lod = [[] for i in input_seq_sub_lod]
for i, level in enumerate(input_seq_sub_lod):
output_lod[i].extend(level)
return output_value, output_lod
def test_reorder_lod_tensor(self):
self.data_desc[0][-1] = 2 # input is lod_tensor
self.set_data()
self.run_program()
# check output
expect_output, expect_output_lod = self.reorder()
for actual_output in self.actual_outputs:
np.testing.assert_allclose(
np.array(actual_output), expect_output, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_output_lod, actual_output.recursive_sequence_lengths()
)
# check gradient
expect_grad = np.ones_like(self.data[self.data_desc[0][0]][0])
expect_grad_lod = self.data[self.data_desc[0][0]][1]
for actual_grad in self.actual_grads:
np.testing.assert_allclose(
np.array(actual_grad), expect_grad, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_grad_lod, actual_grad.recursive_sequence_lengths()
)
def test_reorder_tensor(self):
self.data_desc[0][-1] = 0 # input is tensor
self.set_data()
self.run_program()
# check output
expect_output, expect_output_lod = self.reorder()
for actual_output in self.actual_outputs:
np.testing.assert_allclose(
np.array(actual_output), expect_output, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_output_lod, actual_output.recursive_sequence_lengths()
)
# check gradient
expect_grad = np.ones_like(self.data[self.data_desc[0][0]][0])
expect_grad_lod = self.data[self.data_desc[0][0]][1]
for actual_grad in self.actual_grads:
np.testing.assert_allclose(
np.array(actual_grad), expect_grad, rtol=1e-05, atol=0.001
)
self.assertEqual(
expect_grad_lod, actual_grad.recursive_sequence_lengths()
)
# compare outputs between LodTensors with explicit and implicit lod
# use the same data but set the input lod explicitly
input_lod = [[1] * len(self.data[self.data_desc[0][0]][0])]
self.inputs[self.data_desc[0][0]].set_recursive_sequence_lengths(
input_lod
)
# preserve the output of LodTensor with implicit lod to compare
expect_outputs = [
np.array(actual_output) for actual_output in self.actual_outputs
]
self.run_program()
for actual_output, expect_output in zip(
self.actual_outputs, expect_outputs
):
np.testing.assert_allclose(
np.array(actual_output), expect_output, rtol=1e-05, atol=0.001
)
class TestReorderLoDTensorError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
def test_Variable():
# The input must be Variable.
x1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
table1 = np.array([0.9383, 0.1983, 3.2, 1.2]).astype("float64")
new_dat = fluid.layers.reorder_lod_tensor_by_rank(
x=x1, rank_table=table1
)
self.assertRaises(TypeError, test_Variable)
def test_type():
x2 = fluid.layers.data(name='x1', shape=[4], dtype='float32')
table2 = fluid.layers.data(
name='table2', shape=[4], dtype='int32'
)
new_dat2 = fluid.layers.reorder_lod_tensor_by_rank(
x=x2, rank_table=table2
)
self.assertRaises(TypeError, test_type)
if __name__ == '__main__':
unittest.main()
......@@ -156,7 +156,7 @@ class TestSetValueItemSliceInWhile(TestSetValueApi):
return i, x
i = paddle.zeros(shape=(1,), dtype='int32')
i, x = paddle.fluid.layers.while_loop(cond, body, [i, x])
i, x = paddle.static.nn.while_loop(cond, body, [i, x])
def _get_answer(self):
self.data[0] = self.value
......
......@@ -17,11 +17,14 @@ from functools import partial
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
paddle.enable_static()
class TestAPISwitchCase(unittest.TestCase):
def test_return_single_var(self):
......@@ -42,29 +45,29 @@ class TestAPISwitchCase(unittest.TestCase):
index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
# call fn_1
out_0 = layers.switch_case(
out_0 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
# call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3)
)
# call default fn_3
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=((1, fn_1), (2, fn_2)),
default=fn_3,
)
# no default, call fn_2
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)]
)
# no default, call fn_2 but branch_index is 5
out_4 = layers.switch_case(
out_4 = paddle.static.nn.switch_case(
branch_index=index_5,
branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)],
)
......@@ -132,7 +135,9 @@ class TestAPISwitchCase(unittest.TestCase):
with program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
out = layers.switch_case(index_1, ((1, fn_1), (2, fn_2)), fn_3)
out = paddle.static.nn.switch_case(
index_1, ((1, fn_1), (2, fn_2)), fn_3
)
place = (
fluid.CUDAPlace(0)
......@@ -153,7 +158,7 @@ class TestAPISwitchCase(unittest.TestCase):
class TestAPISwitchCase_Nested(unittest.TestCase):
def test_nested_switch_case(self):
def fn_1(x=1):
out = layers.switch_case(
out = paddle.static.nn.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=x
),
......@@ -169,7 +174,7 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
return out
def fn_2(x=2):
out = layers.switch_case(
out = paddle.static.nn.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=2
),
......@@ -186,7 +191,7 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
return out
def fn_3():
out = layers.switch_case(
out = paddle.static.nn.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=3
),
......@@ -209,14 +214,14 @@ class TestAPISwitchCase_Nested(unittest.TestCase):
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3)
out_1 = layers.switch_case(
out_1 = paddle.static.nn.switch_case(
branch_index=index_1, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
out_2 = layers.switch_case(
out_2 = paddle.static.nn.switch_case(
branch_index=index_2, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
out_3 = layers.switch_case(
out_3 = paddle.static.nn.switch_case(
branch_index=index_3, branch_fns={1: fn_1, 2: fn_2, 3: fn_3}
)
......@@ -277,7 +282,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The type of 'branch_index' in Op(switch_case) must be Variable
def type_error_branch_index():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=1, branch_fns=[(1, fn_1)], default=fn_3
)
......@@ -285,7 +290,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8
def dtype_error_branch_index():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_float32,
branch_fns=[(1, fn_1)],
default=fn_3,
......@@ -295,7 +300,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict
def type_error_branch_fns():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=1, default=fn_3
)
......@@ -303,7 +308,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The elements' type of 'branch_fns' in Op(switch_case) must be tuple
def type_error_index_fn_pair_1():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=[1], default=fn_3
)
......@@ -311,7 +316,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The tuple's size of 'branch_fns' in Op(switch_case) must be 2
def type_error_index_fn_pair_2():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=[(1, 2, 3)], default=fn_3
)
......@@ -319,7 +324,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The key's type of 'branch_fns' in Op(switch_case) must be int
def type_error_key():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32, branch_fns=[(2.3, 2)], default=fn_3
)
......@@ -327,7 +332,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The key in 'branch_fns' must be unique
def value_error_key():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32,
branch_fns=[(2, fn_1), (2, fn_2)],
default=fn_3,
......@@ -337,7 +342,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The type of function in 'branch_fns' must be callable
def type_error_fn():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32,
branch_fns=[(1, 1), (2, fn_2)],
default=fn_3,
......@@ -347,7 +352,7 @@ class TestAPISwitchCase_Error(unittest.TestCase):
# The default in Op(case) must be callable
def type_error_default():
layers.switch_case(
paddle.static.nn.switch_case(
branch_index=key_int32,
branch_fns=[(1, fn_1), (2, fn_2)],
default=1,
......
......@@ -21,6 +21,8 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
paddle.enable_static()
class TestTensorArrayToTensorError(unittest.TestCase):
"""Tensor_array_to_tensor error message enhance"""
......@@ -288,7 +290,9 @@ class TestTensorArrayToTensorAPI(unittest.TestCase):
fluid.layers.array_write(prev, i, array)
return i + 1, end, array
_, _, array = fluid.layers.while_loop(cond, body, [i, ten, array])
_, _, array = paddle.static.nn.while_loop(
cond, body, [i, ten, array]
)
self.assertTrue(paddle.tensor.array_length(array), 10)
last = fluid.layers.fill_constant(shape=[1], dtype='int64', value=9)
......
......@@ -40,7 +40,7 @@ class TestApiWhileLoop(unittest.TestCase):
i = layers.fill_constant(shape=[1], dtype='int64', value=0)
one = layers.fill_constant(shape=[1], dtype='int64', value=1)
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
out = layers.while_loop(cond, body, (i,))
out = paddle.static.nn.while_loop(cond, body, (i,))
place = (
fluid.CUDAPlace(0)
......@@ -69,7 +69,7 @@ class TestApiWhileLoop(unittest.TestCase):
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
mem = fluid.data(name='mem', shape=[10], dtype='float32')
one = layers.fill_constant(shape=[10], dtype='float32', value=1)
out = layers.while_loop(cond, body, [i, mem])
out = paddle.static.nn.while_loop(cond, body, [i, mem])
data = np.random.rand(10).astype('float32')
data_one = np.ones(10).astype('float32')
......@@ -122,7 +122,13 @@ class TestApiWhileLoop(unittest.TestCase):
}
]
i, ten, test_dict, test_list, test_list_dict = layers.while_loop(
(
i,
ten,
test_dict,
test_list,
test_list_dict,
) = paddle.static.nn.while_loop(
cond, body, [i, ten, test_dict, test_list, test_list_dict]
)
place = (
......@@ -171,7 +177,7 @@ class TestApiWhileLoop_Nested(unittest.TestCase):
j = layers.increment(j)
return [j, init, sums]
result = layers.while_loop(
result = paddle.static.nn.while_loop(
internal_cond, internal_body, [j, init, sums]
)
j = result[0]
......@@ -192,7 +198,7 @@ class TestApiWhileLoop_Nested(unittest.TestCase):
loop_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
ones = layers.fill_constant(shape=[3, 3], dtype='float32', value=1)
out = layers.while_loop(
out = paddle.static.nn.while_loop(
external_cond, external_body, [i, j, init, sums]
)
......@@ -236,7 +242,7 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
x = fluid.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
out = layers.while_loop(cond, body, [i, x])
out = paddle.static.nn.while_loop(cond, body, [i, x])
mean = paddle.mean(out[1])
append_backward(mean)
......@@ -277,7 +283,7 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
x = fluid.data(name='x', shape=[1], dtype='float32')
x.stop_gradient = False
out = layers.while_loop(cond, body, [i, x])
out = paddle.static.nn.while_loop(cond, body, [i, x])
mean = paddle.mean(out[1])
append_backward(mean)
......@@ -328,7 +334,7 @@ class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
outer_sum_1 = paddle.add(x=x, y=outer_sum_0)
i = layers.increment(x=i, in_place=True)
layers.array_write(outer_sum_1, i=i, array=mem_array)
j, x, mem_array = layers.while_loop(
j, x, mem_array = paddle.static.nn.while_loop(
internal_cond, internal_body, [j, x, mem_array]
)
return [i, j, x, mem_array]
......@@ -357,7 +363,7 @@ class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
out = layers.while_loop(
out = paddle.static.nn.while_loop(
external_cond, external_body, [i, j, x, mem_array]
)
......@@ -405,7 +411,7 @@ class TestApiWhileLoopWithSwitchCase(unittest.TestCase):
data_add_one = paddle.add(x=i, y=one)
return data_add_one
return layers.switch_case(
return paddle.static.nn.switch_case(
branch_index=i,
branch_fns={2: fn_add_three, 5: fn_square},
default=fn_add_one,
......@@ -418,7 +424,7 @@ class TestApiWhileLoopWithSwitchCase(unittest.TestCase):
ten = layers.fill_constant(shape=[1], dtype='int64', value=10)
three = layers.fill_constant(shape=[1], dtype='int64', value=3)
one = layers.fill_constant(shape=[1], dtype='int64', value=1)
out = layers.while_loop(cond, body, [i])
out = paddle.static.nn.while_loop(cond, body, [i])
place = (
fluid.CUDAPlace(0)
......@@ -488,13 +494,13 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The type of `cond` in Op(while_loop) must be callable
def type_error_cond():
out = layers.while_loop(data, body, [data_1d])
out = paddle.static.nn.while_loop(data, body, [data_1d])
self.assertRaises(TypeError, type_error_cond)
# The type of `body` in Op(while_loop) must be callable
def type_error_body():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, data, [data_1d]
)
......@@ -502,25 +508,31 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The type of `loop_vars` in Op(while_loop) must be list or tuple
def type_error_loop_vars():
out = layers.while_loop(cond_returns_bool_tensor, body, data_1d)
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, body, data_1d
)
self.assertRaises(TypeError, type_error_loop_vars)
# The value of `loop_vars` is empty
def value_error_loop_vars():
out = layers.while_loop(cond_returns_bool_tensor, body, [])
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, body, []
)
self.assertRaises(ValueError, value_error_loop_vars)
# The type of `cond` returns in Op(while_loop) must be Variable
def type_error_cond_returns_not_variable():
out = layers.while_loop(cond_returns_constant, body, [data_1d])
out = paddle.static.nn.while_loop(
cond_returns_constant, body, [data_1d]
)
self.assertRaises(TypeError, type_error_cond_returns_not_variable)
# The type of `cond` returns in Op(while_loop) must be a bollean variable
def type_error_cond_returns_not_boolean():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_not_bool_tensor, body, [data_1d]
)
......@@ -528,13 +540,15 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The shape of `cond` returns in Op(while_loop) must be 1
def type_error_shape_cond_returns_2d():
out = layers.while_loop(cond_returns_2d_tensor, body, [data_2d])
out = paddle.static.nn.while_loop(
cond_returns_2d_tensor, body, [data_2d]
)
self.assertRaises(TypeError, type_error_shape_cond_returns_2d)
# The length of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_length():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_bool_tensor, body_returns_error_length, [data]
)
......@@ -542,7 +556,7 @@ class TestApiWhileLoop_Error(unittest.TestCase):
# The type of `body` returns in Op(while_loop) must be same as `loop_vars`
def value_error_body_returns_error_type():
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_receives_two_args, body_returns_error_type, [data, ten]
)
......@@ -555,7 +569,7 @@ class TestApiWhileLoop_Error(unittest.TestCase):
shape=[2, 2], dtype='int64', value=1
)
}
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_with_mutable_dict,
body_returns_with_mutable_dict,
[data, test_dict],
......@@ -569,7 +583,7 @@ class TestApiWhileLoop_Error(unittest.TestCase):
test_list = [
layers.fill_constant(shape=[2, 2], dtype='int64', value=1)
]
out = layers.while_loop(
out = paddle.static.nn.while_loop(
cond_returns_with_mutable_list,
body_returns_with_mutable_list,
[data, test_list],
......@@ -597,7 +611,7 @@ class TestApiWhileLoopSliceInBody(unittest.TestCase):
z = fluid.layers.fill_constant([1], 'int32', 0)
x_shape = paddle.shape(x)
i = fluid.layers.fill_constant([1], 'int32', 0)
z, _ = fluid.layers.while_loop(cond, body, [z, i])
z, _ = paddle.static.nn.while_loop(cond, body, [z, i])
place = (
fluid.CUDAPlace(0)
......
......@@ -56,8 +56,8 @@ class TestWhileOp(unittest.TestCase):
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = paddle.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......@@ -122,10 +122,10 @@ class TestWhileOp(unittest.TestCase):
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
cond = paddle.less_than(x=i, y=array_len)
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
cond = layers.cast(cond, dtype='float64')
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
class BadInputTest(unittest.TestCase):
......@@ -157,7 +157,7 @@ class TestIgnoreVarNameInWhile(unittest.TestCase):
i = layers.fill_constant(shape=[1], value=0, dtype='int32')
num = layers.fill_constant(shape=[1], value=5, dtype='int32')
i, ten, shuffle_temp, y = layers.while_loop(
i, ten, shuffle_temp, y = paddle.static.nn.while_loop(
cond, body_func, [i, num, temp, y]
)
......
......@@ -159,7 +159,7 @@ class TestDeviceGuard(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with paddle.static.device_guard("cpu"):
while_op = fluid.layers.While(cond=cond)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
i = paddle.increment(x=i, value=1)
paddle.assign(paddle.less_than(x=i, y=loop_len), cond)
......
......@@ -55,8 +55,8 @@ class TestWhileOp(unittest.TestCase):
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = paddle.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
while_op = paddle.static.nn.control_flow.While(cond=cond)
while_op2 = paddle.static.nn.control_flow.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
......@@ -121,10 +121,10 @@ class TestWhileOp(unittest.TestCase):
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
cond = paddle.less_than(x=i, y=array_len)
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
cond = layers.cast(cond, dtype='float64')
with self.assertRaises(TypeError):
layers.While(cond=cond)
paddle.static.nn.control_flow.While(cond=cond)
if __name__ == '__main__':
......
......@@ -21,10 +21,14 @@ from .common import deform_conv2d # noqa: F401
from .common import conv3d # noqa: F401
from .common import conv2d_transpose # noqa: F401
from .common import conv3d_transpose # noqa: F401
from .control_flow import (
case,
while_loop,
switch_case,
)
from .common import bilinear_tensor_product # noqa: F401
from .common import py_func # noqa: F401
from ...tensor.creation import create_parameter # noqa: F401
from ...fluid.layers import case # noqa: F401
from ...fluid.layers import cond # noqa: F401
from ...fluid.layers import conv2d # noqa: F401
from ...fluid.layers import crf_decoding # noqa: F401
......@@ -34,8 +38,6 @@ from .loss import nce # noqa: F401
from .common import prelu # noqa: F401
from ...fluid.layers import row_conv # noqa: F401
from ...fluid.layers import spectral_norm # noqa: F401
from ...fluid.layers import switch_case # noqa: F401
from ...fluid.layers import while_loop # noqa: F401
from ...fluid.input import embedding # noqa: F401
from ...fluid.contrib.layers import sparse_embedding # noqa: F401
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from functools import partial, reduce
import paddle
import paddle.fluid.core as core
from paddle.common_ops_import import (
LayerHelper,
_non_static_mode,
check_type,
check_variable_and_dtype,
convert_dtype,
)
from paddle.fluid.framework import Operator, Program, Variable
# Temporary solution, it will be deleted later
from paddle.fluid.layers.control_flow import cond
from paddle.fluid.layers.utils import (
assert_same_structure,
copy_mutable_vars,
hold_mutable_vars,
is_sequence,
map_structure,
)
class BlockGuard:
"""
BlockGuard class.
BlockGuard class is used to create a sub-block in a program by
using the Python `with` keyword.
"""
def __init__(self, main_program):
if not isinstance(main_program, Program):
raise TypeError("BlockGuard takes a program")
self.main_program = main_program
def __enter__(self):
self.main_program._create_block()
def __exit__(self, exc_type, exc_val, exc_tb):
self.main_program._rollback()
if exc_type is not None:
return False # re-raise exception
return True
class WhileGuard(BlockGuard):
def __init__(self, while_op):
if not isinstance(while_op, While):
raise TypeError("WhileGuard takes a while op")
super().__init__(while_op.helper.main_program)
self.while_op = while_op
def __enter__(self):
self.while_op.status = While.IN_WHILE_BLOCK
return super().__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
self.while_op.status = While.AFTER_WHILE_BLOCK
self.while_op._complete()
return super().__exit__(exc_type, exc_val, exc_tb)
def get_inputs_outputs_in_block(
current_block, inner_inputs, inner_outputs, helper
):
"""
Find inputs and outputs in current control flow block.
:param current_block: Current control flow block.
:param inner_inputs: Input var name of ops in current block.
:param inner_outputs: Output var name of ops in current block.
:return: inner_inputs, inner_outputs
"""
def is_ignore_vars(op, var_name):
# NOTE(dev): There are some persistable var created in some non-standard API
# such as "contrib.layers.shuffle_batch". It create a "Seed" used both in
# Input and Output. This var shall not be considered as a loop_var in
# control_flow.
IGNORE_VAR_NAMES = {"shuffle_batch": ["shuffle_batch_seed"]}
if op.type in IGNORE_VAR_NAMES:
var_names = IGNORE_VAR_NAMES[op.type]
for name in var_names:
if name in var_name:
return True
return False
# Step1: update inner_inputs and inner_outputs
# NOTE: Here assumes that all variables are input or output of Ops,
# but some variables are created without appendding a real op.
# For example, in `arr = create_array(dtype)`, `arr` is not a output of a op.
for op in current_block.ops:
assert isinstance(op, Operator)
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in inner_outputs and not is_ignore_vars(
op, in_var_name
):
inner_inputs.add(in_var_name)
for oname in op.output_names:
for out_var_name in op.output(oname):
inner_outputs.add(out_var_name)
# Step2: Remove LOD_TENSOR_ARRAY created in current control flow block.
remove_inner_inputs = set()
parent_block = helper.main_program.block(current_block.parent_idx)
for in_var_name in inner_inputs:
parent_block_var = parent_block._find_var_recursive(in_var_name)
current_block_var = None
if current_block.has_var(in_var_name):
current_block_var = current_block.var(in_var_name)
if (
not parent_block_var
and current_block_var
and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
remove_inner_inputs.add(in_var_name)
inner_inputs = inner_inputs - remove_inner_inputs
return inner_inputs, inner_outputs
class While:
"""
:api_attr: Static Graph
while loop control flow. Repeat while body until cond is False.
Note:
A new OP :ref:`api_fluid_layers_while_loop` is highly recommended instead of ``While`` if the shape of parameter ``cond`` is [1].
OP :ref:`api_fluid_layers_while_loop` is easier to use and is called with less code but does the same thing as ``While`` .
Notice:
Local variables created in ``While`` are similar to that created in while of C++, and cannot be referenced externally.
As a result, they cannot be obtained through ``fetch_list`` of ``Executor``. If you would like to access the variable
out of ``while`` , PaddlePaddle provides ``assign`` API to assign local variables to external. Please refer to example
code 2 or refer to `issue#22724 <https://github.com/PaddlePaddle/Paddle/issues/22724>`_.
Args:
cond(Variable): A Tensor whose data type is bool controlling whether to continue looping.
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` .
Examples 1:
.. code-block:: python
import paddle
import numpy as np
paddle.enable_static()
i = paddle.full(shape=[1], dtype='int64', fill_value=0) # loop counter
loop_len = paddle.full(shape=[1],dtype='int64', fill_value=10) # loop length
cond = paddle.less_than(x=i, y=loop_len)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
i = paddle.increment(x=i, value=1)
paddle.assign(paddle.less_than(x=i, y=loop_len), output=cond)
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
res = exe.run(paddle.static.default_main_program(), feed={}, fetch_list=[i])
print(res) # [array([10])]
Examples 2:
.. code-block:: python
import paddle
import numpy as np
paddle.enable_static()
i = paddle.full(shape=[1], dtype='int64', fill_value=0)
loop_len = paddle.full(shape=[1], dtype='int64', fill_value=10)
one = paddle.full(shape=[1], dtype='float32', fill_value=1)
data = paddle.static.data(name='data', shape=[1], dtype='float32')
sums = paddle.full(shape=[1], dtype='float32', fill_value=0) # Define the variable to be obtained ouside of While, which name should be different from the variable inside the While to be obtained
cond = paddle.less_than(x=i, y=loop_len)
while_op = paddle.static.nn.control_flow.While(cond=cond)
with while_op.block():
sums_tensor = paddle.add(x=data, y=data)
paddle.assign(sums_tensor, sums) # Update the value of sums_tensor defined in While to the sums which defined outside of While through layers.assign
i = paddle.increment(x=i, value=1)
data = paddle.add(x=data, y=one)
paddle.assign(paddle.less_than(x=i, y=loop_len), output=cond)
feed_data = np.ones(1).astype('float32')
exe = paddle.static.Executor(paddle.CPUPlace())
exe.run(paddle.static.default_startup_program())
res = exe.run(paddle.static.default_main_program(), feed={'data': feed_data}, fetch_list=sums)
print(res[0]) # [2.] # Because the data in While does not update the value outside the While, the value of sums is [2.] after the loop
"""
BEFORE_WHILE_BLOCK = 0
IN_WHILE_BLOCK = 1
AFTER_WHILE_BLOCK = 2
def __init__(self, cond, is_test=False, name=None):
self.helper = LayerHelper("while", name=name)
self.status = While.BEFORE_WHILE_BLOCK
check_variable_and_dtype(cond, 'cond', ['bool'], 'static.nn.While')
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError(
"condition expected shape as [1], but given shape as {0}.".format(
list(cond.shape)
)
)
self.cond_var = cond
self.is_test = is_test
def block(self):
return WhileGuard(self)
def _complete(self):
main_program = self.helper.main_program
while_block = main_program.current_block()
parent_block = main_program.block(
main_program.current_block().parent_idx
)
inner_outputs = {self.cond_var.name}
x_name_list = set()
x_name_list, inner_outputs = get_inputs_outputs_in_block(
while_block, x_name_list, inner_outputs, self.helper
)
out_vars = []
for inner_out_name in inner_outputs:
inner_var = parent_block._find_var_recursive(inner_out_name)
if inner_var:
out_vars.append(inner_var)
x_name_list |= set(map(lambda x: x.name, out_vars))
# NOTE(dev): cond_var has been contained in Input('Condition'), so
# we remove it from Input('X')
x_name_list -= {self.cond_var.name}
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES
)
parent_block.append_op(
type='while',
inputs={
'X': [
parent_block._var_recursive(x_name)
for x_name in x_name_list
],
'Condition': [self.cond_var],
},
outputs={'Out': out_vars, 'StepScopes': [step_scope]},
attrs={'sub_block': while_block, "is_test": self.is_test},
)
support_ret_buildin_type = (bool, float, int)
def assign_skip_lod_tensor_array(input, output):
"""
Assign input to output, but skip the process of copying LoDTensorArray unless it's created in while_block.
"""
def has_shape_diff(x_var, y_var):
if len(x_var.shape) != len(y_var.shape):
return True
for x_dim, y_dim in zip(x_var.shape, y_var.shape):
if x_dim != y_dim and -1 not in [x_dim, y_dim]:
return True
return False
if not isinstance(input, (Variable, core.VarBase)):
if isinstance(output, Variable) and isinstance(
input, support_ret_buildin_type
):
paddle.assign(input, output)
else:
output = input
return
if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
main_program = input.block.program
parent_block = main_program.block(
main_program.current_block().parent_idx
)
if parent_block and not parent_block._find_var_recursive(input.name):
paddle.assign(input, output)
else:
if (
isinstance(output, Variable)
and isinstance(input, Variable)
and has_shape_diff(input, output)
):
warnings.warn(
"In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
input.shape, output.shape
)
)
paddle.assign(input, output)
def while_loop(cond, body, loop_vars, is_test=False, name=None):
"""
:api_attr: Static Graph
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
Notice:
Local variables defined in ``body`` cannot be obtained through ``fetch_list`` of ``Executor`` , variables should
be defined outside ``body`` and placed in ``loop_vars`` for looping, then these variables can be fetched by ``fetch_list`` .
Args:
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
as many arguments as ``loop_vars`` .
body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
(length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
name(str, optional): Normally there is no need for users to set this property. For more information, please
refer to :ref:`api_guide_Name`. Default is None.
Returns:
A list or tuple of Tensors or LoDTensorArrays which returned by ``body`` .
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
def cond(i, ten):
return i < ten
def body(i, ten):
i = i + 1
return [i, ten]
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.full(shape=[1], fill_value=0, dtype='int64') # loop counter
ten = paddle.full(shape=[1], fill_value=10, dtype='int64') # loop length
i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
exe = paddle.static.Executor(paddle.CPUPlace())
res = exe.run(main_program, feed={}, fetch_list=[i])
print(res) # [array([10])]
"""
helper = LayerHelper('while_loop', **locals())
if not callable(cond):
raise TypeError("cond in while_loop should be callable")
if not callable(body):
raise TypeError("body in while_loop should be callable")
check_type(loop_vars, 'loop_vars', (list, tuple), 'static.nn.while_loop')
if len(loop_vars) == 0:
raise ValueError("loop_vars in while_loop should not be empty")
pre_cond = cond(*loop_vars)
check_variable_and_dtype(
pre_cond, 'var of cond returned', ['bool'], 'static.nn.while_loop'
)
if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
raise TypeError(
"the shape of the variable returned by cond should be [1],"
"but given shape as {0}.".format(list(pre_cond.shape))
)
if _non_static_mode():
now_cond = pre_cond.numpy()[0]
while now_cond:
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
if len(output_vars) != len(loop_vars):
raise ValueError(
"body in while_loop should return the same arity "
"(length and structure) and types as loop_vars"
)
now_cond = cond(*output_vars).numpy()[0]
map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
return loop_vars
while_loop_block = While(pre_cond, is_test, name)
has_mutable_vars_in_loop = hold_mutable_vars(loop_vars)
with while_loop_block.block():
# If a variable with mutable type is included in loop_vars, like `dict/list`,
# modifying it in the body function will cause origin variable to be modified
# synchronously. This will raise an assignment error out of while block.
# Here we make a copy of the mutable vars to avoid this problem.
if has_mutable_vars_in_loop:
new_loop_vars = copy_mutable_vars(loop_vars)
output_vars = body(*new_loop_vars)
else:
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
try:
loop_vars = _deal_with_undefined_var(output_vars, loop_vars)
assert_same_structure(output_vars, loop_vars, check_types=False)
except ValueError as e:
raise ValueError(
"body in while_loop should return the same arity "
"(length and structure) as loop_vars: {0}".format(e)
)
now_cond = cond(*output_vars)
map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
paddle.assign(now_cond, pre_cond)
return loop_vars
def _deal_with_undefined_var(output_vars, loop_vars):
"""Deal with undefined var cases, We create undefined variable based on the results of body().
In Dy2Static, we use undefined var to represent the var created in control flow. This function
expand the loop_vars and replace original loop_vars.
1. UndefinedVar = Variable # create a variable
2. UndefinedVar = None # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
3. UndefinedVar = List(int) # create a list of variable
4. UndefinedVar = value # create a variable
"""
from paddle.jit.dy2static.utils import (
UndefinedVar,
create_undefined_variable,
)
def create_var_like(o_var):
if (
isinstance(o_var, (Variable,) + support_ret_buildin_type)
or o_var is None
):
return create_undefined_variable()
if is_sequence(o_var):
"""
Create a complex container class inside the body of while, including Python list and python Dict
"""
return map_structure(lambda x: create_undefined_variable(), o_var)
if len(output_vars) != len(loop_vars):
raise ValueError("The length of loop_vars should be the same.")
results = []
for o_var, l_var in zip(output_vars, loop_vars):
if isinstance(l_var, UndefinedVar) or l_var is None:
results.append(create_var_like(o_var))
else:
results.append(l_var)
return results
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = (
"{what} of '{arg_name}' in {op_name} must be "
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
op_name=op_name,
right_value=right_value,
error_value=error_value,
)
)
return error_message
def case(pred_fn_pairs, default=None, name=None):
'''
:api_attr: Static Graph
This operator works like an if-elif-elif-else chain.
Args:
pred_fn_pairs(list|tuple): A list or tuple of (pred, fn) pairs. ``pred`` is a boolean Tensor with shape [1], ``fn`` is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor|list(Tensor): Tensors returned by the callable from the first pair whose pred is True,
or Tensors returned by ``default`` if no pred in ``pred_fn_pairs`` is True and ``default`` is not None,
or Tensors returned by the last callable in ``pred_fn_pairs`` if no pred in ``pred_fn_pairs`` is True and ``default`` is None.
Raises:
TypeError: If the type of ``pred_fn_pairs`` is not list or tuple.
TypeError: If the type of elements in ``pred_fn_pairs`` is not tuple.
TypeError: If the size of tuples in ``pred_fn_pairs`` is not 2.
TypeError: If the first element of 2-tuple in ``pred_fn_pairs`` is not a Tensor.
TypeError: If the second element of 2-tuple in ``pred_fn_pairs`` is not callable.
TypeError: If ``default`` is not None but it is not callable.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
def fn_1():
return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)
def fn_2():
return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[3], dtype='int32', fill_value=3)
main_program = paddle.static.default_startup_program()
startup_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.full(shape=[1], dtype='float32', fill_value=0.3)
y = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
z = paddle.full(shape=[1], dtype='float32', fill_value=0.2)
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
# Call fn_1 because pred_1 is True
out_1 = paddle.static.nn.case(
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
# Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
# because fn_3 is the last callable in pred_fn_pairs.
out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
exe = paddle.static.Executor(paddle.CPUPlace())
res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
print(res_1) # [[1. 1.]]
print(res_2) # [3 3 3]
'''
helper = LayerHelper('case', **locals())
def _case_check_args(pred_fn_pairs, default):
'''
Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
'''
check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case')
for pred_fn in pred_fn_pairs:
if not isinstance(pred_fn, tuple):
raise TypeError(
_error_message(
"The elements' type",
"pred_fn_pairs",
"case",
tuple,
type(pred_fn),
)
)
if len(pred_fn) != 2:
raise TypeError(
_error_message(
"The tuple's size",
"pred_fn_pairs",
"case",
"2",
str(len(pred_fn)) + "-tuple",
)
)
pred, fn = pred_fn
if not isinstance(pred, Variable):
raise TypeError(
_error_message(
"The pred's type",
"pred_fn_pairs",
"case",
"boolean Variable",
type(pred),
)
)
if not callable(fn):
raise TypeError(
"The fn for {} of pred_fn_pairs in Op(case) must"
" be callable.".format(pred.name)
)
if default is None:
default_index = len(pred_fn_pairs) - 1 # pick the last one
default = pred_fn_pairs[default_index][1]
pred_fn_pairs = pred_fn_pairs[:default_index]
elif not callable(default):
raise TypeError("The default in Op(case) must be callable.")
return pred_fn_pairs, default
pred_fn_pairs, default = _case_check_args(pred_fn_pairs, default)
false_fn = default
for pred, true_fn in reversed(pred_fn_pairs):
false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)
final_fn = false_fn
return final_fn()
def switch_case(branch_index, branch_fns, default=None, name=None):
'''
:api_attr: Static Graph
This operator is like a C++ switch/case statement.
Args:
branch_index(Tensor): A Tensor with shape [1] to specify which branch to execute. The data type is ``int32``, ``int64`` or ``uint8``.
branch_fns(dict|list|tuple): If it's a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it's a dict, its key is a python integer and the value is a callable. All callables return the same structure of Tensors.
default(callable, optional): Callable that returns a structure of Tensors.
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor|list(Tensor): Tensors returned by the callable specified by ``branch_index`` in ``branch_fns``,
or Tensors returned by ``default`` if ``default`` is not None and no index matches in ``branch_fns``,
or Tensors returned by the callable with the max index in ``branch_fns`` if ``default`` is None and no index matches in ``branch_fns``.
Raises:
TypeError: If the type of ``branch_index`` is not Tensor.
TypeError: If the data type of ``branch_index`` is not ``int32``, ``int64`` or ``uint8``.
TypeError: If the type of ``branch_fns`` is not dict, list or tuple.
TypeError: If the elements of ``branch_fns`` is not 2-tuple.
TypeError: If the first element of 2-tuple in ``branch_fns`` is not integer.
ValueError: If the first element of 2-tuple in ``branch_fns`` is not unique.
TypeError: If the second element of 2-tuple in ``branch_fns`` is not callable.
TypeError: If ``default`` is not None but it is not callable.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
def fn_1():
return paddle.full(shape=[1, 2], dtype='float32', fill_value=1)
def fn_2():
return paddle.full(shape=[2, 2], dtype='int32', fill_value=2)
def fn_3():
return paddle.full(shape=[3], dtype='int32', fill_value=3)
main_program = paddle.static.default_startup_program()
startup_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
index_1 = paddle.full(shape=[1], dtype='int32', fill_value=1)
index_2 = paddle.full(shape=[1], dtype='int32', fill_value=2)
out_1 = paddle.static.nn.switch_case(
branch_index=index_1,
branch_fns={1: fn_1, 2: fn_2},
default=fn_3)
out_2 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(1, fn_1), (2, fn_2)],
default=fn_3)
# Argument default is None and no index matches. fn;,,_3 will be called because of the max index 7.
out_3 = paddle.static.nn.switch_case(
branch_index=index_2,
branch_fns=[(0, fn_1), (4, fn_2), (7, fn_3)])
exe = paddle.static.Executor(paddle.CPUPlace())
res_1, res_2, res_3 = exe.run(main_program, fetch_list=[out_1, out_2, out_3])
print(res_1) # [[1. 1.]]
print(res_2) # [[2 2] [2 2]]
print(res_3) # [3 3 3]
'''
helper = LayerHelper('switch_case', **locals())
def _check_args(branch_index, branch_fns, default):
check_variable_and_dtype(
branch_index,
'branch_index',
['uint8', 'int32', 'int64'],
'static.nn.switch_case',
)
if convert_dtype(branch_index.dtype) != "int64":
branch_index = paddle.cast(branch_index, "int64")
check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case')
branch_fns = (
branch_fns.items() if isinstance(branch_fns, dict) else branch_fns
)
branch_fns = (
list(enumerate(branch_fns))
if all(callable(fn) for fn in branch_fns)
else branch_fns
)
keys_of_fns = []
for index_fn_pair in branch_fns:
if not isinstance(index_fn_pair, tuple):
raise TypeError(
_error_message(
"The elements' type",
"branch_fns",
"switch_case",
tuple,
type(branch_fns),
)
)
if len(index_fn_pair) != 2:
raise TypeError(
_error_message(
"The tuple's size",
"branch_fns",
"switch_case",
"2",
str(len(index_fn_pair)) + "-tuple",
)
)
key, fn = index_fn_pair
if not isinstance(key, int):
raise TypeError(
_error_message(
"The key's type",
"branch_fns",
"switch_case",
int,
type(key),
)
)
if key in keys_of_fns:
raise ValueError(
"The key in 'branch_fns' must be unique, but '{}' appears more than once.".format(
key
)
)
else:
keys_of_fns.append(key)
if not callable(fn):
raise TypeError(
_error_message(
"The type of function for key {}".format(key),
"branch_fns",
"switch_case",
"callable",
type(fn),
)
)
if default is None:
default = sorted(branch_fns)[-1][1]
branch_fns = sorted(branch_fns)[:-1]
elif not callable(default):
raise TypeError("The default in Op(case) must be callable.")
pred_fn_pairs = []
for index, fn in branch_fns:
new_index = paddle.full(shape=[1], dtype="int64", fill_value=index)
pred = paddle.equal(branch_index, new_index)
pred_fn_pairs.append((pred, fn))
return pred_fn_pairs, default
pred_fn_pairs, default = _check_args(branch_index, branch_fns, default)
false_fn = default
for pred, true_fn in pred_fn_pairs:
false_fn = partial(cond, pred=pred, true_fn=true_fn, false_fn=false_fn)
final_fn = false_fn
return final_fn()
......@@ -212,7 +212,6 @@ HIGH_PARALLEL_JOB_NEW = [
'check_reduce_rank_test',
'test_progressbar',
'test_seed_op',
'test_shrink_rnn_memory',
'test_fc_bf16_mkldnn_op',
'test_sequence_first_step',
'test_fusion_lstm_mkldnn_op',
......@@ -273,7 +272,6 @@ HIGH_PARALLEL_JOB_NEW = [
'test_fleet_graph_executor',
'decorator_test',
'test_collective_base',
'test_lod_rank_table',
'test_multi_gru_mkldnn_op',
'test_eager_deletion_conditional_block',
'op_proto_maker_test',
......@@ -868,7 +866,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_imperative_load_static_param',
'test_imperative_qat_user_defined',
'test_anchor_generator_op',
'test_if_else_op',
'test_prepare_op',
'test_conj_op',
'test_imperative_hook_for_layer',
......@@ -1099,7 +1096,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_sequence_mask',
'test_fill_op',
'test_imperative_deepcf',
'test_reorder_lod_tensor',
'test_multiply',
'test_partial_program',
'test_fetch_feed',
......@@ -1264,7 +1260,6 @@ FOURTH_HIGH_PARALLEL_JOB_NEW = [
'test_imperative_static_runner_mnist',
'test_nearest_interp_op',
'test_diag_embed',
'test_imperative_basic',
'test_merge_selectedrows_op',
'test_feed_data_check_shape_type',
'test_complex_trace_layer',
......@@ -1740,7 +1735,6 @@ CPU_PARALLEL_JOB = [
'test_simplify_with_basic_ops_pass',
'test_similarity_focus_op',
'test_shuffle_batch_op',
'test_shrink_rnn_memory',
'test_set_bool_attr',
'test_sequence_topk_avg_pooling',
'test_sequence_scatter_op',
......@@ -1846,7 +1840,6 @@ CPU_PARALLEL_JOB = [
'test_logger',
'test_lod_tensor_array_ops',
'test_lod_tensor_array',
'test_lod_rank_table',
'test_locality_aware_nms_op',
'test_load_vars_shape_check',
'test_load_op_xpu',
......@@ -2373,7 +2366,6 @@ TETRAD_PARALLEL_JOB = [
'test_trt_conv3d_op',
'test_parallel_executor_drop_scope',
'test_tensorrt_engine',
'test_ir_memory_optimize_ifelse_op',
'test_parallel_executor_mnist',
'test_load_state_dict_from_old_format',
'test_fuse_elewise_add_act_pass',
......@@ -2594,7 +2586,6 @@ TETRAD_PARALLEL_JOB = [
'test_imperative_hook_for_layer',
'test_complex_sum_layer',
'test_complex_cast',
'test_reorder_lod_tensor',
'test_complex_kron',
'test_complex_trace_layer',
'test_merge_selectedrows_op',
......@@ -2851,7 +2842,6 @@ TWO_PARALLEL_JOB = [
'test_imperative_data_parallel',
'test_norm_nn_grad',
'test_im2sequence_op',
'test_if_else_op',
'test_one_hot_v2_op',
'test_grid_sampler_op',
'test_pad_op',
......@@ -3068,7 +3058,6 @@ TWO_PARALLEL_JOB = [
'test_broadcast_tensors_op',
'test_pad3d_op',
'test_cumprod_op',
'test_imperative_basic',
'trt_fc_prelu_test',
'test_sigmoid_focal_loss',
'test_pixel_shuffle',
......
......@@ -263,7 +263,6 @@ STATIC_MODE_TESTING_LIST = [
'test_huber_loss_op',
'test_im2sequence_op',
'test_image_classification_layer',
'test_imperative_basic',
'test_imperative_deepcf',
'test_imperative_framework',
'test_imperative_gan',
......@@ -293,7 +292,6 @@ STATIC_MODE_TESTING_LIST = [
'test_inverse_op',
'test_io_save_load',
'test_iou_similarity_op',
'test_ir_memory_optimize_ifelse_op',
'test_ir_memory_optimize_pass',
'test_is_empty_op',
'test_isfinite_op',
......@@ -315,7 +313,6 @@ STATIC_MODE_TESTING_LIST = [
'test_load_vars_shape_check',
'test_locality_aware_nms_op',
'test_lod_array_length_op',
'test_lod_rank_table',
'test_lod_tensor_array_ops',
'test_log_loss_op',
'test_log_softmax',
......@@ -440,7 +437,6 @@ STATIC_MODE_TESTING_LIST = [
'test_registry',
'test_regularizer',
'test_regularizer_api',
'test_reorder_lod_tensor',
'test_reshape_op',
'test_reshape_bf16_op',
'test_retinanet_detection_output',
......@@ -472,7 +468,6 @@ STATIC_MODE_TESTING_LIST = [
'test_sgd_op',
'test_shape_op',
'test_shard_index_op',
'test_shrink_rnn_memory',
'test_shuffle_batch_op',
'test_shuffle_channel_op',
'test_sigmoid_cross_entropy_with_logits_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册