未验证 提交 1a532d51 编写于 作者: J joejiong 提交者: GitHub

add uint8 support for squeeze operator (#28734)

Adding uint8 support for squeeze operator.
上级 9066828b
......@@ -337,6 +337,7 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
......@@ -345,6 +346,7 @@ REGISTER_OP_CPU_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
......@@ -352,6 +354,7 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
......@@ -360,5 +363,6 @@ REGISTER_OP_CPU_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -23,6 +23,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
......@@ -32,6 +33,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::SqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
......@@ -41,6 +43,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2Kernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
squeeze2_grad,
......@@ -50,4 +53,5 @@ REGISTER_OP_CUDA_KERNEL(
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Squeeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
文件模式从 100644 更改为 100755
......@@ -362,6 +362,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
......@@ -370,6 +371,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
......@@ -377,6 +379,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
......@@ -385,5 +388,6 @@ REGISTER_OP_CPU_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -23,6 +23,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
......@@ -34,6 +35,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
unsqueeze2,
......@@ -42,6 +44,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
......@@ -52,5 +55,6 @@ REGISTER_OP_CUDA_KERNEL(
plat::float16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -13,12 +13,13 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
paddle.enable_static()
......
......@@ -13,13 +13,15 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
import paddle
from op_test import OpTest
paddle.enable_static()
......@@ -81,27 +83,30 @@ class TestSqueezeOp4(TestSqueezeOp):
class TestSqueezeOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
# The input type of softmax_op must be Variable.
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.squeeze, x1)
np.array([[-1]]), [[1]], paddle.CPUPlace())
self.assertRaises(TypeError, paddle.squeeze, x1)
# The input axes of squeeze must be list.
x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, fluid.layers.squeeze, x2, axes=0)
x2 = paddle.static.data(name='x2', shape=[4], dtype="int32")
self.assertRaises(TypeError, paddle.squeeze, x2, axes=0)
# The input dtype of squeeze not support float16.
x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16")
self.assertRaises(TypeError, fluid.layers.squeeze, x3, axes=0)
x3 = paddle.static.data(name='x3', shape=[4], dtype="float16")
self.assertRaises(TypeError, paddle.squeeze, x3, axes=0)
class API_TestSqueeze(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data(
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data1 = paddle.static.data(
'data1', shape=[-1, 1, 10], dtype='float64')
result_squeeze = paddle.squeeze(data1, axis=[1])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
input1 = np.random.random([5, 1, 10]).astype('float64')
result, = exe.run(feed={"data1": input1},
fetch_list=[result_squeeze])
......@@ -111,31 +116,49 @@ class API_TestSqueeze(unittest.TestCase):
class API_TestDygraphSqueeze(unittest.TestCase):
def test_out(self):
with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = fluid.dygraph.to_variable(input_1)
output = paddle.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_out_int8(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int8")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_out_uint8(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("uint8")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_axis_not_list(self):
with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = fluid.dygraph.to_variable(input_1)
output = paddle.squeeze(input, axis=1)
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=1)
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_dimension_not_1(self):
with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = fluid.dygraph.to_variable(input_1)
output = paddle.squeeze(input, axis=(1, 2))
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=(1, 2))
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
if __name__ == "__main__":
......
......@@ -13,12 +13,14 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest
import paddle
paddle.enable_static()
......@@ -208,24 +210,24 @@ class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
class TestUnsqueezeAPI(unittest.TestCase):
def test_api(self):
input = np.random.random([3, 2, 5]).astype("float64")
x = fluid.data(name='x', shape=[3, 2, 5], dtype="float64")
x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64")
positive_3_int32 = fluid.layers.fill_constant([1], "int32", 3)
positive_1_int64 = fluid.layers.fill_constant([1], "int64", 1)
axes_tensor_int32 = fluid.data(
axes_tensor_int32 = paddle.static.data(
name='axes_tensor_int32', shape=[3], dtype="int32")
axes_tensor_int64 = fluid.data(
axes_tensor_int64 = paddle.static.data(
name='axes_tensor_int64', shape=[3], dtype="int64")
out_1 = fluid.layers.unsqueeze(x, axes=[3, 1, 1])
out_2 = fluid.layers.unsqueeze(
x, axes=[positive_3_int32, positive_1_int64, 1])
out_3 = fluid.layers.unsqueeze(x, axes=axes_tensor_int32)
out_4 = fluid.layers.unsqueeze(x, axes=3)
out_5 = fluid.layers.unsqueeze(x, axes=axes_tensor_int64)
out_1 = paddle.unsqueeze(x, axis=[3, 1, 1])
out_2 = paddle.unsqueeze(
x, axis=[positive_3_int32, positive_1_int64, 1])
out_3 = paddle.unsqueeze(x, axis=axes_tensor_int32)
out_4 = paddle.unsqueeze(x, axis=3)
out_5 = paddle.unsqueeze(x, axis=axes_tensor_int64)
exe = fluid.Executor(place=fluid.CPUPlace())
exe = paddle.static.Executor(place=paddle.CPUPlace())
res_1, res_2, res_3, res_4, res_5 = exe.run(
fluid.default_main_program(),
paddle.static.default_main_program(),
feed={
"x": input,
"axes_tensor_int32": np.array([3, 1, 1]).astype("int32"),
......@@ -241,8 +243,8 @@ class TestUnsqueezeAPI(unittest.TestCase):
def test_error(self):
def test_axes_type():
x2 = fluid.data(name="x2", shape=[2, 25], dtype="int32")
fluid.layers.unsqueeze(x2, axes=2.1)
x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32")
paddle.unsqueeze(x2, axis=2.1)
self.assertRaises(TypeError, test_axes_type)
......
......@@ -13,12 +13,14 @@
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from op_test import OpTest
paddle.enable_static()
......@@ -80,11 +82,13 @@ class TestUnsqueezeOp4(TestUnsqueezeOp):
class API_TestUnsqueeze(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.layers.data('data1', shape=[-1, 10], dtype='float64')
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data1 = paddle.static.data('data1', shape=[-1, 10], dtype='float64')
result_squeeze = paddle.unsqueeze(data1, axis=[1])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
input1 = np.random.random([5, 1, 10]).astype('float64')
input = np.squeeze(input1, axis=1)
result, = exe.run(feed={"data1": input},
......@@ -94,10 +98,12 @@ class API_TestUnsqueeze(unittest.TestCase):
class TestUnsqueezeOpError(unittest.TestCase):
def test_errors(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
# The type of axis in split_op should be int or Variable.
def test_axes_type():
x6 = fluid.layers.data(
x6 = paddle.static.data(
shape=[-1, 10], dtype='float16', name='x3')
paddle.unsqueeze(x6, axis=3.2)
......@@ -106,12 +112,14 @@ class TestUnsqueezeOpError(unittest.TestCase):
class API_TestUnsqueeze2(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.data('data1', shape=[-1, 10], dtype='float64')
data2 = fluid.data('data2', shape=[1], dtype='int32')
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data1 = paddle.static.data('data1', shape=[-1, 10], dtype='float64')
data2 = paddle.static.data('data2', shape=[1], dtype='int32')
result_squeeze = paddle.unsqueeze(data1, axis=data2)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
input1 = np.random.random([5, 1, 10]).astype('float64')
input2 = np.array([1]).astype('int32')
input = np.squeeze(input1, axis=1)
......@@ -123,12 +131,14 @@ class API_TestUnsqueeze2(unittest.TestCase):
class API_TestUnsqueeze3(unittest.TestCase):
def test_out(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data1 = fluid.data('data1', shape=[-1, 10], dtype='float64')
data2 = fluid.data('data2', shape=[1], dtype='int32')
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data1 = paddle.static.data('data1', shape=[-1, 10], dtype='float64')
data2 = paddle.static.data('data2', shape=[1], dtype='int32')
result_squeeze = paddle.unsqueeze(data1, axis=[data2, 3])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
input1 = np.random.random([5, 1, 10, 1]).astype('float64')
input2 = np.array([1]).astype('int32')
input = np.squeeze(input1)
......@@ -141,55 +151,102 @@ class API_TestUnsqueeze3(unittest.TestCase):
class API_TestDyUnsqueeze(unittest.TestCase):
def test_out(self):
with fluid.dygraph.guard():
input_1 = np.random.random([5, 1, 10]).astype("int32")
input1 = np.expand_dims(input_1, axis=1)
input = fluid.dygraph.to_variable(input_1)
output = paddle.unsqueeze(input, axis=[1])
out_np = output.numpy()
self.assertTrue(np.array_equal(input1, out_np))
self.assertEqual(input1.shape, out_np.shape)
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input1 = np.expand_dims(input_1, axis=1)
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=[1])
out_np = output.numpy()
self.assertTrue(np.array_equal(input1, out_np))
self.assertEqual(input1.shape, out_np.shape)
class API_TestDyUnsqueeze2(unittest.TestCase):
def test_out(self):
with fluid.dygraph.guard():
input1 = np.random.random([5, 10]).astype("int32")
out1 = np.expand_dims(input1, axis=1)
input = fluid.dygraph.to_variable(input1)
output = paddle.unsqueeze(input, axis=1)
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
paddle.disable_static()
input1 = np.random.random([5, 10]).astype("int32")
out1 = np.expand_dims(input1, axis=1)
input = paddle.to_tensor(input1)
output = paddle.unsqueeze(input, axis=1)
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
class API_TestDyUnsqueezeAxisTensor(unittest.TestCase):
def test_out(self):
with fluid.dygraph.guard():
input1 = np.random.random([5, 10]).astype("int32")
out1 = np.expand_dims(input1, axis=1)
out1 = np.expand_dims(out1, axis=2)
input = fluid.dygraph.to_variable(input1)
output = paddle.unsqueeze(input, axis=paddle.to_tensor([1, 2]))
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
paddle.disable_static()
input1 = np.random.random([5, 10]).astype("int32")
out1 = np.expand_dims(input1, axis=1)
out1 = np.expand_dims(out1, axis=2)
input = paddle.to_tensor(input1)
output = paddle.unsqueeze(input, axis=paddle.to_tensor([1, 2]))
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
class API_TestDyUnsqueezeAxisTensorList(unittest.TestCase):
def test_out(self):
with fluid.dygraph.guard():
input1 = np.random.random([5, 10]).astype("int32")
# Actually, expand_dims supports tuple since version 1.18.0
out1 = np.expand_dims(input1, axis=1)
out1 = np.expand_dims(out1, axis=2)
input = fluid.dygraph.to_variable(input1)
output = paddle.unsqueeze(
fluid.dygraph.to_variable(input1),
axis=[paddle.to_tensor([1]), paddle.to_tensor([2])])
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
paddle.disable_static()
input1 = np.random.random([5, 10]).astype("int32")
# Actually, expand_dims supports tuple since version 1.18.0
out1 = np.expand_dims(input1, axis=1)
out1 = np.expand_dims(out1, axis=2)
input = paddle.to_tensor(input1)
output = paddle.unsqueeze(
paddle.to_tensor(input1),
axis=[paddle.to_tensor([1]), paddle.to_tensor([2])])
out_np = output.numpy()
self.assertTrue(np.array_equal(out1, out_np))
self.assertEqual(out1.shape, out_np.shape)
class API_TestDygraphUnSqueeze(unittest.TestCase):
def test_out(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_out_int8(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int8")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_out_uint8(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("uint8")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=1)
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_axis_not_list(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=1)
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
def test_dimension_not_1(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=(1, 2))
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册