未验证 提交 5ccaaab8 编写于 作者: M mapingshuo 提交者: GitHub

reshape support bool, test=develop (#27944)

上级 8d7908f3
......@@ -621,15 +621,18 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int8_t, ops::ReshapeKernel,
uint8_t, ops::ReshapeKernel, int,
ops::ReshapeKernel, int64_t, ops::ReshapeKernel);
ops::ReshapeKernel, int64_t, ops::ReshapeKernel,
bool, ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad_grad, float,
ops::ReshapeDoubleGradKernel, double,
ops::ReshapeDoubleGradKernel, int,
ops::ReshapeDoubleGradKernel, int64_t,
ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel);
#ifdef PADDLE_WITH_CUDA
......@@ -641,15 +644,17 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel);
ops::ReshapeKernel, bool, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, bool,
ops::ReshapeGradKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
......@@ -657,6 +662,7 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
ops::ReshapeDoubleGradKernel, int,
ops::ReshapeDoubleGradKernel, int64_t,
ops::ReshapeDoubleGradKernel, plat::float16,
ops::ReshapeDoubleGradKernel, bool,
ops::ReshapeDoubleGradKernel);
#endif
......@@ -664,10 +670,11 @@ REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad_grad, float,
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel);
ops::ReshapeKernel, bool, ops::ReshapeKernel);
REGISTER_OP_XPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel);
ops::ReshapeGradKernel,
bool ops::ReshapeGradKernel);
#endif
......@@ -6119,7 +6119,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'], 'reshape')
x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64',
'bool'], 'reshape')
check_type(shape, 'shape', (list, tuple, Variable), 'reshape')
check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape')
......
......@@ -226,6 +226,24 @@ class TestReshapeUint8Op(TestReshapeInt8Op):
self.dtype = np.uint8
class TestReshapeOpBool(TestReshapeOp):
def setUp(self):
self.init_data()
self.op_type = "reshape2"
self.inputs = {
"X": np.random.choice(
[True, False], size=self.ori_shape)
}
self.attrs = {"shape": self.new_shape}
self.outputs = {
"Out": self.inputs["X"].reshape(self.infered_shape),
'XShape': np.random.random(self.ori_shape).astype("float32")
}
def test_check_grad(self):
pass
# Test python API
class TestReshapeAPI(unittest.TestCase):
def _set_paddle_api(self):
......@@ -324,7 +342,7 @@ class TestReshapeOpError(unittest.TestCase):
# The x dtype of reshape_op must be float16, float32, float64, int32 or int64.
def test_x_dtype():
x2 = self.data(name="x2", shape=[2, 25], dtype="bool")
x2 = self.data(name="x2", shape=[2, 25], dtype="int8")
self.reshape(x2, shape=[2, 5, 5])
self.assertRaises(TypeError, test_x_dtype)
......
......@@ -1353,7 +1353,7 @@ def reshape(x, shape, name=None):
the corresponding dimension of x.
Args:
x(Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32`` or ``int64``.
x(Tensor): An N-D Tensor. The data type is ``float32``, ``float64``, ``int32``, ``int64`` or ``bool``
shape(list|tuple|Tensor): Define the target shape. At most one dimension of the target shape can be -1.
The data type is ``int32`` . If ``shape`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``shape`` is an Tensor, it should be an 1-D Tensor .
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册