未验证 提交 423dda37 编写于 作者: L lijin23 提交者: GitHub

[XPU][PHI Kernels] fix errors when numel is zero for xpu (#54010)

* fix empty bugs for xpu

* fix empty bugs for xpu
上级 664a2753
......@@ -15,90 +15,66 @@
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename InT, typename OutT, typename Context>
void CastXPUKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
using XPUInT = typename XPUTypeTrait<InT>::Type;
using XPUOutT = typename XPUTypeTrait<OutT>::Type;
const auto* in_data = x.data<InT>();
auto* out_data = dev_ctx.template Alloc<OutT>(out);
auto numel = x.numel();
if (numel == 0) {
return;
}
int r = xpu::cast<XPUInT, XPUOutT>(dev_ctx.x_context(),
reinterpret_cast<const XPUInT*>(in_data),
reinterpret_cast<XPUOutT*>(out_data),
numel);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
}
template <typename T, typename Context>
void CastKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
using XPUInTDType = typename XPUTypeTrait<T>::Type;
using XPUTypeFP16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
auto* in_data = x.data<T>();
auto numel = x.numel();
int r = -1;
switch (out_dtype) {
case phi::DataType::FLOAT32:
r = xpu::cast<XPUInTDType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<float>(out),
numel);
case DataType::INT32:
CastXPUKernelImpl<T, int, Context>(dev_ctx, x, out);
break;
case phi::DataType::FLOAT16:
r = xpu::cast<XPUInTDType, XPUTypeFP16>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
reinterpret_cast<XPUTypeFP16*>(
dev_ctx.template Alloc<phi::dtype::float16>(out)),
numel);
case DataType::FLOAT32:
CastXPUKernelImpl<T, float, Context>(dev_ctx, x, out);
break;
case phi::DataType::INT64:
r = xpu::cast<XPUInTDType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<int64_t>(out),
numel);
case DataType::FLOAT16:
CastXPUKernelImpl<T, dtype::float16, Context>(dev_ctx, x, out);
break;
case phi::DataType::INT32:
r = xpu::cast<XPUInTDType, int32_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<int>(out),
numel);
case DataType::INT64:
CastXPUKernelImpl<T, int64_t, Context>(dev_ctx, x, out);
break;
case phi::DataType::BOOL:
r = xpu::cast<XPUInTDType, bool>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<bool>(out),
numel);
case DataType::BOOL:
CastXPUKernelImpl<T, bool, Context>(dev_ctx, x, out);
break;
case phi::DataType::INT8:
r = xpu::cast<XPUInTDType, int8_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<int8_t>(out),
numel);
case DataType::INT8:
CastXPUKernelImpl<T, int8_t, Context>(dev_ctx, x, out);
break;
case phi::DataType::UINT8:
r = xpu::cast<XPUInTDType, uint8_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<uint8_t>(out),
numel);
case DataType::UINT8:
CastXPUKernelImpl<T, uint8_t, Context>(dev_ctx, x, out);
break;
case phi::DataType::FLOAT64:
r = xpu::cast<XPUInTDType, double>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<double>(out),
numel);
case DataType::FLOAT64:
CastXPUKernelImpl<T, double, Context>(dev_ctx, x, out);
break;
default:
PADDLE_THROW(phi::errors::Unavailable(
"Not supported cast %d -> %d", x.dtype(), out_dtype));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
}
} // namespace phi
......
......@@ -29,7 +29,7 @@ void GatherKernel(const Context& dev_ctx,
const auto& index_type = index.dtype();
dev_ctx.template Alloc<T>(out);
if (x.numel() == 0) return;
if (x.numel() == 0 || index.numel() == 0) return;
const auto index_dims = index.dims();
if (index_dims.size() == 2) {
......
......@@ -15,10 +15,6 @@
#include "paddle/phi/kernels/scale_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
......@@ -39,6 +35,9 @@ void ScaleKernel(const Context& dev_ctx,
" expected %s, but got %s.",
x.dims().to_str().c_str(),
out->dims().to_str().c_str()));
if (x.numel() == 0 || !x.IsInitialized()) {
return;
}
using XPUType = typename XPUTypeTrait<T>::Type;
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
......
......@@ -99,6 +99,18 @@ class TestCastOpError(unittest.TestCase):
self.assertRaises(TypeError, paddle.cast, x1, 'int32')
class TestCastOpEmpty(unittest.TestCase):
def test_cast_op_empty(self):
if paddle.is_compiled_with_xpu():
paddle.set_device('xpu')
paddle.disable_static()
data = paddle.ones([0, 10], dtype='float32')
out = paddle.cast(data, 'int32')
self.assertEqual(out.shape, data.shape)
self.assertEqual(out.dtype, paddle.int32)
paddle.enable_static()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -112,6 +112,18 @@ class XPUTestGather(XPUOpTestWrapper):
self.index_type = np.int64
class TestGatherOpEmpty(unittest.TestCase):
def test_gather_empty_index(self):
if paddle.is_compiled_with_xpu():
paddle.set_device('xpu')
paddle.disable_static()
data = paddle.ones([10], dtype='int32')
index = paddle.ones([], dtype='int32')
out = paddle.gather(data, index)
self.assertEqual(out.shape, index.shape)
paddle.enable_static()
support_types = get_xpu_op_support_types('gather')
for stype in support_types:
create_test_class(globals(), XPUTestGather, stype)
......
......@@ -136,6 +136,17 @@ class TestScaleInplaceApiDygraph(TestScaleApiDygraph):
return x.scale_(scale, bias)
class TestScaleOpZeroNumelVariable(unittest.TestCase):
def test_check_zero_numel_xpu(self):
if paddle.is_compiled_with_xpu():
paddle.disable_static()
paddle.set_device('xpu')
data = paddle.ones([0, 1])
out = paddle.scale(data, 2)
self.assertEqual(out.shape, data.shape)
paddle.enable_static()
support_types = get_xpu_op_support_types('scale')
for stype in support_types:
create_test_class(globals(), XPUTestScaleOp, stype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册