未验证 提交 582c0a04 编写于 作者: J joejiong 提交者: GitHub

add uint8 for reshape op (#28996)

add uint8 for reshape operator
上级 f0e614fe
......@@ -627,12 +627,14 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float,
ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int,
ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t,
ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel);
......@@ -640,20 +642,24 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float,
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16,
uint8_t, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16,
uint8_t, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, bool, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool,
......@@ -662,6 +668,7 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int,
ops::ReshapeDoubleGradKernel, uint8_t,
ops::ReshapeDoubleGradKernel, int64_t,
ops::ReshapeDoubleGradKernel, plat::float16,
ops::ReshapeDoubleGradKernel, bool,
......
......@@ -20,7 +20,8 @@ import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid import compiler
from paddle.static import Program, program_guard
# situation 1: have shape( list, no tensor), no actual shape(Tensor)
......@@ -248,16 +249,17 @@ class TestReshapeOpBool(TestReshapeOp):
class TestReshapeAPI(unittest.TestCase):
def _set_paddle_api(self):
self.fill_constant = paddle.fluid.layers.fill_constant
self.data = paddle.fluid.data
self.data = paddle.static.data
self.reshape = paddle.reshape
self.to_tensor = paddle.to_tensor
def _set_fluid_api(self):
self.fill_constant = fluid.layers.fill_constant
self.data = fluid.data
self.data = paddle.static.data
self.reshape = fluid.layers.reshape
def _test_api(self):
paddle.enable_static()
input = np.random.random([2, 25]).astype("float32")
shape = [2, 5, 5]
main_prog = Program()
......@@ -280,7 +282,7 @@ class TestReshapeAPI(unittest.TestCase):
# Situation 4: have shape(Tensor), no actual shape(Tensor)
out_4 = self.reshape(x, shape=actual_shape)
exe = fluid.Executor(place=fluid.CPUPlace())
exe = paddle.static.Executor(place=paddle.CPUPlace())
res_1, res_2, res_3, res_4 = exe.run(
main_prog,
feed={"x": input,
......@@ -323,7 +325,7 @@ class TestReshapeAPI(unittest.TestCase):
# Test Input Error
class TestReshapeOpError(unittest.TestCase):
def _set_paddle_api(self):
self.data = paddle.fluid.data
self.data = paddle.static.data
self.reshape = paddle.reshape
def _set_fluid_api(self):
......@@ -335,7 +337,7 @@ class TestReshapeOpError(unittest.TestCase):
# The x type of reshape_op must be Variable.
def test_x_type():
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
np.array([[-1]]), [[1]], paddle.CPUPlace())
self.reshape(x1, shape=[1])
self.assertRaises(TypeError, test_x_type)
......@@ -395,5 +397,34 @@ class TestReshapeOpError(unittest.TestCase):
self._test_errors()
class API_TestDygraphReshape(unittest.TestCase):
def test_out(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.reshape(x=input, shape=[5, 10])
out_np = output.numpy()
expected_out = np.reshape(input_1, newshape=[5, 10])
self.assertTrue(np.allclose(expected_out, out_np))
def test_out_uint8(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("uint8")
input = paddle.to_tensor(input_1)
output = paddle.reshape(x=input, shape=[5, 10])
out_np = output.numpy()
expected_out = np.reshape(input_1, newshape=[5, 10])
self.assertTrue(np.allclose(expected_out, out_np))
def test_out_float32(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("float32")
input = paddle.to_tensor(input_1)
output = paddle.reshape(x=input, shape=[5, 10])
out_np = output.numpy()
expected_out = np.reshape(input_1, newshape=[5, 10])
self.assertTrue(np.allclose(expected_out, out_np))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册