diff --git a/paddle/phi/kernels/xpu/cast_kernel.cc b/paddle/phi/kernels/xpu/cast_kernel.cc index 74e2a622dba865fc45b9af38b870e467f2d84cae..c5fd2d02e3360dcbfc84bab06615038a344958ad 100644 --- a/paddle/phi/kernels/xpu/cast_kernel.cc +++ b/paddle/phi/kernels/xpu/cast_kernel.cc @@ -15,90 +15,66 @@ #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/backends/xpu/xpu_header.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { +template +void CastXPUKernelImpl(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + using XPUInT = typename XPUTypeTrait::Type; + using XPUOutT = typename XPUTypeTrait::Type; + + const auto* in_data = x.data(); + auto* out_data = dev_ctx.template Alloc(out); + auto numel = x.numel(); + + if (numel == 0) { + return; + } + + int r = xpu::cast(dev_ctx.x_context(), + reinterpret_cast(in_data), + reinterpret_cast(out_data), + numel); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); +} template void CastKernel(const Context& dev_ctx, const DenseTensor& x, DataType out_dtype, DenseTensor* out) { - using XPUInTDType = typename XPUTypeTrait::Type; - using XPUTypeFP16 = typename XPUTypeTrait::Type; - - auto* in_data = x.data(); - auto numel = x.numel(); - - int r = -1; switch (out_dtype) { - case phi::DataType::FLOAT32: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - dev_ctx.template Alloc(out), - numel); + case DataType::INT32: + CastXPUKernelImpl(dev_ctx, x, out); break; - case phi::DataType::FLOAT16: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - reinterpret_cast( - dev_ctx.template Alloc(out)), - numel); + case DataType::FLOAT32: + CastXPUKernelImpl(dev_ctx, x, out); break; - case phi::DataType::INT64: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - dev_ctx.template Alloc(out), - numel); + case DataType::FLOAT16: + CastXPUKernelImpl(dev_ctx, x, out); break; - case phi::DataType::INT32: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - dev_ctx.template Alloc(out), - numel); + case DataType::INT64: + CastXPUKernelImpl(dev_ctx, x, out); break; - case phi::DataType::BOOL: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - dev_ctx.template Alloc(out), - numel); + case DataType::BOOL: + CastXPUKernelImpl(dev_ctx, x, out); break; - case phi::DataType::INT8: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - dev_ctx.template Alloc(out), - numel); + case DataType::INT8: + CastXPUKernelImpl(dev_ctx, x, out); break; - case phi::DataType::UINT8: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - dev_ctx.template Alloc(out), - numel); + case DataType::UINT8: + CastXPUKernelImpl(dev_ctx, x, out); break; - case phi::DataType::FLOAT64: - r = xpu::cast( - dev_ctx.x_context(), - reinterpret_cast(in_data), - dev_ctx.template Alloc(out), - numel); + case DataType::FLOAT64: + CastXPUKernelImpl(dev_ctx, x, out); break; default: PADDLE_THROW(phi::errors::Unavailable( "Not supported cast %d -> %d", x.dtype(), out_dtype)); } - - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/gather_kernel.cc b/paddle/phi/kernels/xpu/gather_kernel.cc index 658999c5289d3e871044120b775f2fb67866eb75..bcbd859cd4d690c0f60bbb926b55a67205d3cfab 100644 --- a/paddle/phi/kernels/xpu/gather_kernel.cc +++ b/paddle/phi/kernels/xpu/gather_kernel.cc @@ -29,7 +29,7 @@ void GatherKernel(const Context& dev_ctx, const auto& index_type = index.dtype(); dev_ctx.template Alloc(out); - if (x.numel() == 0) return; + if (x.numel() == 0 || index.numel() == 0) return; const auto index_dims = index.dims(); if (index_dims.size() == 2) { diff --git a/paddle/phi/kernels/xpu/scale_kernel.cc b/paddle/phi/kernels/xpu/scale_kernel.cc index cc8bdf6165e295238600353b2c25dbe654e18181..4d84e1860621d266e127a8a68b24609b619914d6 100644 --- a/paddle/phi/kernels/xpu/scale_kernel.cc +++ b/paddle/phi/kernels/xpu/scale_kernel.cc @@ -15,10 +15,6 @@ #include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" -#include "paddle/phi/backends/xpu/xpu_context.h" -#include "paddle/phi/common/data_type.h" -#include "paddle/phi/common/float16.h" -#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { @@ -39,6 +35,9 @@ void ScaleKernel(const Context& dev_ctx, " expected %s, but got %s.", x.dims().to_str().c_str(), out->dims().to_str().c_str())); + if (x.numel() == 0 || !x.IsInitialized()) { + return; + } using XPUType = typename XPUTypeTrait::Type; int r = xpu::scale(dev_ctx.x_context(), reinterpret_cast(x.data()), diff --git a/test/xpu/test_cast_op_xpu.py b/test/xpu/test_cast_op_xpu.py index baf814e08de8a1a09d27161cdc2e01a5ab2d8263..fbfe6e979e7f4e91ab04453b32490d12370dd3f3 100644 --- a/test/xpu/test_cast_op_xpu.py +++ b/test/xpu/test_cast_op_xpu.py @@ -99,6 +99,18 @@ class TestCastOpError(unittest.TestCase): self.assertRaises(TypeError, paddle.cast, x1, 'int32') +class TestCastOpEmpty(unittest.TestCase): + def test_cast_op_empty(self): + if paddle.is_compiled_with_xpu(): + paddle.set_device('xpu') + paddle.disable_static() + data = paddle.ones([0, 10], dtype='float32') + out = paddle.cast(data, 'int32') + self.assertEqual(out.shape, data.shape) + self.assertEqual(out.dtype, paddle.int32) + paddle.enable_static() + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/test/xpu/test_gather_op_xpu.py b/test/xpu/test_gather_op_xpu.py index 0d132e7185e6435e5c78c1849ca4ad113d5af5ac..bef509edae799f7b643c48b268b8157f3081e9db 100644 --- a/test/xpu/test_gather_op_xpu.py +++ b/test/xpu/test_gather_op_xpu.py @@ -112,6 +112,18 @@ class XPUTestGather(XPUOpTestWrapper): self.index_type = np.int64 +class TestGatherOpEmpty(unittest.TestCase): + def test_gather_empty_index(self): + if paddle.is_compiled_with_xpu(): + paddle.set_device('xpu') + paddle.disable_static() + data = paddle.ones([10], dtype='int32') + index = paddle.ones([], dtype='int32') + out = paddle.gather(data, index) + self.assertEqual(out.shape, index.shape) + paddle.enable_static() + + support_types = get_xpu_op_support_types('gather') for stype in support_types: create_test_class(globals(), XPUTestGather, stype) diff --git a/test/xpu/test_scale_op_xpu.py b/test/xpu/test_scale_op_xpu.py index fbc3b7f8208569799d1971c36bbcd78540ae7658..8ad7800fe2311e0874c3a8eaf84e7440207629cc 100644 --- a/test/xpu/test_scale_op_xpu.py +++ b/test/xpu/test_scale_op_xpu.py @@ -136,6 +136,17 @@ class TestScaleInplaceApiDygraph(TestScaleApiDygraph): return x.scale_(scale, bias) +class TestScaleOpZeroNumelVariable(unittest.TestCase): + def test_check_zero_numel_xpu(self): + if paddle.is_compiled_with_xpu(): + paddle.disable_static() + paddle.set_device('xpu') + data = paddle.ones([0, 1]) + out = paddle.scale(data, 2) + self.assertEqual(out.shape, data.shape) + paddle.enable_static() + + support_types = get_xpu_op_support_types('scale') for stype in support_types: create_test_class(globals(), XPUTestScaleOp, stype)