未验证 提交 72e1eb6b 编写于 作者: X xiongkun 提交者: GitHub

[CherryPick] Cherry pick #45916 #46031 #47299 (#47610)

* [ Dy2Static ] Fix bugs when select inputs meeting different shape or undefined-var (#45916)

* fix select_input with different shape errors:
1. select_input_with_buildin_type directly return non-undefinedvar branch when meeting undefined var
2. the output shape of select_input is inferred from inputs.

* reverse the logic in select_input

* [warning] added warning message in cond block when one branch returns variable and another returns None (#46031)

* [cherry-pick] Allow manaully set py_reader name in standalone executor (#45898) (#45931)

* Allow manaully set py_reader name in standalone executor

* [BugFix] while cond receives dict as input (#47299)

* fix bugs while cond receives dict as input

* add unittest

* change flatten -> _is_sequence_except_dict

* code format
Co-authored-by: Nfeifei-111 <wuzhanfei@baidu.com>
上级 cfee9c13
......@@ -18,31 +18,73 @@ from ..wrapped_decorator import signature_safe_contextmanager
from .layer_function_generator import autodoc, templatedoc
from .tensor import assign, cast, fill_constant
from .. import core
from ..framework import Program, Variable, Operator, _non_static_mode, static_only, _in_legacy_dygraph, in_dygraph_mode
from ..framework import (
Program,
Variable,
Operator,
_non_static_mode,
static_only,
_in_legacy_dygraph,
in_dygraph_mode,
)
from ..layer_helper import LayerHelper, unique_name
from .nn import logical_and, logical_not, logical_or
from .utils import assert_same_structure, map_structure, hold_mutable_vars, copy_mutable_vars, padding_to_same_structure, is_sequence, pack_sequence_as, flatten, to_sequence
from .utils import (
assert_same_structure,
map_structure,
hold_mutable_vars,
copy_mutable_vars,
padding_to_same_structure,
is_sequence,
pack_sequence_as,
flatten,
to_sequence,
)
import numpy
import warnings
import six
from functools import reduce, partial
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..data_feeder import (
convert_dtype,
check_variable_and_dtype,
check_type,
check_dtype,
)
from ... import compat as cpt
from ..backward import _infer_var_data_type_shape_
from paddle import _C_ops, _legacy_C_ops
__all__ = [
'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than',
'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal',
'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN',
'reorder_lod_tensor_by_rank', 'Print', 'Assert', 'is_empty', 'case',
'switch_case', 'while_loop'
'While',
'Switch',
'increment',
'array_write',
'create_array',
'less_than',
'less_equal',
'greater_than',
'greater_equal',
'equal',
'not_equal',
'array_read',
'array_length',
'cond',
'IfElse',
'DynamicRNN',
'StaticRNN',
'reorder_lod_tensor_by_rank',
'Print',
'Assert',
'is_empty',
'case',
'switch_case',
'while_loop',
]
def select_output(input, outputs, mask):
"""
**select_output**
**select_output**
This API takes in one input and multiple outputs and an integer mask. It
selects the output specified by the mask and copy the input to selected
output. It is useful in control flow.
......@@ -61,12 +103,11 @@ def select_output(input, outputs, mask):
check_variable_and_dtype(mask, 'mask', ['int32'], 'select_output')
check_type(outputs, 'outputs', (list, tuple), 'select_output')
helper.append_op(type='select_output',
inputs={
'X': input,
'Mask': mask
},
outputs={'Out': outputs})
helper.append_op(
type='select_output',
inputs={'X': input, 'Mask': mask},
outputs={'Out': outputs},
)
return outputs
......@@ -85,14 +126,15 @@ def _select_input_infer_shape(first_shape, second_shape):
)
return second_shape
out_shape = list(
map(lambda a, b: a if a == b else -1, first_shape, second_shape))
map(lambda a, b: a if a == b else -1, first_shape, second_shape)
)
return out_shape
def select_input(inputs, mask):
"""
**select_input**
This API takes in multiple inputs and uses an integer mask to select one
input to output. It is useful in control flow.
......@@ -109,32 +151,34 @@ def select_input(inputs, mask):
check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')
# Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
#assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
# assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
output_dtype = inputs[1].dtype
output_type = inputs[1].type
out = helper.create_variable(dtype=output_dtype,
shape=output_shape,
type=output_type)
helper.append_op(type='select_input',
inputs={
'X': inputs,
'Mask': mask
},
outputs={'Out': out})
out = helper.create_variable(
dtype=output_dtype, shape=output_shape, type=output_type
)
helper.append_op(
type='select_input',
inputs={'X': inputs, 'Mask': mask},
outputs={'Out': out},
)
return out
def select_input_with_buildin_type(inputs, mask, name):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import (
to_static_variable,
)
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
false_var, true_var = inputs
if isinstance(false_var, UndefinedVar) and isinstance(
true_var, UndefinedVar):
""" None -> UndefinedVar, so the real value is a [None, UndefinedVar] or [None, None], we just return None.
"""
true_var, UndefinedVar
):
"""None -> UndefinedVar, so the real value is a [None, UndefinedVar] or [None, None], we just return None."""
return None
if isinstance(false_var, Variable) and isinstance(true_var, Variable):
......@@ -142,50 +186,63 @@ def select_input_with_buildin_type(inputs, mask, name):
return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")
f"Exceptions throwed while doing select_input on {name}:\n{e}"
)
elif (isinstance(false_var, (support_ret_buildin_type))
and isinstance(false_var, type(true_var))):
elif isinstance(false_var, (support_ret_buildin_type)) and isinstance(
false_var, type(true_var)
):
if false_var == true_var:
return false_var
else:
inputs = [
to_static_variable(false_var),
to_static_variable(true_var)
to_static_variable(true_var),
]
# Deal with the situations like this: false_var is int and true_var is Variable
elif ((isinstance(false_var, support_ret_buildin_type)
and isinstance(true_var, Variable))
or (isinstance(true_var, support_ret_buildin_type)
and isinstance(false_var, Variable))):
elif (
isinstance(false_var, support_ret_buildin_type)
and isinstance(true_var, Variable)
) or (
isinstance(true_var, support_ret_buildin_type)
and isinstance(false_var, Variable)
):
inputs = [to_static_variable(false_var), to_static_variable(true_var)]
warnings.warn(
"Return results from different branches in cond are not same type: "
"false_var returned by fasle_fn is '{}' and true_var of true_fn is "
"'{}'".format(type(false_var), type(true_var)))
elif ((isinstance(false_var, UndefinedVar)
and isinstance(true_var, (Variable, ) + support_ret_buildin_type))
or (isinstance(true_var, UndefinedVar)
and isinstance(false_var,
(Variable, ) + support_ret_buildin_type))):
"'{}'".format(type(false_var), type(true_var))
)
elif (
isinstance(false_var, UndefinedVar)
and isinstance(true_var, (Variable,) + support_ret_buildin_type)
) or (
isinstance(true_var, UndefinedVar)
and isinstance(false_var, (Variable,) + support_ret_buildin_type)
):
def create_var_if_not_undefined_var(a):
if isinstance(a, UndefinedVar): return a
if isinstance(a, UndefinedVar):
return a
return to_static_variable(a)
true_var, false_var = to_static_variable(true_var), to_static_variable(
false_var)
false_var
)
inputs = [false_var, true_var]
else:
raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'".
format(type(false_var), type(true_var)))
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'".format(
type(false_var), type(true_var)
)
)
try:
return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")
f"Exceptions throwed while doing select_input on {name}:\n{e}"
)
def split_lod_tensor(input, mask, level=0):
......@@ -222,23 +279,26 @@ def split_lod_tensor(input, mask, level=0):
input=x, mask=y, level=level)
"""
check_type(input, 'input', (Variable, list, tuple, type(None)),
'fluid.layers.split_lod_tensor')
check_type(
input,
'input',
(Variable, list, tuple, type(None)),
'fluid.layers.split_lod_tensor',
)
check_type(mask, 'mask', (Variable, list), 'fluid.layers.split_lod_tensor')
check_type(level, 'level', int, 'fluid.layers.split_lod_tensor')
helper = LayerHelper('split_lod_tensor', **locals())
out_true = helper.create_variable_for_type_inference(dtype=input.dtype)
out_false = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(type='split_lod_tensor',
inputs={
'X': input,
'Mask': mask,
},
outputs={
'OutTrue': out_true,
'OutFalse': out_false
},
attrs={'level': level})
helper.append_op(
type='split_lod_tensor',
inputs={
'X': input,
'Mask': mask,
},
outputs={'OutTrue': out_true, 'OutFalse': out_false},
attrs={'level': level},
)
return out_true, out_false
......@@ -280,37 +340,48 @@ def merge_lod_tensor(in_true, in_false, x, mask, level=0):
in_true=out_true, in_false=out_false, mask=y, x=x, level=level)
"""
helper = LayerHelper('merge_lod_tensor', **locals())
check_type(x, 'x', (Variable, list, tuple, type(None)),
'fluid.layers.merge_lod_tensor')
check_type(
x,
'x',
(Variable, list, tuple, type(None)),
'fluid.layers.merge_lod_tensor',
)
check_type(mask, 'mask', (Variable, list), 'fluid.layers.merge_lod_tensor')
check_type(in_true, 'in_true', (Variable, list, tuple, type(None)),
'fluid.layers.merge_lod_tensor')
check_type(in_false, 'in_false', (Variable, list, tuple, type(None)),
'fluid.layers.merge_lod_tensor')
check_type(
in_true,
'in_true',
(Variable, list, tuple, type(None)),
'fluid.layers.merge_lod_tensor',
)
check_type(
in_false,
'in_false',
(Variable, list, tuple, type(None)),
'fluid.layers.merge_lod_tensor',
)
out = helper.create_variable_for_type_inference(dtype=in_true.dtype)
helper.append_op(type='merge_lod_tensor',
inputs={
'X': x,
'Mask': mask,
'InTrue': in_true,
'InFalse': in_false
},
outputs={'Out': out},
attrs={'level': level})
helper.append_op(
type='merge_lod_tensor',
inputs={'X': x, 'Mask': mask, 'InTrue': in_true, 'InFalse': in_false},
outputs={'Out': out},
attrs={'level': level},
)
return out
@static_only
def Print(input,
first_n=-1,
message=None,
summarize=20,
print_tensor_name=True,
print_tensor_type=True,
print_tensor_shape=True,
print_tensor_layout=True,
print_tensor_lod=True,
print_phase='both'):
def Print(
input,
first_n=-1,
message=None,
summarize=20,
print_tensor_name=True,
print_tensor_type=True,
print_tensor_shape=True,
print_tensor_layout=True,
print_tensor_lod=True,
print_phase='both',
):
'''
:api_attr: Static Graph
......@@ -334,7 +405,7 @@ def Print(input,
print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
print_phase (str): Which phase to displace, including 'forward',
'backward' and 'both'. Default: 'both'. If set to 'backward', will
'backward' and 'both'. Default: 'both'. If set to 'backward', will
only print the gradients of input tensor; If set to 'both', will
both print the input tensor itself and the gradients of input tensor.
......@@ -348,11 +419,11 @@ def Print(input,
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
out = paddle.static.Print(x, message="The content of input layer:")
......@@ -368,26 +439,31 @@ def Print(input,
# - dtype: long
# - data: [3 3 3 3 3 3]
'''
check_variable_and_dtype(input, 'input',
['float32', 'float64', 'int32', 'int64', 'bool'],
'fluid.layers.Print')
check_variable_and_dtype(
input,
'input',
['float32', 'float64', 'int32', 'int64', 'bool'],
'fluid.layers.Print',
)
helper = LayerHelper('print' + "_" + input.name, **locals())
output = helper.create_variable_for_type_inference(input.dtype)
helper.append_op(type='print',
inputs={'In': input},
outputs={'Out': output},
attrs={
'first_n': first_n,
'summarize': summarize,
'message': message or "",
'print_tensor_name': print_tensor_name,
'print_tensor_type': print_tensor_type,
'print_tensor_shape': print_tensor_shape,
'print_tensor_layout': print_tensor_layout,
'print_tensor_lod': print_tensor_lod,
'print_phase': print_phase.upper()
})
helper.append_op(
type='print',
inputs={'In': input},
outputs={'Out': output},
attrs={
'first_n': first_n,
'summarize': summarize,
'message': message or "",
'print_tensor_name': print_tensor_name,
'print_tensor_type': print_tensor_type,
'print_tensor_shape': print_tensor_shape,
'print_tensor_layout': print_tensor_layout,
'print_tensor_lod': print_tensor_lod,
'print_phase': print_phase.upper(),
},
)
return output
......@@ -454,12 +530,11 @@ def Assert(cond, data=None, summarize=20, name=None):
layer_name = name if name else ('assert_' + cond.name)
helper = LayerHelper(layer_name, **locals())
op = helper.append_op(type="assert",
inputs={
"Cond": cond,
"Data": [] if data is None else list(data)
},
attrs={"summarize": summarize})
op = helper.append_op(
type="assert",
inputs={"Cond": cond, "Data": [] if data is None else list(data)},
attrs={"summarize": summarize},
)
return op
......@@ -509,8 +584,9 @@ class BlockGuardWithCompletion(BlockGuard):
return False
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn._complete_op()
return super(BlockGuardWithCompletion,
self).__exit__(exc_type, exc_val, exc_tb)
return super(BlockGuardWithCompletion, self).__exit__(
exc_type, exc_val, exc_tb
)
class StaticRNNMemoryLink(object):
......@@ -576,12 +652,13 @@ class StaticRNN(object):
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
# mark hidden as output
# mark hidden as output
rnn.step_output(hidden)
# get StaticrNN final output
result = rnn()
"""
BEFORE_RNN_BLOCK = 0
IN_RNN_BLOCK = 1
AFTER_RNN_BLOCK = 2
......@@ -607,13 +684,15 @@ class StaticRNN(object):
if self.status != StaticRNN.IN_RNN_BLOCK:
raise ValueError("You must invoke {0} in rnn block".format(method))
def memory(self,
init=None,
shape=None,
batch_ref=None,
init_value=0.0,
init_batch_dim_idx=0,
ref_batch_dim_idx=1):
def memory(
self,
init=None,
shape=None,
batch_ref=None,
init_value=0.0,
init_batch_dim_idx=0,
ref_batch_dim_idx=1,
):
"""
Create a memory variable for static rnn.
If the :code:`init` is not None, :code:`memory` will be initialized by
......@@ -639,97 +718,118 @@ class StaticRNN(object):
Examples 1:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
Examples 2:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
boot_memory = fluid.layers.data(name='boot', shape=[hidden_size], dtype='float32', lod_level=1)
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# init memory
prev = rnn.memory(init=boot_memory)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# update hidden with prev
rnn.update_memory(prev, hidden)
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
boot_memory = fluid.layers.data(name='boot', shape=[hidden_size], dtype='float32', lod_level=1)
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# init memory
prev = rnn.memory(init=boot_memory)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# update hidden with prev
rnn.update_memory(prev, hidden)
"""
self._assert_in_rnn_block_('memory')
check_type(init, "init", (Variable, type(None)),
"fluid.layers.StaticRNN.memory")
check_type(shape, "shape", (list, tuple, type(None)),
"fluid.layers.StaticRNN.memory")
check_type(batch_ref, "batch_ref", (Variable, type(None)),
"fluid.layers.StaticRNN.memory")
check_type(
init,
"init",
(Variable, type(None)),
"fluid.layers.StaticRNN.memory",
)
check_type(
shape,
"shape",
(list, tuple, type(None)),
"fluid.layers.StaticRNN.memory",
)
check_type(
batch_ref,
"batch_ref",
(Variable, type(None)),
"fluid.layers.StaticRNN.memory",
)
if init is None:
if shape is None or batch_ref is None:
raise ValueError(
"if init is None, memory at least need shape and batch_ref")
"if init is None, memory at least need shape and batch_ref"
)
parent_block = self._parent_block()
var_name = unique_name.generate_with_ignorable_key("@".join(
[self.helper.name, "memory_boot"]))
boot_var = parent_block.create_var(name=var_name,
shape=shape,
dtype=batch_ref.dtype,
persistable=False)
parent_block.append_op(type="fill_constant_batch_size_like",
inputs={'Input': [batch_ref]},
outputs={'Out': [boot_var]},
attrs={
'value': init_value,
'shape': boot_var.shape,
'dtype': boot_var.dtype,
'input_dim_idx': ref_batch_dim_idx,
'output_dim_idx': init_batch_dim_idx
})
var_name = unique_name.generate_with_ignorable_key(
"@".join([self.helper.name, "memory_boot"])
)
boot_var = parent_block.create_var(
name=var_name,
shape=shape,
dtype=batch_ref.dtype,
persistable=False,
)
parent_block.append_op(
type="fill_constant_batch_size_like",
inputs={'Input': [batch_ref]},
outputs={'Out': [boot_var]},
attrs={
'value': init_value,
'shape': boot_var.shape,
'dtype': boot_var.dtype,
'input_dim_idx': ref_batch_dim_idx,
'output_dim_idx': init_batch_dim_idx,
},
)
return self.memory(init=boot_var)
else:
pre_mem = self.helper.create_variable(
name=unique_name.generate_with_ignorable_key("@".join(
[self.helper.name, "mem"])),
name=unique_name.generate_with_ignorable_key(
"@".join([self.helper.name, "mem"])
),
dtype=init.dtype,
shape=init.shape)
self.memories[pre_mem.name] = StaticRNNMemoryLink(init=init,
pre_mem=pre_mem)
shape=init.shape,
)
self.memories[pre_mem.name] = StaticRNNMemoryLink(
init=init, pre_mem=pre_mem
)
return pre_mem
def step_input(self, x):
......@@ -746,29 +846,29 @@ class StaticRNN(object):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
"""
self._assert_in_rnn_block_('step_input')
......@@ -778,10 +878,9 @@ class StaticRNN(object):
elif x.shape[0] != -1 and self.seq_len != x.shape[0]:
raise ValueError("Static RNN only take fix seq_len input")
ipt = self.helper.create_variable(name=x.name,
dtype=x.dtype,
shape=list(x.shape[1:]),
type=x.type)
ipt = self.helper.create_variable(
name=x.name, dtype=x.dtype, shape=list(x.shape[1:]), type=x.type
)
self.inputs.append(ipt)
return ipt
......@@ -798,47 +897,50 @@ class StaticRNN(object):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
rnn.step_output(hidden)
result = rnn()
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
rnn.step_output(hidden)
result = rnn()
"""
self._assert_in_rnn_block_('step_output')
check_type(o, "o", Variable, "fluid.layers.StaticRNN.step_output")
tmp_o = self.helper.create_variable_for_type_inference(dtype=o.dtype)
self.helper.append_op(type='rnn_memory_helper',
inputs={'X': [o]},
outputs={'Out': tmp_o},
attrs={'dtype': o.dtype})
self.helper.append_op(
type='rnn_memory_helper',
inputs={'X': [o]},
outputs={'Out': tmp_o},
attrs={'dtype': o.dtype},
)
out_var = self._parent_block().create_var(name=tmp_o.name,
shape=[self.seq_len] +
list(tmp_o.shape),
dtype=tmp_o.dtype)
out_var = self._parent_block().create_var(
name=tmp_o.name,
shape=[self.seq_len] + list(tmp_o.shape),
dtype=tmp_o.dtype,
)
self.outputs.append(out_var)
......@@ -855,33 +957,33 @@ class StaticRNN(object):
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
# mark each step's hidden and word as output
rnn.output(hidden, word)
result = rnn()
import paddle.fluid as fluid
import paddle.fluid.layers as layers
vocab_size, hidden_size=10000, 200
x = fluid.data(name="x", shape=[None, 1, 1], dtype='int64')
# create word sequence
x_emb = layers.embedding(
input=x,
size=[vocab_size, hidden_size],
dtype='float32',
is_sparse=False)
# transform batch size to dim 1
x_emb = layers.transpose(x_emb, perm=[1, 0, 2])
rnn = fluid.layers.StaticRNN()
with rnn.step():
# mark created x_emb as input, each step process a word
word = rnn.step_input(x_emb)
# create prev memory parameter, batch size comes from word
prev = rnn.memory(shape=[-1, hidden_size], batch_ref = word)
hidden = fluid.layers.fc(input=[word, prev], size=hidden_size, act='relu')
# use hidden to update prev
rnn.update_memory(prev, hidden)
# mark each step's hidden and word as output
rnn.output(hidden, word)
result = rnn()
"""
for each in outputs:
self.step_output(each)
......@@ -954,7 +1056,8 @@ class StaticRNN(object):
]
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
type=core.VarDesc.VarType.STEP_SCOPES
)
inlinks = [parent_block.var(i.name) for i in self.inputs]
outlinks = self.outputs
......@@ -966,39 +1069,41 @@ class StaticRNN(object):
for _, mem in six.iteritems(self.memories):
boot_memories.append(mem.init)
pre_memories.append(mem.pre_mem.name)
assert mem.mem is not None, "%s should be updated in every step." % (
mem.init.name)
assert (
mem.mem is not None
), "%s should be updated in every step." % (mem.init.name)
mem_var = rnn_block.var(mem.mem.name)
assert isinstance(mem_var, Variable)
new_mem = self.helper.create_variable_for_type_inference(
dtype=mem_var.dtype)
rnn_block.append_op(type='rnn_memory_helper',
inputs={'X': [mem_var]},
outputs={'Out': [new_mem]},
attrs={'dtype': mem_var.dtype})
dtype=mem_var.dtype
)
rnn_block.append_op(
type='rnn_memory_helper',
inputs={'X': [mem_var]},
outputs={'Out': [new_mem]},
attrs={'dtype': mem_var.dtype},
)
memories.append(new_mem.name)
parent_block.append_op(type='recurrent',
inputs={
'inputs': inlinks,
'initial_states': boot_memories,
'parameters': parameters
},
outputs={
'outputs': outlinks,
'step_scopes': [step_scope]
},
attrs={
'has_states': len(pre_memories) > 0,
'ex_states': pre_memories,
'states': memories,
'sub_block': rnn_block
})
parent_block.append_op(
type='recurrent',
inputs={
'inputs': inlinks,
'initial_states': boot_memories,
'parameters': parameters,
},
outputs={'outputs': outlinks, 'step_scopes': [step_scope]},
attrs={
'has_states': len(pre_memories) > 0,
'ex_states': pre_memories,
'states': memories,
'sub_block': rnn_block,
},
)
class WhileGuard(BlockGuard):
def __init__(self, while_op):
if not isinstance(while_op, While):
raise TypeError("WhileGuard takes a while op")
......@@ -1017,8 +1122,9 @@ class WhileGuard(BlockGuard):
return super(WhileGuard, self).__exit__(exc_type, exc_val, exc_tb)
def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
helper):
def get_inputs_outputs_in_block(
current_block, inner_inputs, inner_outputs, helper
):
"""
Find inputs and outputs in current control flow block.
:param current_block: Current control flow block.
......@@ -1049,7 +1155,8 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in inner_outputs and not is_ignore_vars(
op, in_var_name):
op, in_var_name
):
inner_inputs.add(in_var_name)
for oname in op.output_names:
......@@ -1065,8 +1172,11 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
current_block_var = None
if current_block.has_var(in_var_name):
current_block_var = current_block.var(in_var_name)
if not parent_block_var and current_block_var and \
current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
if (
not parent_block_var
and current_block_var
and current_block_var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
remove_inner_inputs.add(in_var_name)
inner_inputs = inner_inputs - remove_inner_inputs
......@@ -1077,7 +1187,7 @@ def get_inputs_outputs_in_block(current_block, inner_inputs, inner_outputs,
class While(object):
"""
:api_attr: Static Graph
while loop control flow. Repeat while body until cond is False.
Note:
......@@ -1097,7 +1207,7 @@ class While(object):
Examples 1:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
......@@ -1156,8 +1266,10 @@ class While(object):
check_variable_and_dtype(cond, 'cond', ['bool'], 'fluid.layers.While')
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
raise TypeError(
"condition expected shape as [1], but given shape as {0}.".
format(list(cond.shape)))
"condition expected shape as [1], but given shape as {0}.".format(
list(cond.shape)
)
)
self.cond_var = cond
self.is_test = is_test
......@@ -1168,12 +1280,14 @@ class While(object):
main_program = self.helper.main_program
while_block = main_program.current_block()
parent_block = main_program.block(
main_program.current_block().parent_idx)
main_program.current_block().parent_idx
)
inner_outputs = {self.cond_var.name}
x_name_list = set()
x_name_list, inner_outputs = get_inputs_outputs_in_block(
while_block, x_name_list, inner_outputs, self.helper)
while_block, x_name_list, inner_outputs, self.helper
)
out_vars = []
for inner_out_name in inner_outputs:
......@@ -1187,23 +1301,21 @@ class While(object):
x_name_list -= {self.cond_var.name}
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
type=core.VarDesc.VarType.STEP_SCOPES
)
parent_block.append_op(
type='while',
inputs={
'X':
[parent_block._var_recursive(x_name) for x_name in x_name_list],
'Condition': [self.cond_var]
'X': [
parent_block._var_recursive(x_name)
for x_name in x_name_list
],
'Condition': [self.cond_var],
},
outputs={
'Out': out_vars,
'StepScopes': [step_scope]
},
attrs={
'sub_block': while_block,
"is_test": self.is_test
})
outputs={'Out': out_vars, 'StepScopes': [step_scope]},
attrs={'sub_block': while_block, "is_test": self.is_test},
)
support_ret_buildin_type = (bool, float, six.integer_types)
......@@ -1215,14 +1327,17 @@ def assign_skip_lod_tensor_array(input, output):
"""
def has_shape_diff(x_var, y_var):
if len(x_var.shape) != len(y_var.shape): return True
if len(x_var.shape) != len(y_var.shape):
return True
for x_dim, y_dim in zip(x_var.shape, y_var.shape):
if x_dim != y_dim and -1 not in [x_dim, y_dim]: return True
if x_dim != y_dim and -1 not in [x_dim, y_dim]:
return True
return False
if not isinstance(input, (Variable, core.VarBase)):
if isinstance(output, Variable) and isinstance(
input, support_ret_buildin_type):
input, support_ret_buildin_type
):
assign(input, output)
else:
output = input
......@@ -1231,15 +1346,21 @@ def assign_skip_lod_tensor_array(input, output):
if input.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY:
main_program = input.block.program
parent_block = main_program.block(
main_program.current_block().parent_idx)
main_program.current_block().parent_idx
)
if parent_block and not parent_block._find_var_recursive(input.name):
assign(input, output)
else:
if isinstance(output, Variable) and isinstance(
input, Variable) and has_shape_diff(input, output):
if (
isinstance(output, Variable)
and isinstance(input, Variable)
and has_shape_diff(input, output)
):
warnings.warn(
"In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right."
.format(input.shape, output.shape))
"In dy2static mode, we attemp to assign a variable with shape {} into a variable with shape{}, which is not always right.".format(
input.shape, output.shape
)
)
assign(input, output)
......@@ -1255,7 +1376,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
Args:
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
as many arguments as ``loop_vars`` .
as many arguments as ``loop_vars`` .
body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
(length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
......@@ -1285,7 +1406,7 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
i = paddle.full(shape=[1], fill_value=0, dtype='int64') # loop counter
ten = paddle.full(shape=[1], fill_value=10, dtype='int64') # loop length
i, ten = paddle.static.nn.while_loop(cond, body, [i, ten])
exe = paddle.static.Executor(paddle.CPUPlace())
res = exe.run(main_program, feed={}, fetch_list=[i])
print(res) # [array([10])]
......@@ -1301,23 +1422,26 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
raise ValueError("loop_vars in while_loop should not be empty")
pre_cond = cond(*loop_vars)
check_variable_and_dtype(pre_cond, 'var of cond returned', ['bool'],
'fluid.layers.while_loop')
check_variable_and_dtype(
pre_cond, 'var of cond returned', ['bool'], 'fluid.layers.while_loop'
)
if reduce(lambda a, b: a * b, pre_cond.shape, 1) != 1:
raise TypeError(
"the shape of the variable returned by cond should be [1],"
"but given shape as {0}.".format(list(pre_cond.shape)))
"but given shape as {0}.".format(list(pre_cond.shape))
)
if _non_static_mode():
now_cond = pre_cond.numpy()[0]
while (now_cond):
while now_cond:
output_vars = body(*loop_vars)
if not isinstance(output_vars, (list, tuple)):
output_vars = [output_vars]
if len(output_vars) != len(loop_vars):
raise ValueError(
"body in while_loop should return the same arity "
"(length and structure) and types as loop_vars")
"(length and structure) and types as loop_vars"
)
now_cond = cond(*output_vars).numpy()[0]
map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
return loop_vars
......@@ -1342,7 +1466,8 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
except ValueError as e:
raise ValueError(
"body in while_loop should return the same arity "
"(length and structure) as loop_vars: {0}".format(e))
"(length and structure) as loop_vars: {0}".format(e)
)
now_cond = cond(*output_vars)
map_structure(assign_skip_lod_tensor_array, output_vars, loop_vars)
assign(now_cond, pre_cond)
......@@ -1350,22 +1475,27 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
def _deal_with_undefined_var(output_vars, loop_vars):
""" Deal with undefined var cases, We create undefined variable based on the results of body().
In Dy2Static, we use undefined var to represent the var created in control flow. This function
expand the loop_vars and replace original loop_vars.
1. UndefinedVar = Variable # create a variable
2. UndefinedVar = None # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
3. UndefinedVar = List(int) # create a list of variable
4. UndefinedVar = value # create a variable
"""Deal with undefined var cases, We create undefined variable based on the results of body().
In Dy2Static, we use undefined var to represent the var created in control flow. This function
expand the loop_vars and replace original loop_vars.
1. UndefinedVar = Variable # create a variable
2. UndefinedVar = None # create a undefined var with RETURN_NO_VALUE_MAGIC_NUM
3. UndefinedVar = List(int) # create a list of variable
4. UndefinedVar = value # create a variable
"""
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import (
UndefinedVar,
create_undefined_variable,
)
def create_var_like(o_var):
if isinstance(o_var,
(Variable, ) + support_ret_buildin_type) or o_var is None:
if (
isinstance(o_var, (Variable,) + support_ret_buildin_type)
or o_var is None
):
return create_undefined_variable()
if is_sequence(o_var):
"""
"""
Create a complex container class inside the body of while, including Python list and python Dict
"""
return map_structure(lambda x: create_undefined_variable(), o_var)
......@@ -1433,16 +1563,21 @@ def lod_rank_table(x, level=0):
check_type(x, 'x', (Variable, list), 'lod_rank_table')
if isinstance(x, (list)):
for i, input_x in enumerate(x):
check_type(input_x, 'input[' + str(i) + ']', Variable,
'lod_rank_table')
check_type(
input_x, 'input[' + str(i) + ']', Variable, 'lod_rank_table'
)
helper = LayerHelper("lod_rank_table", **locals())
table = helper.create_variable(type=core.VarDesc.VarType.LOD_RANK_TABLE,
name=unique_name.generate("lod_rank_table"))
helper.append_op(type='lod_rank_table',
inputs={'X': x},
outputs={'Out': table},
attrs={'level': level})
table = helper.create_variable(
type=core.VarDesc.VarType.LOD_RANK_TABLE,
name=unique_name.generate("lod_rank_table"),
)
helper.append_op(
type='lod_rank_table',
inputs={'X': x},
outputs={'Out': table},
attrs={'level': level},
)
return table
......@@ -1465,9 +1600,11 @@ def max_sequence_len(rank_table):
"""
helper = LayerHelper("max_seqence_len", **locals())
res = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(type="max_sequence_len",
inputs={"RankTable": rank_table},
outputs={"Out": res})
helper.append_op(
type="max_sequence_len",
inputs={"RankTable": rank_table},
outputs={"Out": res},
)
return res
......@@ -1503,24 +1640,32 @@ def lod_tensor_to_array(x, table):
check_type(x, 'x', (Variable, list), 'lod_tensor_to_array')
if isinstance(x, (list)):
for i, input_x in enumerate(x):
check_type(input_x, 'input[' + str(i) + ']', Variable,
'lod_tensor_to_array')
check_type(
input_x,
'input[' + str(i) + ']',
Variable,
'lod_tensor_to_array',
)
check_type(table, 'table', (Variable, list), 'lod_tensor_to_array')
if isinstance(table, (list)):
for i, table_x in enumerate(table):
check_type(table_x, 'table[' + str(i) + ']', Variable,
'lod_tensor_to_array')
check_type(
table_x,
'table[' + str(i) + ']',
Variable,
'lod_tensor_to_array',
)
helper = LayerHelper("lod_tensor_to_array", **locals())
array = helper.create_variable(
name=unique_name.generate("lod_tensor_to_array"),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=x.dtype)
helper.append_op(type='lod_tensor_to_array',
inputs={
'X': x,
'RankTable': table
},
outputs={'Out': array})
dtype=x.dtype,
)
helper.append_op(
type='lod_tensor_to_array',
inputs={'X': x, 'RankTable': table},
outputs={'Out': array},
)
return array
......@@ -1549,22 +1694,29 @@ def array_to_lod_tensor(x, table):
check_type(x, 'x', (Variable, list), 'array_to_lod_tensor')
if isinstance(x, (list)):
for i, input_x in enumerate(x):
check_type(input_x, 'input[' + str(i) + ']', Variable,
'array_to_lod_tensor')
check_type(
input_x,
'input[' + str(i) + ']',
Variable,
'array_to_lod_tensor',
)
check_type(table, 'table', (Variable, list), 'array_to_lod_tensor')
if isinstance(table, (list)):
for i, table_x in enumerate(table):
check_type(table_x, 'table[' + str(i) + ']', Variable,
'array_to_lod_tensor')
check_type(
table_x,
'table[' + str(i) + ']',
Variable,
'array_to_lod_tensor',
)
helper = LayerHelper("array_to_lod_tensor", **locals())
tmp = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="array_to_lod_tensor",
inputs={
'X': x,
'RankTable': table
},
outputs={'Out': tmp})
helper.append_op(
type="array_to_lod_tensor",
inputs={'X': x, 'RankTable': table},
outputs={'Out': tmp},
)
return tmp
......@@ -1592,17 +1744,20 @@ def increment(x, value=1.0, in_place=True):
if in_dygraph_mode():
return _C_ops.increment_(x, value)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'increment')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'increment'
)
helper = LayerHelper("increment", **locals())
if not in_place:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
out = x
helper.append_op(type='increment',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'step': float(value)})
helper.append_op(
type='increment',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={'step': float(value)},
)
return out
......@@ -1618,8 +1773,8 @@ def array_write(x, i, array=None):
Tensor or LoDTensor. Data type: float32, float64, int32, int64.
i (Variable): 1-D Tensor with shape [1], which represents the position into which
``x`` is written. Data type: int64.
array (LoDTensorArray, optional): The LoDTensorArray into which ``x`` is written.
The default value is None, when a new LoDTensorArray will be created and returned
array (LoDTensorArray, optional): The LoDTensorArray into which ``x`` is written.
The default value is None, when a new LoDTensorArray will be created and returned
as a result.
Returns:
......@@ -1651,8 +1806,8 @@ def array_write(x, i, array=None):
# the output is 2-D Tensor with shape [3,2], which is tmp above.
# dtype is the corresponding C++ data type, which may vary in different environments.
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables.
"""
......@@ -1670,8 +1825,8 @@ def array_write(x, i, array=None):
if array is None:
array = create_array(x.dtype)
assert isinstance(
array,
list), "The 'array' in array_write must be a list in dygraph mode"
array, list
), "The 'array' in array_write must be a list in dygraph mode"
assert i <= len(
array
), "The index 'i' should not be greater than the length of 'array' in dygraph mode"
......@@ -1685,29 +1840,31 @@ def array_write(x, i, array=None):
check_type(x, 'x', (Variable), 'array_write')
helper = LayerHelper('array_write', **locals())
if array is not None:
if not isinstance(
array, Variable
) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
if (
not isinstance(array, Variable)
or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
raise TypeError(
"array should be tensor array vairable in array_write Op")
"array should be tensor array vairable in array_write Op"
)
if array is None:
array = helper.create_variable(
name="{0}.out".format(helper.name),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=x.dtype)
helper.append_op(type='write_to_array',
inputs={
'X': [x],
'I': [i]
},
outputs={'Out': [array]})
dtype=x.dtype,
)
helper.append_op(
type='write_to_array',
inputs={'X': [x], 'I': [i]},
outputs={'Out': [array]},
)
return array
def create_array(dtype, initialized_list=None):
"""
This OP creates an LOD_TENSOR_ARRAY. It is used as
the input of :ref:`api_fluid_layers_array_read` and
the input of :ref:`api_fluid_layers_array_read` and
:ref:`api_fluid_layers_array_write`. Also it can be used
with :ref:`api_fluid_layers_While` to create RNN network.
......@@ -1731,16 +1888,20 @@ def create_array(dtype, initialized_list=None):
if initialized_list is not None:
if not isinstance(initialized_list, (list, tuple)):
raise TypeError(
"Require type(initialized_list) should be list/tuple, but received {}"
.format(type(initialized_list)))
"Require type(initialized_list) should be list/tuple, but received {}".format(
type(initialized_list)
)
)
array = list(initialized_list)
# NOTE: Only support plain list like [x, y,...], not support nested list in static mode.
for val in array:
if not isinstance(val, Variable):
raise TypeError(
"All values in `initialized_list` should be Variable, but recevied {}."
.format(type(val)))
"All values in `initialized_list` should be Variable, but recevied {}.".format(
type(val)
)
)
if _non_static_mode():
return array
......@@ -1749,7 +1910,8 @@ def create_array(dtype, initialized_list=None):
tensor_array = helper.create_variable(
name="{0}.out".format(helper.name),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=dtype)
dtype=dtype,
)
for val in array:
array_write(x=val, i=array_length(tensor_array), array=tensor_array)
......@@ -1786,10 +1948,12 @@ def less_than(x, y, force_cpu=None, cond=None, name=None):
print(result) # [True, False, False, False]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_than")
check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "less_than"
)
check_variable_and_dtype(
y, "y", ["float32", "float64", "int32", "int64"], "less_than"
)
if cond is not None:
check_type(cond, "cond", Variable, "less_than")
if force_cpu != None:
......@@ -1804,13 +1968,12 @@ def less_than(x, y, force_cpu=None, cond=None, name=None):
if force_cpu is not None:
attrs['force_cpu'] = force_cpu
helper.append_op(type='less_than',
inputs={
'X': [x],
'Y': [y]
},
outputs={'Out': [cond]},
attrs=attrs)
helper.append_op(
type='less_than',
inputs={'X': [x], 'Y': [y]},
outputs={'Out': [cond]},
attrs=attrs,
)
return cond
......@@ -1818,13 +1981,13 @@ def less_than(x, y, force_cpu=None, cond=None, name=None):
def less_equal(x, y, cond=None, name=None):
"""
:alias_main: paddle.less_equal
:alias: paddle.less_equal,paddle.tensor.less_equal,paddle.tensor.logic.less_equal
:old_api: paddle.fluid.layers.less_equal
:alias: paddle.less_equal,paddle.tensor.less_equal,paddle.tensor.logic.less_equal
:old_api: paddle.fluid.layers.less_equal
This OP returns the truth value of :math:`x <= y` elementwise, which is equivalent function to the overloaded operator `<=`.
Args:
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *less_equal*.
if cond is None, a new Varibale will be created to store the result.
......@@ -1845,10 +2008,12 @@ def less_equal(x, y, cond=None, name=None):
out1 = label<= limit #out1=[True, False]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"less_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"less_equal")
check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "less_equal"
)
check_variable_and_dtype(
y, "y", ["float32", "float64", "int32", "int64"], "less_equal"
)
if cond is not None:
check_type(cond, "cond", Variable, "less_equal")
......@@ -1859,13 +2024,12 @@ def less_equal(x, y, cond=None, name=None):
attrs = dict()
helper.append_op(type='less_equal',
inputs={
'X': [x],
'Y': [y]
},
outputs={'Out': [cond]},
attrs=attrs)
helper.append_op(
type='less_equal',
inputs={'X': [x], 'Y': [y]},
outputs={'Out': [cond]},
attrs=attrs,
)
return cond
......@@ -1873,13 +2037,13 @@ def less_equal(x, y, cond=None, name=None):
def greater_than(x, y, cond=None, name=None):
"""
:alias_main: paddle.greater_than
:alias: paddle.greater_than,paddle.tensor.greater_than,paddle.tensor.logic.greater_than
:old_api: paddle.fluid.layers.greater_than
:alias: paddle.greater_than,paddle.tensor.greater_than,paddle.tensor.logic.greater_than
:old_api: paddle.fluid.layers.greater_than
This OP returns the truth value of :math:`x > y` elementwise, which is equivalent function to the overloaded operator `>`.
Args:
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_than*.
if cond is None, a new Varibale will be created to store the result.
......@@ -1899,10 +2063,12 @@ def greater_than(x, y, cond=None, name=None):
out = fluid.layers.greater_than(x=label, y=limit) #out=[False, True]
out1 = label > limit #out1=[False, True]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"greater_than")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"greater_than")
check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "greater_than"
)
check_variable_and_dtype(
y, "y", ["float32", "float64", "int32", "int64"], "greater_than"
)
if cond is not None:
check_type(cond, "cond", Variable, "greater_than")
......@@ -1916,13 +2082,12 @@ def greater_than(x, y, cond=None, name=None):
if in_dygraph_mode():
return _C_ops.greater_than(x, y, -1)
else:
helper.append_op(type='greater_than',
inputs={
'X': [x],
'Y': [y]
},
outputs={'Out': [cond]},
attrs=attrs)
helper.append_op(
type='greater_than',
inputs={'X': [x], 'Y': [y]},
outputs={'Out': [cond]},
attrs=attrs,
)
return cond
......@@ -1930,13 +2095,13 @@ def greater_than(x, y, cond=None, name=None):
def greater_equal(x, y, cond=None, name=None):
"""
:alias_main: paddle.greater_equal
:alias: paddle.greater_equal,paddle.tensor.greater_equal,paddle.tensor.logic.greater_equal
:old_api: paddle.fluid.layers.greater_equal
:alias: paddle.greater_equal,paddle.tensor.greater_equal,paddle.tensor.logic.greater_equal
:old_api: paddle.fluid.layers.greater_equal
This OP returns the truth value of :math:`x >= y` elementwise, which is equivalent function to the overloaded operator `>=`.
Args:
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_equal*.
if cond is None, a new Varibale will be created to store the result.
......@@ -1958,10 +2123,12 @@ def greater_equal(x, y, cond=None, name=None):
out_1 = label >= limit #out1=[True, False]
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"greater_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"greater_equal")
check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "greater_equal"
)
check_variable_and_dtype(
y, "y", ["float32", "float64", "int32", "int64"], "greater_equal"
)
if cond is not None:
check_type(cond, "cond", Variable, "greater_equal")
......@@ -1972,13 +2139,12 @@ def greater_equal(x, y, cond=None, name=None):
attrs = dict()
helper.append_op(type='greater_equal',
inputs={
'X': [x],
'Y': [y]
},
outputs={'Out': [cond]},
attrs=attrs)
helper.append_op(
type='greater_equal',
inputs={'X': [x], 'Y': [y]},
outputs={'Out': [cond]},
attrs=attrs,
)
return cond
......@@ -1989,7 +2155,7 @@ def equal(x, y, cond=None, name=None):
Args:
x(Variable): Tensor, data type is float32, float64, int32, int64.
y(Variable): Tensor, data type is float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created
cond(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of *equal*.
if cond is None, a new Varibale will be created to store the result.
name(str, optional): The default value is None. Normally there is no need for
......@@ -2015,10 +2181,12 @@ def equal(x, y, cond=None, name=None):
default_axis = -1
return _C_ops.equal(x, y, default_axis)
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"equal")
check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "equal"
)
check_variable_and_dtype(
y, "y", ["float32", "float64", "int32", "int64"], "equal"
)
if cond is not None:
check_type(cond, "cond", Variable, "equal")
......@@ -2027,25 +2195,22 @@ def equal(x, y, cond=None, name=None):
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
helper.append_op(type='equal',
inputs={
'X': [x],
'Y': [y]
},
outputs={'Out': [cond]})
helper.append_op(
type='equal', inputs={'X': [x], 'Y': [y]}, outputs={'Out': [cond]}
)
return cond
def not_equal(x, y, cond=None, name=None):
"""
:alias_main: paddle.not_equal
:alias: paddle.not_equal,paddle.tensor.not_equal,paddle.tensor.logic.not_equal
:old_api: paddle.fluid.layers.not_equal
:alias: paddle.not_equal,paddle.tensor.not_equal,paddle.tensor.logic.not_equal
:old_api: paddle.fluid.layers.not_equal
This OP returns the truth value of :math:`x != y` elementwise, which is equivalent function to the overloaded operator `!=`.
Args:
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
x(Variable): First input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *not_equal*.
if cond is None, a new Varibale will be created to store the result.
......@@ -2059,15 +2224,17 @@ def not_equal(x, y, cond=None, name=None):
.. code-block:: python
import paddle.fluid as fluid
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
limit = fluid.layers.fill_constant(shape=[1], value=1, dtype='int64')
out = fluid.layers.not_equal(x=label, y=limit)
"""
check_variable_and_dtype(x, "x", ["float32", "float64", "int32", "int64"],
"not_equal")
check_variable_and_dtype(y, "y", ["float32", "float64", "int32", "int64"],
"not_equal")
check_variable_and_dtype(
x, "x", ["float32", "float64", "int32", "int64"], "not_equal"
)
check_variable_and_dtype(
y, "y", ["float32", "float64", "int32", "int64"], "not_equal"
)
if cond is not None:
check_type(cond, "cond", Variable, "not_equal")
......@@ -2076,20 +2243,17 @@ def not_equal(x, y, cond=None, name=None):
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
helper.append_op(type='not_equal',
inputs={
'X': [x],
'Y': [y]
},
outputs={'Out': [cond]})
helper.append_op(
type='not_equal', inputs={'X': [x], 'Y': [y]}, outputs={'Out': [cond]}
)
return cond
def array_read(array, i):
"""
This OP is used to read data at the specified position from the input array
This OP is used to read data at the specified position from the input array
:ref:`api_fluid_LoDTensorArray` . ``array`` is the input array and ``i``
is the specified read position. This OP is often used together with
is the specified read position. This OP is often used together with
:ref:`api_fluid_layers_array_write` OP.
Case 1:
......@@ -2142,14 +2306,14 @@ def array_read(array, i):
# the output is 2-D Tensor with shape [3,2].
# dtype is the corresponding C++ data type, which may vary in different environments.
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables.
"""
if _non_static_mode():
assert isinstance(
array,
list), "The 'array' in array_read must be list in dygraph mode"
array, list
), "The 'array' in array_read must be list in dygraph mode"
assert isinstance(
i, Variable
), "The index 'i' in array_read must be Variable in dygraph mode"
......@@ -2161,17 +2325,17 @@ def array_read(array, i):
check_variable_and_dtype(i, 'i', ['int64'], 'array_read')
helper = LayerHelper('array_read', **locals())
if not isinstance(
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
if (
not isinstance(array, Variable)
or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
raise TypeError("array should be tensor array vairable")
out = helper.create_variable_for_type_inference(dtype=array.dtype)
helper.append_op(type='read_from_array',
inputs={
'X': [array],
'I': [i]
},
outputs={'Out': [out]})
helper.append_op(
type='read_from_array',
inputs={'X': [array], 'I': [i]},
outputs={'Out': [out]},
)
return out
......@@ -2205,21 +2369,19 @@ def shrink_memory(x, i, table):
check_type(i, 'i', Variable, 'shrink_memory')
check_type(table, 'table', Variable, 'shrink_memory')
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='shrink_rnn_memory',
inputs={
'X': [x],
'I': [i],
'RankTable': [table]
},
outputs={'Out': [out]},
attrs={})
helper.append_op(
type='shrink_rnn_memory',
inputs={'X': [x], 'I': [i], 'RankTable': [table]},
outputs={'Out': [out]},
attrs={},
)
return out
def array_length(array):
"""
This OP is used to get the length of the input array :ref:`api_fluid_LoDTensorArray` .
It can be used together with :ref:`api_fluid_layers_array_read` , :ref:`api_fluid_layers_array_write` ,
It can be used together with :ref:`api_fluid_layers_array_read` , :ref:`api_fluid_layers_array_write` ,
:ref:`api_fluid_layers_While` OP to traverse, read and write LoDTensorArray.
Args:
......@@ -2253,33 +2415,35 @@ def array_length(array):
# shape: [1,]
# dtype: l
# data: 11,
# 1-D Tensor with shape [1], whose value is 11. It means that the length of LoDTensorArray
# is 11.
# dtype is the corresponding C++ data type, which may vary in different environments.
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# Eg: if the data type of tensor is int64, then the corresponding C++ data type is int64_t,
# so the dtype value is typeid(int64_t).Name(), which is 'x' on MacOS, 'l' on Linux,
# and '__int64' on Windows. They both represent 64-bit integer variables.
"""
if _non_static_mode():
assert isinstance(
array,
list), "The 'array' in array_write must be a list in dygraph mode"
array, list
), "The 'array' in array_write must be a list in dygraph mode"
return len(array)
if not isinstance(
array,
Variable) or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY:
if (
not isinstance(array, Variable)
or array.type != core.VarDesc.VarType.LOD_TENSOR_ARRAY
):
raise TypeError(
"array should be tensor array vairable in array_length Op")
"array should be tensor array vairable in array_length Op"
)
helper = LayerHelper('array_length', **locals())
tmp = helper.create_variable_for_type_inference(dtype='int64')
tmp.stop_gradient = True
helper.append_op(type='lod_array_length',
inputs={'X': [array]},
outputs={'Out': [tmp]})
helper.append_op(
type='lod_array_length', inputs={'X': [array]}, outputs={'Out': [tmp]}
)
return tmp
......@@ -2301,8 +2465,9 @@ class ConditionalBlockGuard(BlockGuard):
def __exit__(self, exc_type, exc_val, exc_tb):
self.block.complete()
return super(ConditionalBlockGuard,
self).__exit__(exc_type, exc_val, exc_tb)
return super(ConditionalBlockGuard, self).__exit__(
exc_type, exc_val, exc_tb
)
class ConditionalBlock(object):
......@@ -2348,10 +2513,9 @@ class ConditionalBlock(object):
intermediate = set()
params = set()
params, intermediate = get_inputs_outputs_in_block(inside_block,
params,
intermediate,
helper=self.helper)
params, intermediate = get_inputs_outputs_in_block(
inside_block, params, intermediate, helper=self.helper
)
# Todo(liym27) Here assume that all params are in recursive parent block
# but when minimize() called in control flow, some params may be in
......@@ -2367,25 +2531,25 @@ class ConditionalBlock(object):
out_list.append(inner_var)
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
type=core.VarDesc.VarType.STEP_SCOPES
)
conditional_block_op = parent_block.append_op(
type='conditional_block',
inputs={
'Cond': self.inputs,
'Input': param_list,
},
outputs={
'Out': out_list,
'Scope': [step_scope]
},
outputs={'Out': out_list, 'Scope': [step_scope]},
attrs={
'sub_block': inside_block,
'is_scalar_condition': self.is_scalar_condition
})
'is_scalar_condition': self.is_scalar_condition,
},
)
if self.need_append_conditional_block_grad(inside_block):
self.append_conditional_block_grad(parent_block, inside_block,
conditional_block_op)
self.append_conditional_block_grad(
parent_block, inside_block, conditional_block_op
)
def need_append_conditional_block_grad(self, inside_block):
grad_sub_block_idx = inside_block.backward_block_idx
......@@ -2393,10 +2557,13 @@ class ConditionalBlock(object):
# if inside_block have grad_block and grad_block is not itself,
# we will append conditional block grad.
return grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
return (
grad_sub_block_idx != -1 and grad_sub_block_idx != inside_block_idx
)
def append_conditional_block_grad(self, parent_block, inside_block,
conditional_block_op):
def append_conditional_block_grad(
self, parent_block, inside_block, conditional_block_op
):
'''
Append op `conditional_block_grad` manually.
When `optimizer.minimize/append_backward` is called in Paddle control flow,
......@@ -2435,8 +2602,8 @@ class ConditionalBlock(object):
param_list.append(cpt.to_text(inner_var.name))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
conditional_block_op.desc, cpt.to_text(set()),
[grad_sub_block.desc])
conditional_block_op.desc, cpt.to_text(set()), [grad_sub_block.desc]
)
# append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
......@@ -2446,13 +2613,18 @@ class ConditionalBlock(object):
new_op_desc._set_attr(op_role_attr_name, backward)
# set input and output manually
new_op_desc.set_input('Input', param_list)
new_op_desc.set_output('Input@GRAD',
[param + "@GRAD" for param in param_list])
new_op_desc.set_output(
'Input@GRAD', [param + "@GRAD" for param in param_list]
)
new_vars = set()
for grad_var_name in new_op_desc.output_arg_names():
if grad_sub_block.desc.has_var_recursive(cpt.to_bytes(
grad_var_name)) or grad_var_name == core.empty_var_name():
if (
grad_sub_block.desc.has_var_recursive(
cpt.to_bytes(grad_var_name)
)
or grad_var_name == core.empty_var_name()
):
continue
grad_sub_block.desc.var(cpt.to_bytes(grad_var_name))
new_vars.add(grad_var_name)
......@@ -2475,16 +2647,20 @@ def copy_var_to_parent_block(var, layer_helper):
return var
prog = layer_helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0, "Got wrong parent block index when assigning var to parent scope in control_flow"
assert (
parent_idx >= 0
), "Got wrong parent block index when assigning var to parent scope in control_flow"
parent_block = prog.block(parent_idx)
if var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
and parent_block._find_var_recursive(var.name):
if (
var.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY
and parent_block._find_var_recursive(var.name)
):
parent_block_var = var
else:
parent_block_var = parent_block.create_var(dtype=var.dtype,
shape=var.shape,
type=var.type)
parent_block_var = parent_block.create_var(
dtype=var.dtype, shape=var.shape, type=var.type
)
assign(var, parent_block_var)
return parent_block_var
......@@ -2500,8 +2676,8 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
or both return ``None`` if user doens't like to return anything. A nest
structure of tensors in PaddlePaddle is tensor(s), or tuple of tensors, or
list of tensors.
Note:
Note:
1. The tuples or lists returned by ``true_fn`` and ``false_fn`` must have
the same shape because of dataflow model of PaddlePaddle while the
tensors in the tuples or the lists can have different shapes.
......@@ -2509,7 +2685,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
2. This API could be used under both static mode or dygraph mode. If it
is in dygraph mode, the API only runs one branch based on condition.
3. If it is in static mode, any tensors or operations created outside
3. If it is in static mode, any tensors or operations created outside
or inside of ``true_fn`` and ``false_fn`` will be in net building
regardless of which branch is selected at runtime. This has frequently
surprised users who expected a lazy semantics. For example:
......@@ -2538,9 +2714,9 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
name(str, optional): The default value is ``None`` . Normally users
don't have to set this parameter. For more information, please
refer to :ref:`api_guide_Name` .
return_names(sequence of string, optional): The default value is ``None`` .
Normally users don't have to set this parameters. A sequence of strings
to represents the name of returned vars. The structure of sequence must
return_names(sequence of string, optional): The default value is ``None`` .
Normally users don't have to set this parameters. A sequence of strings
to represents the name of returned vars. The structure of sequence must
be same with return values of true_fn and false_fn.
Returns:
......@@ -2586,7 +2762,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
# ret is a tuple containing 2 tensors
# ret[0] = [[1 1]]
# ret[1] = [[ True True True]
# [ True True True]]
# [ True True True]]
"""
if _non_static_mode():
......@@ -2597,15 +2773,19 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
if true_fn is not None:
if not callable(true_fn):
raise TypeError(
"The true_fn in cond must be callable, but received {}".
format(type(true_fn).__name__))
"The true_fn in cond must be callable, but received {}".format(
type(true_fn).__name__
)
)
return true_fn()
else:
if false_fn is not None:
if not callable(false_fn):
raise TypeError(
"The false_fn in cond must be callable, but received {}"
.format(type(false_fn).__name__))
"The false_fn in cond must be callable, but received {}".format(
type(false_fn).__name__
)
)
return false_fn()
return None
......@@ -2619,25 +2799,32 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
if not callable(true_fn):
raise TypeError(
"The true_fn in cond must be callable, but received {}".format(
type(true_fn).__name__))
type(true_fn).__name__
)
)
true_cond_block = ConditionalBlock([pred], is_scalar_condition=True)
with true_cond_block.block():
origin_true_output = true_fn()
if origin_true_output is not None:
true_output = map_structure(copy_to_parent_func,
origin_true_output)
true_output = map_structure(
copy_to_parent_func, origin_true_output
)
if false_fn is not None:
if not callable(false_fn):
raise TypeError(
"The false_fn in cond must be callable, but received {}".format(
type(false_fn).__name__))
false_cond_block = ConditionalBlock([logical_not(pred)],
is_scalar_condition=True)
type(false_fn).__name__
)
)
false_cond_block = ConditionalBlock(
[logical_not(pred)], is_scalar_condition=True
)
with false_cond_block.block():
origin_false_output = false_fn()
if origin_false_output is not None:
false_output = map_structure(copy_to_parent_func,
origin_false_output)
false_output = map_structure(
copy_to_parent_func, origin_false_output
)
if true_output is None and false_output is None:
return None
......@@ -2645,48 +2832,108 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
if true_output is None:
raise ValueError(
"Incompatible return values of true_fn and false_fn in cond: "
"true_fn returns None while false_fn returns non-None")
"true_fn returns None while false_fn returns non-None"
)
if false_output is None:
raise ValueError(
"Incompatible return values of true_fn and false_fn in cond: "
"true_fn returns non-None while false_fn returns None")
"true_fn returns non-None while false_fn returns None"
)
# Merge ture and false output if they are not None
if return_names is None:
return_names = ["no name"] * len(to_sequence(true_output))
is_dy2staic = False
return_names = ["no name"] * len(_to_sequence_except_dict(true_output))
else:
"""
"""
dy2static will set the return_names and expand the return values to UndefinedVar.
"""
is_dy2staic = True
# TODO: expand_undefined_var will replace None to Undefinedvar(), to fix cases like:
# a = None
# if condition:
# a = 1
# Because we can not use variable to express 'None'
true_output, false_output = expand_undefined_var(
true_output, false_output, return_names)
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)
if len(to_sequence(true_output)) != len(to_sequence(false_output)):
true_output, false_output, return_names
)
if len(_to_sequence_except_dict(true_output)) != len(
_to_sequence_except_dict(false_output)
):
raise ValueError(
"true fn returns {} vars, but false fn returns {} vars, which is not equals"
.format(len(to_sequence(true_output)),
len(to_sequence(false_output))))
for true_out, false_out, return_name in zip(to_sequence(true_output),
to_sequence(false_output),
to_sequence(return_names)):
"true fn returns {} vars, but false fn returns {} vars, which is not equals".format(
len(_to_sequence_except_dict(true_output)),
len(_to_sequence_except_dict(false_output)),
)
)
for true_out, false_out, return_name in zip(
_to_sequence_except_dict(true_output),
_to_sequence_except_dict(false_output),
_to_sequence_except_dict(return_names),
):
try:
assert_same_structure(true_out, false_out, check_types=False)
except ValueError as e:
raise ValueError(
"Incompatible return values of `{}` in true_fn and false_fn in cond: {}"
.format(return_name, e))
"Incompatible return values of `{}` in true_fn and false_fn in cond: {}".format(
return_name, e
)
)
def check_ret_none(seq_true, seq_false, seq_names):
for f_true, f_false, f_name in zip(seq_true, seq_false, seq_names):
f_true = flatten(f_true)
f_false = flatten(f_false)
for idx in range(len(f_true)):
if (
f_true[idx] is None
and f_false[idx] is not None
or f_false[idx] is None
and f_true[idx] is not None
):
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
f_name,
type(f_true[idx]),
f_true[idx],
type(f_false[idx]),
f_false[idx],
)
)
check_ret_none(
_to_sequence_except_dict(true_output),
_to_sequence_except_dict(false_output),
_to_sequence_except_dict(return_names),
)
if is_dy2staic:
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output
)
mask = cast(pred, dtype='int32')
merge_func = lambda name, false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask, name)
merge_func = (
lambda name, false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask, name
)
)
def merge_every_var_list(false_vars, true_vars, name):
return map_structure(partial(merge_func, name), false_vars, true_vars)
merged_output = list(
map(merge_every_var_list, to_sequence(false_output),
to_sequence(true_output), to_sequence(return_names)))
map(
merge_every_var_list,
_to_sequence_except_dict(false_output),
_to_sequence_except_dict(true_output),
_to_sequence_except_dict(return_names),
)
)
merged_output = pack_sequence_as(false_output, flatten(merged_output))
return merged_output
......@@ -2695,7 +2942,8 @@ def change_none_to_undefinedvar(nest1, nest2):
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
def map_fn(x):
if x is None: return UndefinedVar("padding")
if x is None:
return UndefinedVar("padding")
return x
nest1_out = pack_sequence_as(nest1, list(map(map_fn, flatten(nest1))))
......@@ -2703,42 +2951,100 @@ def change_none_to_undefinedvar(nest1, nest2):
return nest1_out, nest2_out
def _to_sequence_except_dict(x):
"""
In this function, dict is not viewed as sequence.
"""
if isinstance(x, dict):
return [x]
return to_sequence(x)
def _is_sequence_except_dict(x):
"""
In this function, dict is not viewed as sequence.
"""
if isinstance(x, dict):
return False
return is_sequence(x)
def expand_undefined_var(nest1, nest2, names):
""" TODO: make this function recursively.
nest1: Var1, (UndefinedVar, [1,2,3])
nest2: Var2, ([1,2,3,4], UndefinedVar)
In this case, we should not expand recursively.
"""TODO: make this function recursively.
nest1: Var1, (UndefinedVar, [1,2,3])
nest2: Var2, ([1,2,3,4], UndefinedVar)
In this case, we should not expand recursively.
"""
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_VALUE_PREFIX
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import (
RETURN_VALUE_PREFIX,
)
def pack_undefined_var_as(seq):
return pack_sequence_as(seq,
[UndefinedVar("padding") for i in flatten(seq)])
return pack_sequence_as(
seq, [UndefinedVar("padding") for i in flatten(seq)]
)
def map_fn(n1, n2, name):
if not name.startswith(RETURN_VALUE_PREFIX) and (isinstance(
n1, UndefinedVar) or n1 is None):
def map_fn(n1, n2, name, order):
if not name.startswith(RETURN_VALUE_PREFIX) and (
isinstance(n1, UndefinedVar) or n1 is None
):
if n1 is None and n2 is not None:
if order == 0:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
name, type(n1), n1, type(n2), n2
)
)
else:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
name, type(n2), n2, type(n1), n1
)
)
return pack_undefined_var_as(n2)
return n1
nest1_out = list(
map(map_fn, to_sequence(nest1), to_sequence(nest2), to_sequence(names)))
map(
map_fn,
_to_sequence_except_dict(nest1),
_to_sequence_except_dict(nest2),
_to_sequence_except_dict(names),
[0 for i in _to_sequence_except_dict(names)],
)
)
nest2_out = list(
map(map_fn, to_sequence(nest2), to_sequence(nest1), to_sequence(names)))
if not is_sequence(nest1): nest1_out = nest1_out[0]
if not is_sequence(nest2): nest2_out = nest2_out[0]
map(
map_fn,
_to_sequence_except_dict(nest2),
_to_sequence_except_dict(nest1),
_to_sequence_except_dict(names),
[1 for i in _to_sequence_except_dict(names)],
)
)
if not _is_sequence_except_dict(nest1):
nest1_out = nest1_out[0]
if not _is_sequence_except_dict(nest2):
nest2_out = nest2_out[0]
return nest1_out, nest2_out
def _error_message(what, arg_name, op_name, right_value, error_value):
error_message = "{what} of '{arg_name}' in {op_name} must be " \
error_message = (
"{what} of '{arg_name}' in {op_name} must be "
"{right_value}, but received: {error_value}.".format(
what=what,
arg_name=arg_name,
op_name=op_name,
right_value=right_value,
error_value=error_value)
what=what,
arg_name=arg_name,
op_name=op_name,
right_value=right_value,
error_value=error_value,
)
)
return error_message
......@@ -2819,24 +3125,42 @@ def case(pred_fn_pairs, default=None, name=None):
for pred_fn in pred_fn_pairs:
if not isinstance(pred_fn, tuple):
raise TypeError(
_error_message("The elements' type", "pred_fn_pairs",
"case", tuple, type(pred_fn)))
_error_message(
"The elements' type",
"pred_fn_pairs",
"case",
tuple,
type(pred_fn),
)
)
if len(pred_fn) != 2:
raise TypeError(
_error_message("The tuple's size", "pred_fn_pairs", "case",
"2",
str(len(pred_fn)) + "-tuple"))
_error_message(
"The tuple's size",
"pred_fn_pairs",
"case",
"2",
str(len(pred_fn)) + "-tuple",
)
)
pred, fn = pred_fn
if not isinstance(pred, Variable):
raise TypeError(
_error_message("The pred's type", "pred_fn_pairs", "case",
"boolean Variable", type(pred)))
_error_message(
"The pred's type",
"pred_fn_pairs",
"case",
"boolean Variable",
type(pred),
)
)
if not callable(fn):
raise TypeError(
"The fn for {} of pred_fn_pairs in Op(case) must"
" be callable.".format(pred.name))
" be callable.".format(pred.name)
)
if default is None:
default_index = len(pred_fn_pairs) - 1 # pick the last one
......@@ -2862,11 +3186,11 @@ class Switch(object):
"""
:api_attr: Static Graph
This class is used to implement Switch branch control function.
Switch branch contains several case branches and one default branch.
Switch control flow checks whether the case branch conditions are satisfied in turn,
and only executes the statement after the first case branch that satisfies the conditions.
If there is no case branch that satisfies the condition,
This class is used to implement Switch branch control function.
Switch branch contains several case branches and one default branch.
Switch control flow checks whether the case branch conditions are satisfied in turn,
and only executes the statement after the first case branch that satisfies the conditions.
If there is no case branch that satisfies the condition,
only the statement following the default branch is executed.
Note:
......@@ -2875,7 +3199,7 @@ class Switch(object):
Member Functions:
case(condition): The case branch of Switch whose parameter cond is a scalar Variable of bool type. Only if the cond of the current case branch is True and the cond of the previous case branch is False, the statement after the case branch will be executed, and the statement after the case branch will not be executed.
default(): The default branch of Switch. When cond of all case branches is False, the statement after default branch is executed.
Case and default functions can only be used inside the scope of Switch, as shown below:
......@@ -2897,7 +3221,7 @@ class Switch(object):
Examples:
.. code-block:: python
import paddle.fluid as fluid
lr = fluid.layers.create_global_var(
......@@ -2938,8 +3262,11 @@ class Switch(object):
raise ValueError("case should be called inside with")
check_variable_and_dtype(
condition, 'condition', ['bool'],
'the member function case of fluid.layers.Switch')
condition,
'condition',
['bool'],
'the member function case of fluid.layers.Switch',
)
if len(self.pre_not_conditions) == 0:
cond_block = ConditionalBlock([condition], is_scalar_condition=True)
......@@ -2948,12 +3275,14 @@ class Switch(object):
else:
pre_cond_num = len(self.pre_not_conditions)
pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
new_not_cond = logical_and(x=pre_not_cond,
y=logical_not(x=condition))
new_not_cond = logical_and(
x=pre_not_cond, y=logical_not(x=condition)
)
self.pre_not_conditions.append(new_not_cond)
cond_block = ConditionalBlock(
[logical_and(x=pre_not_cond, y=condition)],
is_scalar_condition=True)
is_scalar_condition=True,
)
return ConditionalBlockGuard(cond_block)
......@@ -2963,7 +3292,8 @@ class Switch(object):
raise ValueError("there should be at least one condition")
cond_block = ConditionalBlock(
[self.pre_not_conditions[pre_cond_num - 1]],
is_scalar_condition=True)
is_scalar_condition=True,
)
return ConditionalBlockGuard(cond_block)
def __enter__(self):
......@@ -2983,7 +3313,6 @@ class Switch(object):
class IfElseBlockGuard(object):
def __init__(self, is_true, ifelse):
if not isinstance(ifelse, IfElse):
raise TypeError("ifelse must be an instance of IfElse class")
......@@ -3004,7 +3333,11 @@ class IfElseBlockGuard(object):
self.cond_block = self.cond_block.block()
def __enter__(self):
self.ie.status = IfElse.IN_IF_ELSE_TRUE_BLOCKS if self.is_true else IfElse.IN_IF_ELSE_FALSE_BLOCKS
self.ie.status = (
IfElse.IN_IF_ELSE_TRUE_BLOCKS
if self.is_true
else IfElse.IN_IF_ELSE_FALSE_BLOCKS
)
self.cond_block.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
......@@ -3031,7 +3364,7 @@ class IfElse(object):
IfElse OP is different from other OPs in usage, which may cause some users confusion. Here is a simple example to illustrate this OP.
.. code-block:: python
# The following code completes the function: subtract 10 from the data greater than 0 in x, add 10 to the data less than 0 in x, and sum all the data.
import numpy as np
import paddle.fluid as fluid
......@@ -3041,7 +3374,7 @@ class IfElse(object):
x_d = np.array([[3], [1], [-2], [-3]]).astype(np.float32)
y_d = np.zeros((4, 1)).astype(np.float32)
# Compare the size of x, y pairs of elements, output cond, cond is shape [4, 1], data type bool 2-D tensor.
# Based on the input data x_d, y_d, it can be inferred that the data in cond are [[true], [true], [false], [false]].
cond = fluid.layers.greater_than(x, y)
......@@ -3060,7 +3393,7 @@ class IfElse(object):
ie.output(out_1)
# According to cond condition, the data processed in the two blocks are merged. The output here is output, the type is List, and the element type in List is Variable.
output = ie() # [array([[-7.], [-9.], [ 8.], [ 7.]], dtype=float32)]
output = ie() # [array([[-7.], [-9.], [ 8.], [ 7.]], dtype=float32)]
# Get the first Variable in the output List and add all elements.
out = fluid.layers.reduce_sum(output[0])
......@@ -3070,7 +3403,7 @@ class IfElse(object):
res = exe.run(fluid.default_main_program(), feed={"x":x_d, "y":y_d}, fetch_list=[out])
print(res)
# [array([-1.], dtype=float32)]
# [array([-1.], dtype=float32)]
Args:
cond (Variable): cond is a 2-D Tensor with shape [N, 1] and data type bool, representing the corresponding execution conditions of N input data. The data type is bool.
......@@ -3081,7 +3414,7 @@ class IfElse(object):
Internal Functions:
The block is constructed by calling the ``with ie. true_block()`` function in the object, and the computational logic under condition true is put into the block. If no corresponding block is constructed, the input data in the corresponding conditional dimension is unchanged.
The block is constructed by calling the ``with ie. false_block()`` function in the object, and the computational logic under condition false is put into the block. If no corresponding block is constructed, the input data in the corresponding conditional dimension is unchanged.
``Out = ie. input (x)`` will take out the data of the corresponding conditional dimension in X and put it into out, supporting the internal processing of multiple inputs in block.
......@@ -3091,6 +3424,7 @@ class IfElse(object):
There is a ``call ()`` function inside the object, that is, by calling ``output = ie ()``, all the outputs inside the block of False are fused as the whole output, the output type is a list, and the type of each element in the list is Variable.
"""
OUT_IF_ELSE_BLOCKS = 0
IN_IF_ELSE_TRUE_BLOCKS = 1
IN_IF_ELSE_FALSE_BLOCKS = 2
......@@ -3112,24 +3446,27 @@ class IfElse(object):
if id(x) not in self.input_table:
parent_block = self._parent_block()
out_true = parent_block.create_var(
name=unique_name.generate_with_ignorable_key('ifelse_input' +
self.helper.name),
dtype=x.dtype)
name=unique_name.generate_with_ignorable_key(
'ifelse_input' + self.helper.name
),
dtype=x.dtype,
)
out_false = parent_block.create_var(
name=unique_name.generate_with_ignorable_key('ifelse_input' +
self.helper.name),
dtype=x.dtype)
parent_block.append_op(type='split_lod_tensor',
inputs={
'X': x,
'Mask': self.cond,
},
outputs={
'OutTrue': out_true,
'OutFalse': out_false
},
attrs={'level': 0})
name=unique_name.generate_with_ignorable_key(
'ifelse_input' + self.helper.name
),
dtype=x.dtype,
)
parent_block.append_op(
type='split_lod_tensor',
inputs={
'X': x,
'Mask': self.cond,
},
outputs={'OutTrue': out_true, 'OutFalse': out_false},
attrs={'level': 0},
)
self.input_table[id(x)] = (out_true, out_false)
else:
out_true, out_false = self.input_table[id(x)]
......@@ -3153,17 +3490,21 @@ class IfElse(object):
if self.status == self.OUT_IF_ELSE_BLOCKS:
raise ValueError("output can only be invoked in the sub-block")
out_table = self.output_table[1 if self.status ==
self.IN_IF_ELSE_TRUE_BLOCKS else 0]
out_table = self.output_table[
1 if self.status == self.IN_IF_ELSE_TRUE_BLOCKS else 0
]
parent_block = self._parent_block()
for each_out in outs:
check_type(each_out, "each output", Variable,
"fluid.layers.IfElse.output")
check_type(
each_out, "each output", Variable, "fluid.layers.IfElse.output"
)
# create outside tensor
outside_out = parent_block.create_var(
name=unique_name.generate_with_ignorable_key("_".join(
[self.helper.name, 'output'])),
dtype=each_out.dtype)
name=unique_name.generate_with_ignorable_key(
"_".join([self.helper.name, 'output'])
),
dtype=each_out.dtype,
)
out_table.append(outside_out)
# assign local var to outside
......@@ -3174,8 +3515,9 @@ class IfElse(object):
raise ValueError("IfElse::__call__ must be out of sub-block")
false_len, true_len = list(map(len, self.output_table))
if false_len == 0 and true_len == 0:
raise ValueError("Must invoke true_block/false_block before "
"__call__")
raise ValueError(
"Must invoke true_block/false_block before " "__call__"
)
elif false_len != true_len and false_len != 0 and true_len != 0:
raise ValueError("The output side must be same")
elif false_len == 0 or true_len == 0:
......@@ -3186,11 +3528,14 @@ class IfElse(object):
rlist = []
for false_var, true_var in zip(*self.output_table):
rlist.append(
merge_lod_tensor(in_true=true_var,
in_false=false_var,
mask=self.cond,
x=self.cond,
level=0))
merge_lod_tensor(
in_true=true_var,
in_false=false_var,
mask=self.cond,
x=self.cond,
level=0,
)
)
return rlist
......@@ -3261,6 +3606,7 @@ class DynamicRNN(object):
# Get RNN's result of the last time step
last = fluid.layers.sequence_last_step(out)
"""
BEFORE_RNN = 0
IN_RNN = 1
AFTER_RNN = 2
......@@ -3378,39 +3724,44 @@ class DynamicRNN(object):
if self.lod_rank_table is None:
self.lod_rank_table = parent_block.create_var(
name=unique_name.generate('lod_rank_table'),
type=core.VarDesc.VarType.LOD_RANK_TABLE)
type=core.VarDesc.VarType.LOD_RANK_TABLE,
)
self.lod_rank_table.stop_gradient = True
parent_block.append_op(type='lod_rank_table',
inputs={"X": x},
outputs={"Out": self.lod_rank_table},
attrs={"level": level})
parent_block.append_op(
type='lod_rank_table',
inputs={"X": x},
outputs={"Out": self.lod_rank_table},
attrs={"level": level},
)
self.max_seq_len = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_max_seq_len'),
dtype='int64')
dtype='int64',
)
self.max_seq_len.stop_gradient = False
parent_block.append_op(type='max_sequence_len',
inputs={'RankTable': self.lod_rank_table},
outputs={"Out": self.max_seq_len})
parent_block.append_op(
type='max_sequence_len',
inputs={'RankTable': self.lod_rank_table},
outputs={"Out": self.max_seq_len},
)
self.cond.stop_gradient = True
parent_block.append_op(type='less_than',
inputs={
'X': self.step_idx,
'Y': self.max_seq_len
},
outputs={'Out': self.cond},
attrs={'force_cpu': True})
parent_block.append_op(
type='less_than',
inputs={'X': self.step_idx, 'Y': self.max_seq_len},
outputs={'Out': self.cond},
attrs={'force_cpu': True},
)
input_array = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_input_array'),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=x.dtype)
dtype=x.dtype,
)
self.input_array.append((input_array, x.dtype))
parent_block.append_op(type='lod_tensor_to_array',
inputs={
'X': x,
'RankTable': self.lod_rank_table
},
outputs={'Out': input_array})
parent_block.append_op(
type='lod_tensor_to_array',
inputs={'X': x, 'RankTable': self.lod_rank_table},
outputs={'Out': input_array},
)
return array_read(array=input_array, i=self.step_idx)
def static_input(self, x):
......@@ -3543,18 +3894,19 @@ class DynamicRNN(object):
check_type(x, 'x', Variable, 'fluid.layers.DynamicRNN.static_input()')
if self.lod_rank_table is None:
raise RuntimeError(
"static_input() must be called after step_input().")
"static_input() must be called after step_input()."
)
parent_block = self._parent_block_()
x_reordered = parent_block.create_var(
name=unique_name.generate("dynamic_rnn_static_input_reordered"),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=x.dtype)
parent_block.append_op(type='reorder_lod_tensor_by_rank',
inputs={
'X': [x],
'RankTable': [self.lod_rank_table]
},
outputs={'Out': [x_reordered]})
dtype=x.dtype,
)
parent_block.append_op(
type='reorder_lod_tensor_by_rank',
inputs={'X': [x], 'RankTable': [self.lod_rank_table]},
outputs={'Out': [x_reordered]},
)
return shrink_memory(x_reordered, self.step_idx, self.lod_rank_table)
@signature_safe_contextmanager
......@@ -3569,10 +3921,9 @@ class DynamicRNN(object):
"""
if self.status != DynamicRNN.BEFORE_RNN:
raise ValueError("rnn.block() can only be invoke once")
self.step_idx = fill_constant(shape=[1],
dtype='int64',
value=0,
force_cpu=True)
self.step_idx = fill_constant(
shape=[1], dtype='int64', value=0, force_cpu=True
)
self.step_idx.stop_gradient = False
self.status = DynamicRNN.IN_RNN
with self.while_op.block():
......@@ -3582,15 +3933,18 @@ class DynamicRNN(object):
for new_mem, mem_array in self.mem_link:
array_write(x=new_mem, i=self.step_idx, array=mem_array)
less_than(x=self.step_idx,
y=self.max_seq_len,
force_cpu=True,
cond=self.cond)
less_than(
x=self.step_idx,
y=self.max_seq_len,
force_cpu=True,
cond=self.cond,
)
self.status = DynamicRNN.AFTER_RNN
for each_array in self.output_array:
self.outputs.append(
array_to_lod_tensor(x=each_array, table=self.lod_rank_table))
array_to_lod_tensor(x=each_array, table=self.lod_rank_table)
)
def __call__(self, *args, **kwargs):
"""
......@@ -3606,19 +3960,25 @@ class DynamicRNN(object):
ValueError: When :code:`__call__()` is called before :code:`block()` .
"""
if self.status != DynamicRNN.AFTER_RNN:
raise ValueError(("Output of the dynamic RNN can only be visited "
"outside the rnn block."))
raise ValueError(
(
"Output of the dynamic RNN can only be visited "
"outside the rnn block."
)
)
if len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def memory(self,
init=None,
shape=None,
value=0.0,
need_reorder=False,
dtype='float32'):
def memory(
self,
init=None,
shape=None,
value=0.0,
need_reorder=False,
dtype='float32',
):
r"""
Create a memory Variable for DynamicRNN to deliver data cross time steps.
It can be initialized by an existing Tensor or a constant Tensor of given
......@@ -3707,11 +4067,16 @@ class DynamicRNN(object):
self._assert_in_rnn_block_('memory')
self._init_zero_idx_()
if shape is not None:
check_type(shape, 'shape', (list, tuple),
'fluid.layers.DynamicRNN.memory()')
check_type(
shape,
'shape',
(list, tuple),
'fluid.layers.DynamicRNN.memory()',
)
if init is not None:
check_type(init, 'init', Variable,
'fluid.layers.DynamicRNN.memory()')
check_type(
init, 'init', Variable, 'fluid.layers.DynamicRNN.memory()'
)
parent_block = self._parent_block_()
init_tensor = init
if need_reorder == True:
......@@ -3719,32 +4084,36 @@ class DynamicRNN(object):
raise ValueError(
'If set need_reorder to True, make sure step_input be '
'invoked before '
'memory(init=init, need_reordered=True, ...).')
'memory(init=init, need_reordered=True, ...).'
)
init_reordered = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_mem_init_reordered'),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=init.dtype)
parent_block.append_op(type='reorder_lod_tensor_by_rank',
inputs={
'X': [init_tensor],
'RankTable': [self.lod_rank_table]
},
outputs={'Out': [init_reordered]})
dtype=init.dtype,
)
parent_block.append_op(
type='reorder_lod_tensor_by_rank',
inputs={
'X': [init_tensor],
'RankTable': [self.lod_rank_table],
},
outputs={'Out': [init_reordered]},
)
init_tensor = init_reordered
mem_array = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_mem_array'),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=init.dtype)
parent_block.append_op(type='write_to_array',
inputs={
'X': init_tensor,
'I': self.zero_idx
},
outputs={'Out': mem_array})
dtype=init.dtype,
)
parent_block.append_op(
type='write_to_array',
inputs={'X': init_tensor, 'I': self.zero_idx},
outputs={'Out': mem_array},
)
retv = array_read(array=mem_array, i=self.step_idx)
retv = shrink_memory(x=retv,
i=self.step_idx,
table=self.lod_rank_table)
retv = shrink_memory(
x=retv, i=self.step_idx, table=self.lod_rank_table
)
self.mem_dict[retv.name] = mem_array
return retv
else:
......@@ -3754,24 +4123,27 @@ class DynamicRNN(object):
)
parent_block = self._parent_block_()
init = parent_block.create_var(
name=unique_name.generate('mem_init'), dtype=dtype)
name=unique_name.generate('mem_init'), dtype=dtype
)
arr, dtype = self.input_array[0]
in0 = parent_block.create_var(name=unique_name.generate('in0'),
dtype=dtype)
parent_block.append_op(type='read_from_array',
inputs={
'X': [arr],
'I': [self.zero_idx]
},
outputs={'Out': [in0]})
parent_block.append_op(type='fill_constant_batch_size_like',
inputs={'Input': [in0]},
outputs={'Out': [init]},
attrs={
'shape': [-1] + shape,
'value': float(value),
'dtype': init.dtype
})
in0 = parent_block.create_var(
name=unique_name.generate('in0'), dtype=dtype
)
parent_block.append_op(
type='read_from_array',
inputs={'X': [arr], 'I': [self.zero_idx]},
outputs={'Out': [in0]},
)
parent_block.append_op(
type='fill_constant_batch_size_like',
inputs={'Input': [in0]},
outputs={'Out': [init]},
attrs={
'shape': [-1] + shape,
'value': float(value),
'dtype': init.dtype,
},
)
return self.memory(init=init)
def update_memory(self, ex_mem, new_mem):
......@@ -3785,7 +4157,7 @@ class DynamicRNN(object):
Returns:
None
Raises:
ValueError: When :code:`update_memory()` is called outside :code:`block()` .
TypeError: When :code:`ex_mem` or :code:`new_mem` is not a Variable.
......@@ -3793,10 +4165,18 @@ class DynamicRNN(object):
ValueError: When :code:`update_memory()` is called before :code:`step_input()` .
"""
self._assert_in_rnn_block_('update_memory')
check_type(ex_mem, 'ex_mem', Variable,
'fluid.layers.DynamicRNN.update_memory()')
check_type(new_mem, 'new_mem', Variable,
'fluid.layers.DynamicRNN.update_memory()')
check_type(
ex_mem,
'ex_mem',
Variable,
'fluid.layers.DynamicRNN.update_memory()',
)
check_type(
new_mem,
'new_mem',
Variable,
'fluid.layers.DynamicRNN.update_memory()',
)
mem_array = self.mem_dict.get(ex_mem.name, None)
if mem_array is None:
......@@ -3823,13 +4203,16 @@ class DynamicRNN(object):
self._assert_in_rnn_block_('output')
parent_block = self._parent_block_()
for each in outputs:
check_type(each, "outputs", Variable,
"fluid.layers.DynamicRNN.output")
check_type(
each, "outputs", Variable, "fluid.layers.DynamicRNN.output"
)
outside_array = parent_block.create_var(
name=unique_name.generate_with_ignorable_key("_".join(
[self.helper.name, "output_array", each.name])),
name=unique_name.generate_with_ignorable_key(
"_".join([self.helper.name, "output_array", each.name])
),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=each.dtype)
dtype=each.dtype,
)
array_write(x=each, i=self.step_idx, array=outside_array)
self.output_array.append(outside_array)
......@@ -3837,16 +4220,19 @@ class DynamicRNN(object):
if self.zero_idx is None:
parent_block = self._parent_block_()
self.zero_idx = parent_block.create_var(
name=unique_name.generate('zero_idx'), dtype='int64')
parent_block.append_op(type='fill_constant',
inputs={},
outputs={'Out': [self.zero_idx]},
attrs={
'shape': [1],
'dtype': self.zero_idx.dtype,
'value': float(0),
'force_cpu': True
})
name=unique_name.generate('zero_idx'), dtype='int64'
)
parent_block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [self.zero_idx]},
attrs={
'shape': [1],
'dtype': self.zero_idx.dtype,
'value': float(0),
'force_cpu': True,
},
)
def _parent_block_(self):
prog = self.helper.main_program
......@@ -3859,7 +4245,8 @@ class DynamicRNN(object):
def _assert_in_rnn_block_(self, method):
if self.status != DynamicRNN.IN_RNN:
raise ValueError(
"{0} can only be invoked inside rnn block.".format(method))
"{0} can only be invoked inside rnn block.".format(method)
)
def switch_case(branch_index, branch_fns, default=None, name=None):
......@@ -3936,44 +4323,71 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
def _check_args(branch_index, branch_fns, default):
check_variable_and_dtype(branch_index, 'branch_index',
['uint8', 'int32', 'int64'], 'switch_case')
check_variable_and_dtype(
branch_index,
'branch_index',
['uint8', 'int32', 'int64'],
'switch_case',
)
if convert_dtype(branch_index.dtype) != "int64":
branch_index = cast(branch_index, "int64")
check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case')
branch_fns = branch_fns.items() if isinstance(branch_fns,
dict) else branch_fns
branch_fns = (
branch_fns.items() if isinstance(branch_fns, dict) else branch_fns
)
branch_fns = list(enumerate(branch_fns)) if all(
callable(fn) for fn in branch_fns) else branch_fns
branch_fns = (
list(enumerate(branch_fns))
if all(callable(fn) for fn in branch_fns)
else branch_fns
)
keys_of_fns = []
for index_fn_pair in branch_fns:
if not isinstance(index_fn_pair, tuple):
raise TypeError(
_error_message("The elements' type", "branch_fns",
"switch_case", tuple, type(branch_fns)))
_error_message(
"The elements' type",
"branch_fns",
"switch_case",
tuple,
type(branch_fns),
)
)
if len(index_fn_pair) != 2:
raise TypeError(
_error_message("The tuple's size", "branch_fns",
"switch_case", "2",
str(len(index_fn_pair)) + "-tuple"))
_error_message(
"The tuple's size",
"branch_fns",
"switch_case",
"2",
str(len(index_fn_pair)) + "-tuple",
)
)
key, fn = index_fn_pair
if not isinstance(key, int):
raise TypeError(
_error_message("The key's type", "branch_fns",
"switch_case", int, type(key)))
_error_message(
"The key's type",
"branch_fns",
"switch_case",
int,
type(key),
)
)
if key in keys_of_fns:
raise ValueError(
"The key in 'branch_fns' must be unique, but '{}' appears more than once."
.format(key))
"The key in 'branch_fns' must be unique, but '{}' appears more than once.".format(
key
)
)
else:
keys_of_fns.append(key)
......@@ -3981,7 +4395,12 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
raise TypeError(
_error_message(
"The type of function for key {}".format(key),
"branch_fns", "switch_case", "callable", type(fn)))
"branch_fns",
"switch_case",
"callable",
type(fn),
)
)
if default is None:
default = sorted(branch_fns)[-1][1]
......@@ -4014,7 +4433,7 @@ def reorder_lod_tensor_by_rank(x, rank_table):
Args:
x(${x_type}): ${x_comment}.
rank_table(${rank_table_type}): ${rank_table_comment}.
Returns:
out(${out_type}): ${out_comment}.
......@@ -4032,20 +4451,20 @@ def reorder_lod_tensor_by_rank(x, rank_table):
"""
check_type(x, 'x', (Variable), 'reorder_lod_tensor_by_rank')
check_type(rank_table, 'rank_table', (Variable),
'reorder_lod_tensor_by_rank')
check_type(
rank_table, 'rank_table', (Variable), 'reorder_lod_tensor_by_rank'
)
if rank_table.type != core.VarDesc.VarType.LOD_RANK_TABLE:
raise TypeError("The type of rank_table should be LOD_RANK_TABLE.")
helper = LayerHelper('reorder_lod_tensor_by_rank', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type='reorder_lod_tensor_by_rank',
inputs={
'X': [x],
'RankTable': [rank_table]
},
outputs={'Out': [out]})
helper.append_op(
type='reorder_lod_tensor_by_rank',
inputs={'X': [x], 'RankTable': [rank_table]},
outputs={'Out': [out]},
)
return out
......@@ -4084,14 +4503,16 @@ def is_empty(x, name=None):
if _in_legacy_dygraph():
return _legacy_C_ops.is_empty(x)
check_variable_and_dtype(x, 'x', ['float32', 'float64', 'int32', 'int64'],
'is_empty')
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int32', 'int64'], 'is_empty'
)
check_type(name, "name", (str, type(None)), "is_empty")
helper = LayerHelper("is_empty", **locals())
cond = helper.create_variable_for_type_inference(dtype='bool')
cond.stop_gradient = True
helper.append_op(type='is_empty',
inputs={'X': [x]},
outputs={'Out': [cond]})
helper.append_op(
type='is_empty', inputs={'X': [x]}, outputs={'Out': [cond]}
)
return cond
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import warnings
from paddle.fluid.dygraph.dygraph_to_static.program_translator import (
convert_to_static,
)
from paddle.fluid.layers.control_flow import cond
@paddle.jit.to_static
def fun1():
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
if a > b:
b = paddle.to_tensor(3)
else:
b = None
def true_fn():
return [paddle.to_tensor(1), [paddle.to_tensor(2), paddle.to_tensor(3)]]
def false_fn():
return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]]
class TestReturnNoneInIfelse(unittest.TestCase):
def test_dy2static_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fun1()
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message
):
flag = True
break
self.assertTrue(flag)
def test_cond_warning(self):
paddle.enable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
cond(a < b, true_fn, false_fn, return_names=['ret1', 'ret2'])
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message
):
flag = True
break
self.assertTrue(flag)
if __name__ == '__main__':
unittest.main()
......@@ -31,7 +31,6 @@ np.random.seed(123)
class TestCondInputOutput(unittest.TestCase):
def test_return_single_var(self):
"""
pseudocode:
......@@ -59,13 +58,16 @@ class TestCondInputOutput(unittest.TestCase):
out = layers.cond(pred, true_func, false_func)
# out is one tensor
place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
ret, = exe.run(main_program, fetch_list=[out.name])
np.testing.assert_allclose(np.asarray(ret),
np.full((3, 2), -1, np.int32),
rtol=1e-05)
(ret,) = exe.run(main_program, fetch_list=[out.name])
np.testing.assert_allclose(
np.asarray(ret), np.full((3, 2), -1, np.int32), rtol=1e-05
)
def test_return_var_tuple(self):
"""
......@@ -80,18 +82,14 @@ class TestCondInputOutput(unittest.TestCase):
paddle.enable_static()
def true_func():
return layers.fill_constant(shape=[1, 2], dtype='int32',
value=1), layers.fill_constant(
shape=[2, 3],
dtype='bool',
value=True)
return layers.fill_constant(
shape=[1, 2], dtype='int32', value=1
), layers.fill_constant(shape=[2, 3], dtype='bool', value=True)
def false_func():
return layers.fill_constant(shape=[3, 4], dtype='float32',
value=3), layers.fill_constant(
shape=[4, 5],
dtype='int64',
value=2)
return layers.fill_constant(
shape=[3, 4], dtype='float32', value=3
), layers.fill_constant(shape=[4, 5], dtype='int64', value=2)
main_program = Program()
startup_program = Program()
......@@ -100,16 +98,19 @@ class TestCondInputOutput(unittest.TestCase):
out = layers.cond(pred, true_func, false_func)
# out is a tuple containing 2 tensors
place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out)
np.testing.assert_allclose(np.asarray(ret[0]),
np.full((1, 2), 1, np.int32),
rtol=1e-05)
np.testing.assert_allclose(np.asarray(ret[1]),
np.full((2, 3), True, bool),
rtol=1e-05)
np.testing.assert_allclose(
np.asarray(ret[0]), np.full((1, 2), 1, np.int32), rtol=1e-05
)
np.testing.assert_allclose(
np.asarray(ret[1]), np.full((2, 3), True, bool), rtol=1e-05
)
def test_pass_and_modify_var(self):
"""
......@@ -137,20 +138,28 @@ class TestCondInputOutput(unittest.TestCase):
with program_guard(main_program, startup_program):
a = layers.fill_constant(shape=[3, 2, 1], dtype='int32', value=7)
i = fluid.data(name="i", shape=[1], dtype='int32')
pred = ((i % 2) == 0)
a = layers.cond(pred, lambda: true_func(a, i),
lambda: false_func(a, i))
place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
pred = (i % 2) == 0
a = layers.cond(
pred, lambda: true_func(a, i), lambda: false_func(a, i)
)
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
for feed_i in range(5):
expected_a = 7 * (feed_i + 1) if feed_i % 2 == 0 else 8 - feed_i
ret, = exe.run(main_program,
feed={'i': np.full((1), feed_i, np.int32)},
fetch_list=[a])
np.testing.assert_allclose(np.asarray(ret),
np.full((3, 2, 1), expected_a, np.int32),
rtol=1e-05)
(ret,) = exe.run(
main_program,
feed={'i': np.full((1), feed_i, np.int32)},
fetch_list=[a],
)
np.testing.assert_allclose(
np.asarray(ret),
np.full((3, 2, 1), expected_a, np.int32),
rtol=1e-05,
)
def test_return_none(self):
"""
......@@ -174,12 +183,15 @@ class TestCondInputOutput(unittest.TestCase):
startup_program = Program()
with program_guard(main_program, startup_program):
i = fluid.data(name="i", shape=[1], dtype='int32')
pred = ((i % 2) == 0)
pred = (i % 2) == 0
out1 = layers.cond(pred, true_func, false_func)
out2 = layers.cond(pred, None, false_func)
out3 = layers.cond(pred, true_func, None)
place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
for feed_i in range(5):
# Test that output is None is runnable
......@@ -202,17 +214,15 @@ class TestCondInputOutput(unittest.TestCase):
return layers.fill_constant(shape=[2, 7], dtype='int32', value=3)
def func_return_two_tensors():
return layers.fill_constant(shape=[3, 1], dtype='int32',
value=7), layers.fill_constant(
shape=[3, 1],
dtype='int32',
value=8)
return layers.fill_constant(
shape=[3, 1], dtype='int32', value=7
), layers.fill_constant(shape=[3, 1], dtype='int32', value=8)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
i = fluid.data(name="i", shape=[1], dtype='int32')
pred = ((i % 2) == 0)
pred = (i % 2) == 0
with self.assertRaises(TypeError):
out = layers.cond(pred, i, func_return_one_tensor)
......@@ -220,47 +230,57 @@ class TestCondInputOutput(unittest.TestCase):
out = layers.cond(pred, func_return_one_tensor, np.asarray([3]))
with self.assertRaises(Exception) as e:
out = layers.cond(pred, func_return_none,
func_return_one_tensor)
out = layers.cond(
pred, func_return_none, func_return_one_tensor
)
self.assertTrue(
"Incompatible return values of true_fn and false_fn in cond" in
str(e.exception))
"Incompatible return values of true_fn and false_fn in cond"
in str(e.exception)
)
with self.assertRaises(Exception) as e:
out = layers.cond(pred, func_return_two_tensors,
func_return_none)
out = layers.cond(
pred, func_return_two_tensors, func_return_none
)
self.assertTrue(
"Incompatible return values of true_fn and false_fn in cond" in
str(e.exception))
"Incompatible return values of true_fn and false_fn in cond"
in str(e.exception)
)
with self.assertRaises(Exception) as e:
out = layers.cond(pred, func_return_one_tensor,
func_return_two_tensors)
out = layers.cond(
pred, func_return_one_tensor, func_return_two_tensors
)
self.assertTrue(
"true fn returns 1 vars, but false fn returns 2 vars, which is not equals"
in str(e.exception))
in str(e.exception)
)
def test_extremely_simple_net_with_op_in_condition(self):
paddle.enable_static()
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
a = fluid.layers.fill_constant(shape=[1],
dtype='float32',
value=1.23)
a = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.23
)
a.stop_gradient = False
b = fluid.layers.fill_constant(shape=[1],
dtype='float32',
value=1.25)
b = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.25
)
b.stop_gradient = False
out = layers.cond(a - b < -1.0, lambda: a, lambda: b)
append_backward(out)
place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
ret = exe.run(main_program,
fetch_list=[out, b, a.grad_name, b.grad_name])
ret = exe.run(
main_program, fetch_list=[out, b, a.grad_name, b.grad_name]
)
# Note: fill_constant has loss of precision, you have to assertEqual
# with values doens't lose precision in float-point number.
self.assertEqual(ret[0][0], ret[1][0])
......@@ -269,7 +289,6 @@ class TestCondInputOutput(unittest.TestCase):
class TestCondNestedControlFlow(unittest.TestCase):
def test_cond_inside_cond(self):
"""
pseudocode:
......@@ -277,7 +296,7 @@ class TestCondNestedControlFlow(unittest.TestCase):
a = 2 * i
if i < 5:
if i >= 3:
return a + a
return a + a
else:
return a - a
else:
......@@ -290,25 +309,37 @@ class TestCondNestedControlFlow(unittest.TestCase):
paddle.enable_static()
def less_than_branch(i, a):
return layers.cond(i >= 3.0, lambda: layers.elementwise_add(a, a),
lambda: layers.elementwise_sub(a, a))
return layers.cond(
i >= 3.0,
lambda: layers.elementwise_add(a, a),
lambda: layers.elementwise_sub(a, a),
)
def greater_equal_branch(i, a):
return layers.cond(i < 8.0, lambda: layers.elementwise_mul(a, a),
lambda: layers.elementwise_div(a, a))
return layers.cond(
i < 8.0,
lambda: layers.elementwise_mul(a, a),
lambda: layers.elementwise_div(a, a),
)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
i = fluid.data(name="i", shape=[1], dtype='float32')
a = 2.0 * i
out = layers.cond(i < 5.0, lambda: less_than_branch(i, a),
lambda: greater_equal_branch(i, a))
out = layers.cond(
i < 5.0,
lambda: less_than_branch(i, a),
lambda: greater_equal_branch(i, a),
)
mean = paddle.mean(out)
append_backward(mean)
place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
for feed_i in range(0, 10):
expected_a = 2.0 * feed_i
......@@ -318,9 +349,11 @@ class TestCondNestedControlFlow(unittest.TestCase):
else:
expected_ret = expected_a * expected_a if feed_i < 8 else 1.0
expected_a_grad = 2.0 * expected_a if feed_i < 8 else 0.0
ret = exe.run(main_program,
feed={'i': np.full((1), feed_i, np.float32)},
fetch_list=[out.name, a.grad_name])
ret = exe.run(
main_program,
feed={'i': np.full((1), feed_i, np.float32)},
fetch_list=[out.name, a.grad_name],
)
self.assertEqual(ret[0][0], expected_ret)
self.assertEqual(ret[1][0], expected_a_grad)
......@@ -330,24 +363,34 @@ class TestCondNestedControlFlow(unittest.TestCase):
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
a = fluid.layers.fill_constant(shape=[1],
dtype='float32',
value=1.23)
a = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.23
)
a.stop_gradient = False
b = fluid.layers.fill_constant(shape=[1],
dtype='float32',
value=1.24)
b = fluid.layers.fill_constant(
shape=[1], dtype='float32', value=1.24
)
b.stop_gradient = False
out = fluid.layers.cond(
a < b, lambda: fluid.layers.cond(
a - b < -1.0, lambda: fluid.layers.elementwise_add(a, b),
lambda: fluid.layers.elementwise_mul(a, b)), lambda:
fluid.layers.cond(a == b, lambda: fluid.layers.elementwise_sub(
a, b), lambda: fluid.layers.elementwise_pow(a, b)))
a < b,
lambda: fluid.layers.cond(
a - b < -1.0,
lambda: fluid.layers.elementwise_add(a, b),
lambda: fluid.layers.elementwise_mul(a, b),
),
lambda: fluid.layers.cond(
a == b,
lambda: fluid.layers.elementwise_sub(a, b),
lambda: fluid.layers.elementwise_pow(a, b),
),
)
append_backward(out)
place = fluid.CUDAPlace(
0) if core.is_compiled_with_cuda() else fluid.CPUPlace()
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=[out, a.grad_name, b.grad_name])
# Note: fill_constant has loss of precision, so we assertAlmostEqual.
......@@ -357,7 +400,6 @@ class TestCondNestedControlFlow(unittest.TestCase):
class TestCondBackward(unittest.TestCase):
def backward_value_helper(self, cond_func, use_cuda, use_parallel_exe):
"""
Helper function that compares calculated backward value is close to dy/dx
......@@ -381,70 +423,76 @@ class TestCondBackward(unittest.TestCase):
num_devices = 1
if use_parallel_exe:
os.environ['CPU_NUM'] = str(2)
exe = fluid.ParallelExecutor(use_cuda=use_cuda,
main_program=main_program,
loss_name=loss.name)
exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=main_program,
loss_name=loss.name,
)
num_devices = exe.device_count
delta = 0.005
for feed_i in range(0, 10):
feed_img = np.random.random(size=[1, 9]).astype(np.float32)
feed_label = np.random.randint(low=0,
high=10,
size=[1, 1],
dtype=np.int64)
feed_label = np.random.randint(
low=0, high=10, size=[1, 1], dtype=np.int64
)
if use_parallel_exe:
img_grad, loss_value = exe.run(
feed={
'i': np.full((num_devices), feed_i, np.int32),
'image': np.repeat(feed_img, num_devices, axis=0),
'label': np.repeat(feed_label, num_devices, axis=0)
'label': np.repeat(feed_label, num_devices, axis=0),
},
fetch_list=[img.grad_name, loss.name])
fetch_list=[img.grad_name, loss.name],
)
else:
img_grad, loss_value = exe.run(
main_program,
feed={
'i': np.full((1), feed_i, np.int32),
'image': feed_img,
'label': feed_label
'label': feed_label,
},
fetch_list=[img.grad_name, loss.name])
fetch_list=[img.grad_name, loss.name],
)
numerical_grad = np.zeros(shape=[num_devices, 9], dtype=np.float32)
feed_img_delta = np.copy(feed_img)
for j in range(9):
feed_img_delta[0][j] = feed_img[0][j] + delta
if use_parallel_exe:
loss_delta = exe.run(feed={
'i':
np.full((num_devices), feed_i, np.int32),
'image':
np.repeat(feed_img_delta, num_devices, axis=0),
'label':
np.repeat(feed_label, num_devices, axis=0)
},
fetch_list=[loss.name])
multi_device_grad = (loss_delta[0] -
loss_value[0]) / delta / num_devices
loss_delta = exe.run(
feed={
'i': np.full((num_devices), feed_i, np.int32),
'image': np.repeat(
feed_img_delta, num_devices, axis=0
),
'label': np.repeat(feed_label, num_devices, axis=0),
},
fetch_list=[loss.name],
)
multi_device_grad = (
(loss_delta[0] - loss_value[0]) / delta / num_devices
)
for d in range(num_devices):
numerical_grad[d][j] = multi_device_grad[d]
else:
loss_delta = exe.run(main_program,
feed={
'i': np.full((1), feed_i,
np.int32),
'image': feed_img_delta,
'label': feed_label
},
fetch_list=[loss.name])
numerical_grad[0][j] = (loss_delta[0] -
loss_value[0]) / delta
loss_delta = exe.run(
main_program,
feed={
'i': np.full((1), feed_i, np.int32),
'image': feed_img_delta,
'label': feed_label,
},
fetch_list=[loss.name],
)
numerical_grad[0][j] = (
loss_delta[0] - loss_value[0]
) / delta
feed_img_delta[0][j] = feed_img[0][j]
np.testing.assert_allclose(img_grad,
numerical_grad,
rtol=0.05,
atol=0.05)
np.testing.assert_allclose(
img_grad, numerical_grad, rtol=0.05, atol=0.05
)
def add_optimizer_helper(self, cond_func, use_cuda, use_parallel_exe):
"""
......@@ -465,43 +513,49 @@ class TestCondBackward(unittest.TestCase):
exe.run(startup_program)
if use_parallel_exe:
os.environ['CPU_NUM'] = str(2)
exe = fluid.ParallelExecutor(use_cuda=use_cuda,
main_program=main_program,
loss_name=loss.name)
exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=main_program,
loss_name=loss.name,
)
num_devices = exe.device_count
for feed_i in range(0, 10):
feed_img = np.random.random(size=[16, 784]).astype(np.float32)
feed_label = np.random.randint(low=0,
high=10,
size=[16, 1],
dtype=np.int64)
feed_label = np.random.randint(
low=0, high=10, size=[16, 1], dtype=np.int64
)
if use_parallel_exe:
exe.run(feed={
'i': np.full((num_devices), feed_i, np.int32),
'image': np.repeat(feed_img, num_devices, axis=0),
'label': np.repeat(feed_label, num_devices, axis=0)
},
fetch_list=[loss.name])
exe.run(
feed={
'i': np.full((num_devices), feed_i, np.int32),
'image': np.repeat(feed_img, num_devices, axis=0),
'label': np.repeat(feed_label, num_devices, axis=0),
},
fetch_list=[loss.name],
)
else:
exe.run(main_program,
feed={
'i': np.full((1), feed_i, np.int32),
'image': feed_img,
'label': feed_label
},
fetch_list=[loss])
exe.run(
main_program,
feed={
'i': np.full((1), feed_i, np.int32),
'image': feed_img,
'label': feed_label,
},
fetch_list=[loss],
)
def test_cond_backward(self):
paddle.enable_static()
def cond_func(i, img, label):
predicate = ((i % 2) == 0)
predicate = (i % 2) == 0
return layers.cond(
predicate,
lambda: simple_fc_net_with_inputs(img, label, class_num=10),
lambda: batchnorm_fc_with_inputs(img, label, class_num=10))
lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
)
for use_parallel_exe in [False, True]:
if use_parallel_exe and os.name == "nt":
......@@ -510,10 +564,12 @@ class TestCondBackward(unittest.TestCase):
)
continue
self.backward_value_helper(cond_func, core.is_compiled_with_cuda(),
use_parallel_exe)
self.add_optimizer_helper(cond_func, core.is_compiled_with_cuda(),
use_parallel_exe)
self.backward_value_helper(
cond_func, core.is_compiled_with_cuda(), use_parallel_exe
)
self.add_optimizer_helper(
cond_func, core.is_compiled_with_cuda(), use_parallel_exe
)
def test_half_nested_cond_backward(self):
paddle.enable_static()
......@@ -522,15 +578,18 @@ class TestCondBackward(unittest.TestCase):
return layers.cond(
(i % 2) == 0,
lambda: simple_fc_net_with_inputs(img, label, class_num=10),
lambda: batchnorm_fc_with_inputs(img, label, class_num=10))
lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
)
def cond_func_simple_net_at_true(i, img, label):
return layers.cond(i < 5, lambda: branch(i, img, label),
lambda: paddle.mean(img))
return layers.cond(
i < 5, lambda: branch(i, img, label), lambda: paddle.mean(img)
)
def cond_func_simple_net_at_false(i, img, label):
return layers.cond(i < 5, lambda: paddle.mean(img),
lambda: branch(i, img, label))
return layers.cond(
i < 5, lambda: paddle.mean(img), lambda: branch(i, img, label)
)
for use_parallel_exe in [False, True]:
if use_parallel_exe and os.name == "nt":
......@@ -539,35 +598,47 @@ class TestCondBackward(unittest.TestCase):
)
continue
self.backward_value_helper(cond_func_simple_net_at_true,
core.is_compiled_with_cuda(),
use_parallel_exe)
self.add_optimizer_helper(cond_func_simple_net_at_true,
core.is_compiled_with_cuda(),
use_parallel_exe)
self.backward_value_helper(cond_func_simple_net_at_false,
core.is_compiled_with_cuda(),
use_parallel_exe)
self.add_optimizer_helper(cond_func_simple_net_at_false,
core.is_compiled_with_cuda(),
use_parallel_exe)
self.backward_value_helper(
cond_func_simple_net_at_true,
core.is_compiled_with_cuda(),
use_parallel_exe,
)
self.add_optimizer_helper(
cond_func_simple_net_at_true,
core.is_compiled_with_cuda(),
use_parallel_exe,
)
self.backward_value_helper(
cond_func_simple_net_at_false,
core.is_compiled_with_cuda(),
use_parallel_exe,
)
self.add_optimizer_helper(
cond_func_simple_net_at_false,
core.is_compiled_with_cuda(),
use_parallel_exe,
)
def test_nested_cond_backward(self):
paddle.enable_static()
def branch(i, img, label, mod_two):
if mod_two:
predicate = ((i % 2) == 0)
predicate = (i % 2) == 0
else:
predicate = ((i % 2) != 0)
predicate = (i % 2) != 0
return layers.cond(
predicate,
lambda: simple_fc_net_with_inputs(img, label, class_num=10),
lambda: batchnorm_fc_with_inputs(img, label, class_num=10))
lambda: batchnorm_fc_with_inputs(img, label, class_num=10),
)
def cond_func(i, img, label):
return layers.cond(i < 5, lambda: branch(i, img, label, True),
lambda: branch(i, img, label, False))
return layers.cond(
i < 5,
lambda: branch(i, img, label, True),
lambda: branch(i, img, label, False),
)
for use_parallel_exe in [False, True]:
if use_parallel_exe and os.name == "nt":
......@@ -575,14 +646,15 @@ class TestCondBackward(unittest.TestCase):
"Skip use_parallel_exe=True in Windows because of flaky test when using PE under old Windows machine"
)
continue
self.backward_value_helper(cond_func, core.is_compiled_with_cuda(),
use_parallel_exe)
self.add_optimizer_helper(cond_func, core.is_compiled_with_cuda(),
use_parallel_exe)
self.backward_value_helper(
cond_func, core.is_compiled_with_cuda(), use_parallel_exe
)
self.add_optimizer_helper(
cond_func, core.is_compiled_with_cuda(), use_parallel_exe
)
class TestCondWithError(unittest.TestCase):
def test_input_type_error(self):
paddle.enable_static()
main_program = framework.Program()
......@@ -606,5 +678,44 @@ class TestCondWithError(unittest.TestCase):
layers.cond(pred, func, func, set())
class TestCondWithDict(unittest.TestCase):
def test_input_with_dict(self):
paddle.enable_static()
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
def true_func():
return {
'1': paddle.full(shape=[3, 2], dtype='int32', fill_value=1),
'2': paddle.full(
shape=[2, 3], dtype='bool', fill_value=True
),
}
def false_func():
return {
'1': paddle.full(
shape=[3, 4], dtype='float32', fill_value=3
),
'2': paddle.full(shape=[4, 5], dtype='int64', fill_value=2),
}
x = paddle.full(shape=[1], dtype='float32', fill_value=0.1)
y = paddle.full(shape=[1], dtype='float32', fill_value=0.23)
pred = paddle.less_than(x=x, y=y, name=None)
ret = paddle.static.nn.cond(pred, true_func, false_func)
self.assertEqual(
ret['1'].shape,
(3, -1),
f"The shape is not correct, expects (3, -1) but gets {ret['1'].shape}.",
)
self.assertEqual(
ret['2'].shape,
(-1, -1),
f"The shape is not correct, expects (-1, -1) but gets {ret['2'].shape}.",
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册