未验证 提交 267275d9 编写于 作者: S sneaxiy 提交者: GitHub

Add int16 support for several ops (#39636)

* add more op int16 support

* fix xpu ci
上级 2fe04264
......@@ -48,6 +48,7 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
REGISTER_OP_CUDA_KERNEL( \
op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<bool>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int16_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
......
......@@ -95,6 +95,9 @@ class CompareOpKernel
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int>, inverse_functor<int>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int16_t>, inverse_functor<int16_t>>, \
::paddle::operators::CompareOpKernel< \
::paddle::platform::dev##DeviceContext, \
functor<int64_t>, inverse_functor<int64_t>>, \
......
......@@ -93,6 +93,7 @@ REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(cumsum, ops::CumKernel<CPU, ops::CumsumFunctor<float>>,
ops::CumKernel<CPU, ops::CumsumFunctor<double>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int16_t>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int>>,
ops::CumKernel<CPU, ops::CumsumFunctor<int64_t>>);
......
......@@ -320,5 +320,6 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
cumsum, ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::CumCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -96,6 +96,7 @@ REGISTER_OP_CPU_KERNEL(
elementwise_sub,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
......@@ -106,6 +107,7 @@ REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
......@@ -118,6 +120,8 @@ REGISTER_OP_CPU_KERNEL(
float>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
double>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int16_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
......
......@@ -94,6 +94,7 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(
fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fill_any_like,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FillAnyLikeKernel<paddle::platform::CUDADeviceContext, float>,
......
......@@ -183,7 +183,9 @@ REGISTER_OPERATOR(gather_nd_grad, ops::GatherNdGradOp,
REGISTER_OP_CPU_KERNEL(gather_nd, ops::GatherNdOpKernel<float>,
ops::GatherNdOpKernel<double>,
ops::GatherNdOpKernel<int64_t>,
ops::GatherNdOpKernel<int>, ops::GatherNdOpKernel<bool>,
ops::GatherNdOpKernel<int>,
ops::GatherNdOpKernel<int16_t>,
ops::GatherNdOpKernel<bool>,
ops::GatherNdOpKernel<uint8_t>);
REGISTER_OP_CPU_KERNEL(gather_nd_grad, ops::GatherNdGradOpKernel<float>,
......
......@@ -103,6 +103,7 @@ REGISTER_OP_CUDA_KERNEL(gather_nd, ops::GatherNdOpCUDAKernel<CUDA, float>,
ops::GatherNdOpCUDAKernel<CUDA, double>,
ops::GatherNdOpCUDAKernel<CUDA, int64_t>,
ops::GatherNdOpCUDAKernel<CUDA, int>,
ops::GatherNdOpCUDAKernel<CUDA, int16_t>,
ops::GatherNdOpCUDAKernel<CUDA, bool>,
ops::GatherNdOpCUDAKernel<CUDA, plat::float16>);
......
......@@ -116,6 +116,8 @@ REGISTER_OP_CPU_KERNEL(
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int16_t,
ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int, ops::SumFunctor>,
ops::ReduceKernel<paddle::platform::CPUDeviceContext, int64_t,
ops::SumFunctor>,
......
......@@ -20,6 +20,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::ReduceCudaKernel<double, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::float16, kps::AddFunctor,
kps::IdentityFunctor>,
ops::ReduceCudaKernel<int16_t, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<int64_t, kps::AddFunctor, kps::IdentityFunctor>,
ops::ReduceCudaKernel<paddle::platform::complex<float>, kps::AddFunctor,
......
......@@ -639,10 +639,12 @@ REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp,
ops::ReshapeGradInplaceInferer);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
int64_t, ops::ReshapeKernel);
ops::ReshapeKernel, int16_t, ops::ReshapeKernel,
int, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel);
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
double, ops::ReshapeGradKernel, int16_t,
ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel);
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
......@@ -659,15 +661,15 @@ REGISTER_OPERATOR(reshape2_grad_grad, ops::Reshape2DoubleGradOp,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
ops::ReshapeKernel, int, ops::ReshapeKernel,
uint8_t, ops::ReshapeKernel, int64_t,
ops::ReshapeKernel, plat::float16,
ops::ReshapeKernel, plat::bfloat16,
ops::ReshapeKernel);
ops::ReshapeKernel, int16_t, ops::ReshapeKernel,
int, ops::ReshapeKernel, uint8_t,
ops::ReshapeKernel, int64_t, ops::ReshapeKernel,
plat::float16, ops::ReshapeKernel,
plat::bfloat16, ops::ReshapeKernel);
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
double, ops::ReshapeGradKernel, int,
ops::ReshapeGradKernel, int64_t,
ops::ReshapeGradKernel, uint8_t,
double, ops::ReshapeGradKernel, int16_t,
ops::ReshapeKernel, int, ops::ReshapeGradKernel,
int64_t, ops::ReshapeGradKernel, uint8_t,
ops::ReshapeGradKernel, plat::float16,
ops::ReshapeGradKernel, plat::bfloat16,
ops::ReshapeGradKernel);
......
......@@ -362,6 +362,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
......@@ -377,6 +378,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
......@@ -391,6 +393,7 @@ REGISTER_OP_CPU_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, double>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CPUDeviceContext, int64_t>,
......@@ -406,6 +409,7 @@ REGISTER_OP_CPU_KERNEL(
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CPUDeviceContext, int64_t>,
......
......@@ -24,6 +24,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
......@@ -41,6 +42,7 @@ REGISTER_OP_CUDA_KERNEL(
plat::bfloat16>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
......@@ -56,6 +58,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::bfloat16>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, bool>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>,
......@@ -73,6 +76,7 @@ REGISTER_OP_CUDA_KERNEL(
plat::bfloat16>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, bool>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>,
......
......@@ -57,6 +57,7 @@ REGISTER_OP_WITHOUT_GRADIENT(where_index, ops::WhereIndexOp,
ops::WhereIndexOpMaker);
REGISTER_OP_CPU_KERNEL(where_index, ops::CPUWhereIndexKernel<int64_t>,
ops::CPUWhereIndexKernel<int>,
ops::CPUWhereIndexKernel<int16_t>,
ops::CPUWhereIndexKernel<bool>,
ops::CPUWhereIndexKernel<float>,
ops::CPUWhereIndexKernel<double>);
......@@ -158,6 +158,7 @@ class CUDAWhereIndexKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(where_index, ops::CUDAWhereIndexKernel<int64_t>,
ops::CUDAWhereIndexKernel<int>,
ops::CUDAWhereIndexKernel<int16_t>,
ops::CUDAWhereIndexKernel<bool>,
ops::CUDAWhereIndexKernel<float>,
ops::CUDAWhereIndexKernel<double>);
......@@ -132,6 +132,7 @@ PT_REGISTER_KERNEL(add_grad,
pten::AddGradKernel,
float,
double,
int16_t,
int,
int64_t,
pten::dtype::complex<float>,
......@@ -143,6 +144,7 @@ PT_REGISTER_KERNEL(add_double_grad,
pten::AddDoubleGradKernel,
float,
double,
int16_t,
int,
int64_t,
pten::dtype::complex<float>,
......@@ -154,6 +156,7 @@ PT_REGISTER_KERNEL(add_triple_grad,
pten::AddTripleGradKernel,
float,
double,
int16_t,
int,
int64_t,
pten::dtype::complex<float>,
......@@ -165,6 +168,7 @@ PT_REGISTER_KERNEL(subtract_grad,
pten::SubtractGradKernel,
float,
double,
int16_t,
int,
int64_t,
pten::dtype::complex<float>,
......@@ -176,6 +180,7 @@ PT_REGISTER_KERNEL(subtract_double_grad,
pten::SubtractDoubleGradKernel,
float,
double,
int16_t,
int,
int64_t,
pten::dtype::complex<float>,
......
......@@ -95,6 +95,7 @@ PT_REGISTER_KERNEL(full_like,
pten::FullLikeKernel,
float,
double,
int16_t,
int,
int64_t,
bool,
......
......@@ -124,6 +124,7 @@ PT_REGISTER_KERNEL(add_raw,
pten::AddRawKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
......@@ -134,6 +135,7 @@ PT_REGISTER_KERNEL(subtract_raw,
pten::SubtractRawKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
......@@ -167,6 +169,7 @@ PT_REGISTER_KERNEL(sum_raw,
float,
double,
pten::dtype::float16,
int16_t,
int,
int64_t,
complex64,
......
......@@ -56,6 +56,7 @@ PT_REGISTER_KERNEL(flatten,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -67,6 +68,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -80,6 +82,7 @@ PT_REGISTER_KERNEL(flatten,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -92,6 +95,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
double,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
#endif
......@@ -104,6 +108,7 @@ PT_REGISTER_KERNEL(flatten,
float,
pten::dtype::float16,
int8_t,
int16_t,
int,
int64_t) {}
......@@ -114,6 +119,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
float,
pten::dtype::float16,
int8_t,
int16_t,
int,
int64_t) {}
#endif
......@@ -119,6 +119,7 @@ PT_REGISTER_KERNEL(full_like,
pten::FullLikeKernel,
float,
double,
int16_t,
int,
int64_t,
bool,
......
......@@ -101,6 +101,7 @@ PT_REGISTER_KERNEL(add_raw,
pten::AddRawKernel,
float,
double,
int16_t,
int,
int64_t,
float16,
......@@ -112,6 +113,7 @@ PT_REGISTER_KERNEL(subtract_raw,
pten::SubtractRawKernel,
float,
double,
int16_t,
int,
int64_t,
float16,
......@@ -148,6 +150,7 @@ PT_REGISTER_KERNEL(sum_raw,
float,
double,
float16,
int16_t,
int,
int64_t,
complex64,
......
......@@ -92,6 +92,7 @@ PT_REGISTER_KERNEL(sum,
float,
double,
pten::dtype::float16,
int16_t,
int,
int64_t,
complex64,
......@@ -105,6 +106,7 @@ PT_REGISTER_KERNEL(add,
pten::AddKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
......@@ -115,6 +117,7 @@ PT_REGISTER_KERNEL(subtract,
pten::SubtractKernel,
float,
double,
int16_t,
int,
int64_t,
complex64,
......@@ -158,6 +161,7 @@ PT_REGISTER_KERNEL(sum,
float,
double,
pten::dtype::float16,
int16_t,
int,
int64_t,
complex64,
......@@ -170,6 +174,7 @@ PT_REGISTER_KERNEL(add,
pten::AddKernel,
float,
double,
int16_t,
int,
int64_t,
pten::dtype::float16,
......@@ -181,6 +186,7 @@ PT_REGISTER_KERNEL(subtract,
pten::SubtractKernel,
float,
double,
int16_t,
int,
int64_t,
pten::dtype::float16,
......
......@@ -6276,7 +6276,8 @@ def reshape(x, shape, actual_shape=None, act=None, inplace=False, name=None):
return dygraph_utils._append_activation_in_dygraph(out, act)
check_variable_and_dtype(x, 'x', [
'float16', 'float32', 'float64', 'int32', 'int64', 'bool', 'uint16'
'float16', 'float32', 'float64', 'int16', 'int32', 'int64', 'bool',
'uint16'
], 'reshape')
check_type(shape, 'shape', (list, tuple, Variable), 'reshape')
check_type(actual_shape, 'actual_shape', (Variable, type(None)), 'reshape')
......@@ -6456,10 +6457,10 @@ def unsqueeze(input, axes, name=None):
return out
check_type(axes, 'axis/axes', (int, list, tuple, Variable), 'unsqueeze')
check_variable_and_dtype(
input, 'input',
['float16', 'float32', 'float64', 'bool', 'int8', 'int32', 'int64'],
'unsqueeze')
check_variable_and_dtype(input, 'input', [
'float16', 'float32', 'float64', 'bool', 'int8', 'int16', 'int32',
'int64'
], 'unsqueeze')
helper = LayerHelper("unsqueeze2", **locals())
inputs = {"X": input}
attrs = {}
......@@ -8539,9 +8540,9 @@ def gather_nd(input, index, name=None):
"""
if in_dygraph_mode():
return _C_ops.gather_nd(input, index)
check_variable_and_dtype(input, 'input',
['bool', 'float32', 'float64', 'int32', 'int64'],
'gather_np')
check_variable_and_dtype(
input, 'input',
['bool', 'float32', 'float64', 'int16', 'int32', 'int64'], 'gather_np')
check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np')
helper = LayerHelper('gather_nd', **locals())
dtype = helper.input_dtype()
......
......@@ -250,12 +250,12 @@ def cast(x, dtype):
return out
check_variable_and_dtype(x, 'x', [
'bool', 'float16', 'float32', 'float64', 'int32', 'int64', 'uint8',
'uint16'
'bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64',
'uint8', 'uint16'
], 'cast')
check_dtype(dtype, 'dtype', [
'bool', 'float16', 'float32', 'float64', 'int8', 'int32', 'int64',
'uint8', 'uint16'
'bool', 'float16', 'float32', 'float64', 'int8', 'int16', 'int32',
'int64', 'uint8', 'uint16'
], 'cast')
helper = LayerHelper('cast', **locals())
......
......@@ -109,15 +109,6 @@ class TestCastOpError(unittest.TestCase):
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.CPUPlace())
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be bool, float16, float32, float64, int32, int64, uint8.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
def test_dtype_type():
x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32')
output = fluid.layers.cast(x=x4, dtype='int16')
self.assertRaises(TypeError, test_dtype_type)
if __name__ == '__main__':
......
......@@ -81,12 +81,6 @@ class TestFullOpError(unittest.TestCase):
x=input_data,
fill_value=2,
dtype='uint4')
self.assertRaises(
TypeError,
paddle.full_like,
x=input_data,
fill_value=2,
dtype='int16')
if __name__ == "__main__":
......
......@@ -67,15 +67,6 @@ class TestCastOpError(unittest.TestCase):
x1 = fluid.create_lod_tensor(
np.array([[-1]]), [[1]], fluid.XPUPlace(0))
self.assertRaises(TypeError, fluid.layers.cast, x1, 'int32')
# The input dtype of cast_op must be float32, int32, int64.
x2 = fluid.layers.data(name='x2', shape=[4], dtype='int16')
self.assertRaises(TypeError, fluid.layers.cast, x2, 'int32')
def test_dtype_type():
x4 = fluid.layers.data(name='x4', shape=[4], dtype='int32')
output = fluid.layers.cast(x=x4, dtype='int16')
self.assertRaises(TypeError, test_dtype_type)
if __name__ == '__main__':
......
......@@ -219,10 +219,12 @@ def full_like(x, fill_value, dtype=None, name=None):
helper = LayerHelper("full_like", **locals())
check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
x, 'x',
['bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64'],
'full_like')
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
check_dtype(
dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int16', 'int32', 'int64'],
'full_like/zeros_like/ones_like')
out = helper.create_variable_for_type_inference(dtype=dtype)
......
......@@ -672,7 +672,8 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if not in_dygraph_mode():
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'int8', 'int32', 'int64', 'uint8'],
x, 'x',
['float32', 'float64', 'int8', 'int16', 'int32', 'int64', 'uint8'],
'flatten')
x_dim = len(x.shape)
......
......@@ -885,7 +885,7 @@ def sum(x, axis=None, dtype=None, keepdim=False, name=None):
check_variable_and_dtype(
x, 'x', ['bool', 'float16', 'float32', 'float64',
'int32', 'int64', 'complex64', 'complex128',
'int16', 'int32', 'int64', 'complex64', 'complex128',
u'bool', u'float16', u'float32', u'float64',
u'int32', u'int64', u'complex64', u'complex128'], 'sum')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册