未验证 提交 7c903ae7 编写于 作者: Y yunyaoXYY 提交者: GitHub

[Clean fluid] Clean ones, reverse, save, save_combine, load_combine, has_inf,...

[Clean fluid] Clean ones, reverse, save, save_combine, load_combine, has_inf, zeros_like and ones_like (#48424)

* Clean fluid ones

* clean ones_like

* clean zeros_like

* clean save,save_combine,load_combine

* clean reverse

* clean has_inf

* clean reverse tests
上级 2de881aa
......@@ -407,15 +407,15 @@ def basic_gru(
)
if bidirectional:
bw_input = layers.reverse(input, axis=[0])
bw_input = paddle.reverse(input, axis=[0])
bw_mask = None
if mask:
bw_mask = layers.reverse(mask, axis=[0])
bw_mask = paddle.reverse(mask, axis=[0])
bw_rnn_out, bw_last_hidden = get_single_direction_output(
bw_input, bw_unit_list, bw_mask, direc_index=1
)
bw_rnn_out = layers.reverse(bw_rnn_out, axis=[0])
bw_rnn_out = paddle.reverse(bw_rnn_out, axis=[0])
rnn_out = layers.concat([fw_rnn_out, bw_rnn_out], axis=2)
last_hidden = layers.concat([fw_last_hidden, bw_last_hidden], axis=1)
......@@ -718,15 +718,15 @@ def basic_lstm(
)
if bidirectional:
bw_input = layers.reverse(input, axis=[0])
bw_input = paddle.reverse(input, axis=[0])
bw_mask = None
if mask:
bw_mask = layers.reverse(mask, axis=[0])
bw_mask = paddle.reverse(mask, axis=[0])
bw_rnn_out, bw_last_hidden, bw_last_cell = get_single_direction_output(
bw_input, bw_unit_list, bw_mask, direc_index=1
)
bw_rnn_out = layers.reverse(bw_rnn_out, axis=[0])
bw_rnn_out = paddle.reverse(bw_rnn_out, axis=[0])
rnn_out = layers.concat([fw_rnn_out, bw_rnn_out], axis=2)
last_hidden = layers.concat([fw_last_hidden, bw_last_hidden], axis=1)
......
......@@ -659,9 +659,9 @@ class MultivariateNormalDiag(Distribution):
def _det(self, value):
batch_shape = list(value.shape)
one_all = tensor.ones(shape=batch_shape, dtype=self.loc.dtype)
one_all = paddle.ones(shape=batch_shape, dtype=self.loc.dtype)
one_diag = tensor.diag(
tensor.ones(shape=[batch_shape[0]], dtype=self.loc.dtype)
paddle.ones(shape=[batch_shape[0]], dtype=self.loc.dtype)
)
det_diag = paddle.prod(value + one_all - one_diag)
......@@ -670,9 +670,9 @@ class MultivariateNormalDiag(Distribution):
def _inv(self, value):
batch_shape = list(value.shape)
one_all = tensor.ones(shape=batch_shape, dtype=self.loc.dtype)
one_all = paddle.ones(shape=batch_shape, dtype=self.loc.dtype)
one_diag = tensor.diag(
tensor.ones(shape=[batch_shape[0]], dtype=self.loc.dtype)
paddle.ones(shape=[batch_shape[0]], dtype=self.loc.dtype)
)
inv_diag = paddle.pow(value, (one_all - 2 * one_diag))
......
......@@ -595,9 +595,9 @@ def _rnn_dynamic_graph(
mask = paddle.transpose(mask, [1, 0])
if is_reverse:
inputs = map_structure(lambda x: tensor.reverse(x, axis=[0]), inputs)
inputs = map_structure(lambda x: paddle.reverse(x, axis=[0]), inputs)
mask = (
tensor.reverse(mask, axis=[0])
paddle.reverse(mask, axis=[0])
if sequence_length is not None
else None
)
......@@ -626,7 +626,7 @@ def _rnn_dynamic_graph(
if is_reverse:
final_outputs = map_structure(
lambda x: tensor.reverse(x, axis=time_step_index), final_outputs
lambda x: paddle.reverse(x, axis=time_step_index), final_outputs
)
final_states = new_states
......@@ -681,8 +681,8 @@ def _rnn_static_graph(
)
mask = paddle.transpose(mask, [1, 0])
if is_reverse:
inputs = map_structure(lambda x: tensor.reverse(x, axis=[0]), inputs)
mask = tensor.reverse(mask, axis=[0]) if sequence_length else None
inputs = map_structure(lambda x: paddle.reverse(x, axis=[0]), inputs)
mask = paddle.reverse(mask, axis=[0]) if sequence_length else None
# StaticRNN
rnn = control_flow.StaticRNN()
......@@ -711,7 +711,7 @@ def _rnn_static_graph(
if is_reverse:
final_outputs = map_structure(
lambda x: tensor.reverse(x, axis=[0]), final_outputs
lambda x: paddle.reverse(x, axis=[0]), final_outputs
)
if not time_major:
......@@ -1251,7 +1251,7 @@ class BeamSearchDecoder(Decoder):
value=False,
force_cpu=True,
)
init_lengths = tensor.zeros_like(init_inputs)
init_lengths = paddle.zeros_like(init_inputs)
init_inputs = (
self.embedding_fn(init_inputs) if self.embedding_fn else init_inputs
)
......@@ -1482,7 +1482,7 @@ def _dynamic_decode_imperative(
initial_finished,
)
cond = paddle.logical_not((nn.reduce_all(initial_finished)))
sequence_lengths = tensor.cast(tensor.zeros_like(initial_finished), "int64")
sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64")
outputs = None
step_idx = 0
......@@ -1596,7 +1596,7 @@ def _dynamic_decode_declarative(
)
while_op = control_flow.While(cond, is_test=is_test)
sequence_lengths = tensor.cast(tensor.zeros_like(initial_finished), "int64")
sequence_lengths = tensor.cast(paddle.zeros_like(initial_finished), "int64")
sequence_lengths.stop_gradient = True
if is_test:
......
......@@ -60,13 +60,8 @@ __all__ = [
'argmin',
'argmax',
'argsort',
'ones',
'zeros',
'reverse',
'has_inf',
'linspace',
'zeros_like',
'ones_like',
'diag',
]
......@@ -1324,35 +1319,6 @@ def argsort(input, axis=-1, descending=False, name=None):
return out, ids
def ones(shape, dtype, force_cpu=False):
"""
The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 1.
Its :attr:`stop_gradient` will be set to True to stop gradient computation.
Parameters:
shape(tuple|list|Tensor): Shape of output Tensor, the data type of shape is int32 or int64.
dtype (np.dtype|str): Data type of output Tensor, it supports
bool, float16, float32, float64, int32 and int64.
force_cpu (bool, optional): Whether force to store the output Tensor in CPU memory.
If :attr:`force_cpu` is False, the output Tensor will be stored in running device memory.
Default: False.
Returns:
Tensor: A tensor of data type :attr:`dtype` with shape :attr:`shape` and all elements set to 1.
Examples:
.. code-block:: python
import paddle.fluid as fluid
data0 = fluid.layers.ones(shape=[2, 4], dtype='float32') # [[1., 1., 1., 1.], [1., 1., 1., 1.]]
# shape is a Tensor
shape = fluid.layers.fill_constant(shape=[2], dtype='int32', value=2)
data1 = fluid.layers.ones(shape=shape, dtype='int32') #[[1, 1], [1, 1]]
"""
return fill_constant(value=1.0, **locals())
def zeros(shape, dtype, force_cpu=False, name=None):
"""
The OP creates a tensor of specified :attr:`shape` and :attr:`dtype`, and fills it with 0.
......@@ -1384,190 +1350,6 @@ def zeros(shape, dtype, force_cpu=False, name=None):
return fill_constant(value=0.0, **locals())
def reverse(x, axis):
"""
:alias_main: paddle.reverse
:alias: paddle.reverse,paddle.tensor.reverse,paddle.tensor.manipulation.reverse
:old_api: paddle.fluid.layers.reverse
The OP reverses the tensor :attr:`x` along the given :attr:`axis`.
.. code-block:: text
Case 1:
Given a LoDTensor:
x = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
axis = [0, 1]
Then:
output = [[8, 7, 6], [5, 4, 3], [2, 1, 0]]
Case 2:
Given a LoDTensorArray:
x = {[[0, 1], [2, 3]],
[[4, 5, 6]],
[[7],[8], [9]]}
axis = 0
Then:
output = {[[7],[8], [9]],
[[4, 5, 6]],
[[0, 1], [2, 3]]}
Parameters:
x (Variable): A tensor or LoDTensorArray to be reversed, its data type supports bool, float32, float64, int32, int64 and uint8.
If input is a LoDTensorArray, returns a new reversed LoDTensorArray without changing the internal order of each inner tensor.
axis (int|tuple|list): A dimension or a set of dimensions of :attr:`x` to reverse. Must be
in the range [-rank( :attr:`x` ), rank( :attr:`x` )). If it is a tuple or a list, reversing
will be apply on each axis in the tuple or list. If input is a LoDTensorArray, the value of axis shall be 0, or a
list [0] or tuple (0, ) with shape [1].
Returns:
Variable: The reversed tensor with the same shape and data type as :attr:`x`.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import numpy as np
data = fluid.layers.assign(np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]], dtype='float32')) # [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]]
result1 = fluid.layers.reverse(data, 0) # [[6., 7., 8.], [3., 4., 5.], [0., 1., 2.]]
result2 = fluid.layers.reverse(data, [0, 1]) # [[8., 7., 6.], [5., 4., 3.], [2., 1., 0.]]
# example of LoDTensorArray
data1 = fluid.layers.assign(np.array([[0, 1, 2]], dtype='float32'))
data2 = fluid.layers.assign(np.array([[3, 4, 5]], dtype='float32'))
tensor_array = fluid.layers.create_array(dtype='float32')
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
fluid.layers.array_write(data1, i, tensor_array)
fluid.layers.array_write(data2, i+1, tensor_array)
reversed_tensor_array = fluid.layers.reverse(tensor_array, 0) # {[[3, 4, 5]], [[0, 1, 2]]}
"""
check_variable_and_dtype(
x, 'x', ('float32', 'float64', 'int32', 'int64', 'uint8'), 'reverse'
)
check_type(axis, 'axis', (int, tuple, list, Variable), 'reverse')
if isinstance(axis, int):
axis = [axis]
if in_dygraph_mode():
return _C_ops.reverse(x, axis)
helper = LayerHelper("reverse", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='reverse',
inputs={'X': x},
outputs={'Out': [out]},
attrs={'axis': axis},
)
return out
def save(x, file_path, overwrite=True):
"""
Saves a variable as a file.
Args:
x(variable): The Tensor/LoDTensor to be saved.
file_path(str): The file path where the variable will be saved.
overwrite(bool): Whether or not cover the given file when it has already
existed. If it's set 'False' and the file is existed, a runtime
error will be thrown.
"""
helper = LayerHelper("save", **locals())
helper.append_op(
type="save",
inputs={"input": x},
outputs={},
args={"file_path": file_path, "overwrite": overwrite},
)
def save_combine(x, file_path, overwrite=True):
"""
Saves a list of variables into a single file.
Args:
x(list): A list of Tensor/LoDTensor variables to be saved together in
a single file.
file_path(str): The file path where variables will be saved.
overwrite(bool): Whether or not cover the given file when it has already
existed. If it's set 'False' and the file is existed, a runtime
error will be thrown.
Returns:
There is no return value.
Examples:
.. code-block:: python
import paddle.fluid as fluid
v1 = fluid.layers.data(name="data",
shape=(4, 6),
dtype="float32")
v2 = fluid.layers.data(name="data",
shape=(6, 8, 4),
dtype="float32")
normed = fluid.layers.save_combine([v1, v2], file_path="output")
"""
helper = LayerHelper("save_combine", **locals())
helper.append_op(
type="save_combine",
inputs={"input": x},
outputs={},
args={"file_path": file_path, "overwrite": overwrite},
)
def load_combine(out, file_path):
"""
Loads a list of variable from a single file.
Args:
out(list): The list of variables to be read from the disk file.
file_path(str): The path of the disk file.
"""
helper = LayerHelper("load_combine", **locals())
helper.append_op(
type="load_combine",
inputs={},
output={"Out": out},
args={"file_path": file_path},
)
def has_inf(x):
"""
Test if any of x contains an infinity number
Args:
x (Tensor): The Tensor to be checked.
Returns:
Tensor: The tensor storing the output, only a bool value, indicating that whether there is infinity number in x or not.
Examples:
.. code-block:: python
import paddle
data = paddle.randn(shape=[4, 32, 32], dtype="float32")
res = paddle.fluid.layers.has_inf(data)
# [False]
"""
if _non_static_mode():
return _legacy_C_ops.isinf(x)
check_type(x, 'x', (Variable), 'has_inf')
helper = LayerHelper("isinf", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="isinf", inputs={"X": x}, outputs={"Out": out})
return out
def linspace(start, stop, num, dtype=None, name=None):
r"""
This OP return fixed number of evenly spaced values within a given interval.
......@@ -1683,55 +1465,6 @@ def linspace(start, stop, num, dtype=None, name=None):
return out
def zeros_like(x, out=None):
"""
This OP creates a zeros tensor which has identical shape and dtype
with `x`.
Args:
x(Variable): The input tensor which specifies shape and dtype, the
input data dtype could be bool, float32, float64, int32, int64.
out(Variable, optional): If is :attr:`None` , the op will create the
variable as output, the data type and shape of this variable will
be same as input :attr:`x`. If is a tensor, the data type and shape
need to be same as input :attr:`x`. The default value is :attr:`None` .
Returns:
Variable: The N-D tensor, the element in tensor is related to input
data type, if the input data type is bool, the output value is
False, otherwise is zero. The output shape is the same as the input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data(name='x', dtype='float32', shape=[3])
data = fluid.layers.zeros_like(x) # [0.0, 0.0, 0.0]
"""
check_variable_and_dtype(
x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'zeros_like'
)
helper = LayerHelper("zeros_like", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
check_variable_and_dtype(
out,
"out",
['bool', 'float32', 'float64', 'int32', 'int64'],
'zeros_like',
)
helper.append_op(
type='fill_any_like',
inputs={'X': [x]},
attrs={'value': 0, "dtype": x.dtype},
outputs={'Out': [out]},
)
out.stop_gradient = True
return out
@deprecated(since="2.0.0", update_to="paddle.diag")
def diag(diagonal):
r"""
......@@ -1783,49 +1516,3 @@ def diag(diagonal):
out.stop_gradient = True
return out
def ones_like(x, out=None):
"""
**ones_like**
This function creates a ones tensor which has identical shape and dtype
with `x`.
Args:
x(Variable): The input tensor which specifies shape and dtype.
out(Variable): The output tensor.
Returns:
out(Variable): The tensor variable storing the output.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', dtype='float32', shape=[3], append_batch_size=False)
data = fluid.layers.ones_like(x) # [1.0, 1.0, 1.0]
"""
check_variable_and_dtype(
x, "x", ['bool', 'float32', 'float64', 'int32', 'int64'], 'ones_like'
)
helper = LayerHelper("ones_like", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
else:
check_variable_and_dtype(
out,
"out",
['bool', 'float32', 'float64', 'int32', 'int64'],
'ones_like',
)
helper.append_op(
type='fill_any_like',
inputs={'X': [x]},
attrs={'value': 1.0},
outputs={'Out': [out]},
)
return out
......@@ -183,7 +183,7 @@ def test_list_pop_in_for_loop(x, iter_num):
a.append(x + i)
b.append(x * 2)
one = fluid.layers.ones(shape=[1], dtype="int32")
one = paddle.ones(shape=[1], dtype="int32")
for i in range(one.numpy()[0]):
item = a.pop()
return a[0], item, b[1]
......
......@@ -53,7 +53,7 @@ class TestWhileOp(unittest.TestCase):
array_len = layers.fill_constant(shape=[1], dtype='int32', value=5)
array_len = layers.cast(array_len, 'int64')
array_len.stop_gradient = True
cond = layers.ones(shape=[1], dtype='int32')
cond = paddle.ones(shape=[1], dtype='int32')
cond = layers.cast(cond, 'bool')
j = layers.fill_constant(shape=[1], dtype='int32', value=1)
j = layers.cast(j, 'int64')
......@@ -62,7 +62,7 @@ class TestWhileOp(unittest.TestCase):
array_len2 = layers.cast(array_len2, 'int64')
array_len2.stop_gradient = True
cond2 = paddle.logical_or(x=j, y=array_len2)
cond2 = layers.ones(shape=[1], dtype='int32')
cond2 = paddle.ones(shape=[1], dtype='int32')
cond2 = layers.cast(cond2, 'bool')
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
......
......@@ -209,7 +209,7 @@ class TestCloneWithStopGradientInSubBlock(unittest.TestCase):
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
img = fluid.layers.data(name='image', shape=[784])
true = fluid.layers.ones(shape=[1], dtype="float32")
true = paddle.ones(shape=[1], dtype="float32")
hidden1 = fluid.layers.fc(input=img, size=200, act='relu')
hidden1.stop_gradient = True
......@@ -250,7 +250,7 @@ class TestCloneWithRaise(unittest.TestCase):
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
img = fluid.layers.data(name='image', shape=[784])
true = fluid.layers.ones(shape=[1], dtype="float32")
true = paddle.ones(shape=[1], dtype="float32")
hidden1 = fluid.layers.fc(input=img, size=200, act='relu')
hidden1.stop_gradient = True
......
......@@ -13,7 +13,7 @@
# limitations under the License.
import unittest
import paddle
import numpy as np
from op_test import OpTest, skip_check_grad_ci
......@@ -72,8 +72,8 @@ class TestExecutorReturnTensorNotOverOverwritingWithLayers(unittest.TestCase):
pass
def calc_add_out(self, place=None, parallel=None):
x = fluid.layers.ones(shape=[3, 3], dtype='float32')
y = fluid.layers.ones(shape=[3, 3], dtype='float32')
x = paddle.ones(shape=[3, 3], dtype='float32')
y = paddle.ones(shape=[3, 3], dtype='float32')
out = fluid.layers.elementwise_add(x=x, y=y)
program = fluid.default_main_program()
if parallel:
......@@ -85,8 +85,8 @@ class TestExecutorReturnTensorNotOverOverwritingWithLayers(unittest.TestCase):
return out
def calc_sub_out(self, place=None, parallel=None):
x = fluid.layers.ones(shape=[2, 2], dtype='float32')
y = fluid.layers.ones(shape=[2, 2], dtype='float32')
x = paddle.ones(shape=[2, 2], dtype='float32')
y = paddle.ones(shape=[2, 2], dtype='float32')
out = fluid.layers.elementwise_sub(x=x, y=y)
program = fluid.default_main_program()
if parallel:
......
......@@ -16,8 +16,6 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid.framework import convert_np_dtype_to_dtype_
......@@ -47,36 +45,5 @@ class TestFillZerosLike2OpFp64(TestFillZerosLike2Op):
self.dtype = np.float64
class TestZerosError(unittest.TestCase):
def test_errors(self):
def test_zeros_like_type_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
fluid.layers.zeros_like([10], dtype="float")
self.assertRaises(TypeError, test_zeros_like_type_error)
def test_zeros_like_dtype_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float16")
fluid.layers.zeros_like(data, dtype="float32")
self.assertRaises(TypeError, test_zeros_like_dtype_error)
def test_zeros_like_out_type_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
fluid.layers.zeros_like(data, dtype="float32", out=[10])
self.assertRaises(TypeError, test_zeros_like_out_type_error)
def test_zeros_like_out_dtype_error():
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data(name="data", shape=[10], dtype="float32")
out = fluid.data(name="out", shape=[10], dtype="float16")
fluid.layers.zeros_like(data, dtype="float32", out=out)
self.assertRaises(TypeError, test_zeros_like_out_dtype_error)
if __name__ == "__main__":
unittest.main()
......@@ -636,7 +636,7 @@ class TestDygraphDoubleGradVisitedUniq(TestCase):
class TestRaiseNoDoubleGradOp(TestCase):
def raise_no_grad_op(self):
with fluid.dygraph.guard():
x = fluid.layers.ones(shape=[2, 3, 2, 2], dtype='float32')
x = paddle.ones(shape=[2, 3, 2, 2], dtype='float32')
x.stop_gradient = False
y = paddle.static.nn.group_norm(x, groups=1)
......
......@@ -16,9 +16,6 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
......@@ -107,22 +104,5 @@ class TestFP16Isfinite(TestIsfinite):
self.dtype = np.float16
class BadInputTest(unittest.TestCase):
def test_error(self):
with fluid.program_guard(fluid.Program()):
def test_has_inf_bad_x():
data = [1, 2, 3]
result = fluid.layers.has_inf(data)
self.assertRaises(TypeError, test_has_inf_bad_x)
with fluid.dygraph.guard():
data = paddle.zeros([2, 3])
result = paddle.fluid.layers.has_inf(data)
expect_value = np.array([False])
self.assertEqual((result.numpy() == expect_value).all(), True)
if __name__ == '__main__':
unittest.main()
......@@ -234,7 +234,7 @@ class TestMathOpPatches(unittest.TestCase):
a = fluid.layers.data(name="a", shape=[1], dtype='float32')
b = fluid.layers.data(name="b", shape=[1], dtype='float32')
one = fluid.layers.ones(shape=[1], dtype='int32')
one = paddle.ones(shape=[1], dtype='int32')
zero = fluid.layers.zeros(shape=[1], dtype='int32')
cond = one == zero
c = fluid.layers.cond(cond, lambda: a + b, lambda: a - b)
......
......@@ -384,7 +384,7 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
def func_test_np_left_mul(self):
with fluid.dygraph.guard():
t = np.sqrt(2.0 * np.pi)
x = fluid.layers.ones((2, 2), dtype="float32")
x = paddle.ones((2, 2), dtype="float32")
y = t * x
np.testing.assert_allclose(
......
......@@ -64,28 +64,6 @@ class TestOnesLikeAPI(unittest.TestCase):
self.assertEqual((outs[i] == np.ones(shape, dtype)).all(), True)
class TestOnesLikeImpeartive(unittest.TestCase):
def test_out(self):
shape = [3, 4]
place = (
fluid.CUDAPlace(0)
if core.is_compiled_with_cuda()
else fluid.CPUPlace()
)
paddle.disable_static(place)
x = paddle.to_tensor(np.ones(shape))
for dtype in [np.bool_, np.float32, np.float64, np.int32, np.int64]:
out = ones_like(x, dtype)
self.assertEqual((out.numpy() == np.ones(shape, dtype)).all(), True)
out = paddle.tensor.ones_like(x)
self.assertEqual((out.numpy() == np.ones(shape, dtype)).all(), True)
out = paddle.tensor.creation.ones_like(x)
self.assertEqual((out.numpy() == np.ones(shape, dtype)).all(), True)
paddle.enable_static()
class TestOnesAPI(unittest.TestCase):
def test_api(self):
shape = [3, 4]
......
......@@ -17,7 +17,7 @@ import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import numpy as np
class ApiOnesTest(unittest.TestCase):
......@@ -48,7 +48,7 @@ class ApiOnesTest(unittest.TestCase):
def test_fluid_ones(self):
with paddle.static.program_guard(paddle.static.Program()):
ones = fluid.layers.ones(shape=[10], dtype="int64")
ones = paddle.ones(shape=[10], dtype="int64")
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
(result,) = exe.run(fetch_list=[ones])
......@@ -72,13 +72,13 @@ class ApiOnesZerosError(unittest.TestCase):
def test_error3():
with paddle.static.program_guard(paddle.static.Program()):
ones = fluid.layers.ones(shape=10, dtype="int64")
ones = paddle.ones(shape=10, dtype="int64")
self.assertRaises(TypeError, test_error3)
def test_error4():
with paddle.static.program_guard(paddle.static.Program()):
ones = fluid.layers.ones(shape=[10], dtype="int8")
ones = paddle.ones(shape=[10], dtype="int8")
self.assertRaises(TypeError, test_error4)
......
......@@ -12,20 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import gradient_checker
import numpy as np
from decorator_helper import prog_scope
from op_test import OpTest
from test_attribute_var import UnittestBase
import paddle
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid import core
from paddle.fluid.framework import Program, program_guard
class TestReverseOp(OpTest):
......@@ -36,7 +28,7 @@ class TestReverseOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = "reverse"
self.python_api = fluid.layers.reverse
self.python_api = paddle.reverse
self.inputs = {"X": self.x}
self.attrs = {'axis': self.axis}
out = self.x
......@@ -99,241 +91,6 @@ class TestCase3_neg(TestReverseOp):
self.axis = [-1, -2]
class TestCase4(unittest.TestCase):
def test_error(self):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
label = fluid.layers.data(
name="label", shape=[1, 1, 1, 1, 1, 1, 1, 1], dtype="int64"
)
rev = fluid.layers.reverse(label, axis=[-1, -2])
def _run_program():
x = np.random.random(size=(10, 1, 1, 1, 1, 1, 1)).astype('int64')
exe.run(train_program, feed={"label": x})
self.assertRaises(IndexError, _run_program)
class TestReverseLoDTensorArray(unittest.TestCase):
def setUp(self):
self.shapes = [[5, 25], [5, 20], [5, 5]]
self.place = (
fluid.CUDAPlace(0)
if fluid.is_compiled_with_cuda()
else fluid.CPUPlace()
)
self.exe = fluid.Executor(self.place)
def run_program(self, arr_len, axis=0):
main_program = fluid.Program()
with fluid.program_guard(main_program):
inputs, inputs_data = [], []
for i in range(arr_len):
x = fluid.data("x%s" % i, self.shapes[i], dtype='float32')
x.stop_gradient = False
inputs.append(x)
inputs_data.append(
np.random.random(self.shapes[i]).astype('float32')
)
tensor_array = fluid.layers.create_array(dtype='float32')
for i in range(arr_len):
idx = fluid.layers.array_length(tensor_array)
fluid.layers.array_write(inputs[i], idx, tensor_array)
reverse_array = fluid.layers.reverse(tensor_array, axis=axis)
output, _ = fluid.layers.tensor_array_to_tensor(reverse_array)
loss = paddle.sum(output)
fluid.backward.append_backward(loss)
input_grads = list(
map(
main_program.global_block().var,
[x.name + "@GRAD" for x in inputs],
)
)
feed_dict = dict(zip([x.name for x in inputs], inputs_data))
res = self.exe.run(
main_program,
feed=feed_dict,
fetch_list=input_grads + [output.name],
)
return np.hstack(inputs_data[::-1]), res
def test_case1(self):
gt, res = self.run_program(arr_len=3)
self.check_output(gt, res)
# test with tuple type of axis
gt, res = self.run_program(arr_len=3, axis=(0,))
self.check_output(gt, res)
def test_case2(self):
gt, res = self.run_program(arr_len=1)
self.check_output(gt, res)
# test with list type of axis
gt, res = self.run_program(arr_len=1, axis=[0])
self.check_output(gt, res)
def check_output(self, gt, res):
arr_len = len(res) - 1
reversed_array = res[-1]
# check output
np.testing.assert_array_equal(gt, reversed_array)
# check grad
for i in range(arr_len):
np.testing.assert_array_equal(res[i], np.ones_like(res[i]))
def test_raise_error(self):
# The len(axis) should be 1 is input(X) is LoDTensorArray
with self.assertRaises(Exception):
self.run_program(arr_len=3, axis=[0, 1])
# The value of axis should be 0 is input(X) is LoDTensorArray
with self.assertRaises(Exception):
self.run_program(arr_len=3, axis=1)
class TestReverseAxisTensor(UnittestBase):
def init_info(self):
self.shapes = [[2, 3, 4]]
self.save_path = os.path.join(self.temp_dir.name, self.path_prefix())
def test_static(self):
main_prog = Program()
starup_prog = Program()
with program_guard(main_prog, starup_prog):
fc = paddle.nn.Linear(4, 10)
x = paddle.randn([2, 3, 4])
x.stop_gradient = False
feat = fc(x) # [2,3,10]
out = self.call_func(feat)
sgd = paddle.optimizer.SGD()
sgd.minimize(paddle.mean(out))
self.assertTrue(self.var_prefix() in str(main_prog))
exe = paddle.static.Executor()
exe.run(starup_prog)
res = exe.run(fetch_list=[feat, out])
gt = res[0][::-1, :, ::-1]
np.testing.assert_allclose(res[1], gt)
paddle.static.save_inference_model(
self.save_path, [x], [feat, out], exe
)
# Test for Inference Predictor
infer_outs = self.infer_prog()
gt = infer_outs[0][::-1, :, ::-1]
np.testing.assert_allclose(infer_outs[1], gt)
def path_prefix(self):
return 'reverse_tensor'
def var_prefix(self):
return "Var["
def call_func(self, x):
# axes is a Variable
axes = paddle.assign([0, 2])
out = paddle.fluid.layers.reverse(x, axes)
return out
class TestReverseAxisListTensor(TestReverseAxisTensor):
def path_prefix(self):
return 'reverse_tensors'
def var_prefix(self):
return "Vars["
def call_func(self, x):
# axes is a List[Variable]
axes = [paddle.assign([0]), paddle.assign([2])]
out = paddle.fluid.layers.reverse(x, axes)
# check attrs
axis_attrs = (
paddle.static.default_main_program()
.block(0)
.ops[-1]
.all_attrs()["axis"]
)
self.assertTrue(axis_attrs[0].name, axes[0].name)
self.assertTrue(axis_attrs[1].name, axes[1].name)
return out
class TestReverseDoubleGradCheck(unittest.TestCase):
def reverse_wrapper(self, x):
return fluid.layers.reverse(x[0], [0, 1])
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float64
data = layers.data('data', [3, 4], False, dtype)
data.persistable = True
out = fluid.layers.reverse(data, [0, 1])
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.double_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.double_grad_check_for_dygraph(
self.reverse_wrapper, [data], out, x_init=[data_arr], place=place
)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestReverseTripleGradCheck(unittest.TestCase):
def reverse_wrapper(self, x):
return fluid.layers.reverse(x[0], [0, 1])
@prog_scope()
def func(self, place):
# the shape of input variable should be clearly specified, not inlcude -1.
eps = 0.005
dtype = np.float32
data = layers.data('data', [2, 3], False, dtype)
data.persistable = True
out = fluid.layers.reverse(data, [0, 1])
data_arr = np.random.uniform(-1, 1, data.shape).astype(dtype)
gradient_checker.triple_grad_check(
[data], out, x_init=[data_arr], place=place, eps=eps
)
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
gradient_checker.triple_grad_check_for_dygraph(
self.reverse_wrapper, [data], out, x_init=[data_arr], place=place
)
def test_grad(self):
paddle.enable_static()
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -83,7 +83,7 @@ class TestZerosLikeImpeartive(unittest.TestCase):
self.assertEqual(
(out.numpy() == np.zeros(shape, dtype)).all(), True
)
out = paddle.tensor.zeros_like(x)
out = paddle.zeros_like(x)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(), True)
out = paddle.tensor.creation.zeros_like(x)
self.assertEqual((out.numpy() == np.zeros(shape, dtype)).all(), True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册