未验证 提交 423dda37 编写于 作者: L lijin23 提交者: GitHub

[XPU][PHI Kernels] fix errors when numel is zero for xpu (#54010)

* fix empty bugs for xpu

* fix empty bugs for xpu
上级 664a2753
...@@ -15,90 +15,66 @@ ...@@ -15,90 +15,66 @@
#include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename InT, typename OutT, typename Context>
void CastXPUKernelImpl(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
using XPUInT = typename XPUTypeTrait<InT>::Type;
using XPUOutT = typename XPUTypeTrait<OutT>::Type;
const auto* in_data = x.data<InT>();
auto* out_data = dev_ctx.template Alloc<OutT>(out);
auto numel = x.numel();
if (numel == 0) {
return;
}
int r = xpu::cast<XPUInT, XPUOutT>(dev_ctx.x_context(),
reinterpret_cast<const XPUInT*>(in_data),
reinterpret_cast<XPUOutT*>(out_data),
numel);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
}
template <typename T, typename Context> template <typename T, typename Context>
void CastKernel(const Context& dev_ctx, void CastKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DataType out_dtype, DataType out_dtype,
DenseTensor* out) { DenseTensor* out) {
using XPUInTDType = typename XPUTypeTrait<T>::Type;
using XPUTypeFP16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
auto* in_data = x.data<T>();
auto numel = x.numel();
int r = -1;
switch (out_dtype) { switch (out_dtype) {
case phi::DataType::FLOAT32: case DataType::INT32:
r = xpu::cast<XPUInTDType, float>( CastXPUKernelImpl<T, int, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<float>(out),
numel);
break; break;
case phi::DataType::FLOAT16: case DataType::FLOAT32:
r = xpu::cast<XPUInTDType, XPUTypeFP16>( CastXPUKernelImpl<T, float, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
reinterpret_cast<XPUTypeFP16*>(
dev_ctx.template Alloc<phi::dtype::float16>(out)),
numel);
break; break;
case phi::DataType::INT64: case DataType::FLOAT16:
r = xpu::cast<XPUInTDType, int64_t>( CastXPUKernelImpl<T, dtype::float16, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<int64_t>(out),
numel);
break; break;
case phi::DataType::INT32: case DataType::INT64:
r = xpu::cast<XPUInTDType, int32_t>( CastXPUKernelImpl<T, int64_t, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<int>(out),
numel);
break; break;
case phi::DataType::BOOL: case DataType::BOOL:
r = xpu::cast<XPUInTDType, bool>( CastXPUKernelImpl<T, bool, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<bool>(out),
numel);
break; break;
case phi::DataType::INT8: case DataType::INT8:
r = xpu::cast<XPUInTDType, int8_t>( CastXPUKernelImpl<T, int8_t, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<int8_t>(out),
numel);
break; break;
case phi::DataType::UINT8: case DataType::UINT8:
r = xpu::cast<XPUInTDType, uint8_t>( CastXPUKernelImpl<T, uint8_t, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<uint8_t>(out),
numel);
break; break;
case phi::DataType::FLOAT64: case DataType::FLOAT64:
r = xpu::cast<XPUInTDType, double>( CastXPUKernelImpl<T, double, Context>(dev_ctx, x, out);
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<double>(out),
numel);
break; break;
default: default:
PADDLE_THROW(phi::errors::Unavailable( PADDLE_THROW(phi::errors::Unavailable(
"Not supported cast %d -> %d", x.dtype(), out_dtype)); "Not supported cast %d -> %d", x.dtype(), out_dtype));
} }
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} }
} // namespace phi } // namespace phi
......
...@@ -29,7 +29,7 @@ void GatherKernel(const Context& dev_ctx, ...@@ -29,7 +29,7 @@ void GatherKernel(const Context& dev_ctx,
const auto& index_type = index.dtype(); const auto& index_type = index.dtype();
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (x.numel() == 0) return; if (x.numel() == 0 || index.numel() == 0) return;
const auto index_dims = index.dims(); const auto index_dims = index.dims();
if (index_dims.size() == 2) { if (index_dims.size() == 2) {
......
...@@ -15,10 +15,6 @@ ...@@ -15,10 +15,6 @@
#include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/kernels/scale_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
...@@ -39,6 +35,9 @@ void ScaleKernel(const Context& dev_ctx, ...@@ -39,6 +35,9 @@ void ScaleKernel(const Context& dev_ctx,
" expected %s, but got %s.", " expected %s, but got %s.",
x.dims().to_str().c_str(), x.dims().to_str().c_str(),
out->dims().to_str().c_str())); out->dims().to_str().c_str()));
if (x.numel() == 0 || !x.IsInitialized()) {
return;
}
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
int r = xpu::scale(dev_ctx.x_context(), int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()), reinterpret_cast<const XPUType*>(x.data<T>()),
......
...@@ -99,6 +99,18 @@ class TestCastOpError(unittest.TestCase): ...@@ -99,6 +99,18 @@ class TestCastOpError(unittest.TestCase):
self.assertRaises(TypeError, paddle.cast, x1, 'int32') self.assertRaises(TypeError, paddle.cast, x1, 'int32')
class TestCastOpEmpty(unittest.TestCase):
def test_cast_op_empty(self):
if paddle.is_compiled_with_xpu():
paddle.set_device('xpu')
paddle.disable_static()
data = paddle.ones([0, 10], dtype='float32')
out = paddle.cast(data, 'int32')
self.assertEqual(out.shape, data.shape)
self.assertEqual(out.dtype, paddle.int32)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static() paddle.enable_static()
unittest.main() unittest.main()
...@@ -112,6 +112,18 @@ class XPUTestGather(XPUOpTestWrapper): ...@@ -112,6 +112,18 @@ class XPUTestGather(XPUOpTestWrapper):
self.index_type = np.int64 self.index_type = np.int64
class TestGatherOpEmpty(unittest.TestCase):
def test_gather_empty_index(self):
if paddle.is_compiled_with_xpu():
paddle.set_device('xpu')
paddle.disable_static()
data = paddle.ones([10], dtype='int32')
index = paddle.ones([], dtype='int32')
out = paddle.gather(data, index)
self.assertEqual(out.shape, index.shape)
paddle.enable_static()
support_types = get_xpu_op_support_types('gather') support_types = get_xpu_op_support_types('gather')
for stype in support_types: for stype in support_types:
create_test_class(globals(), XPUTestGather, stype) create_test_class(globals(), XPUTestGather, stype)
......
...@@ -136,6 +136,17 @@ class TestScaleInplaceApiDygraph(TestScaleApiDygraph): ...@@ -136,6 +136,17 @@ class TestScaleInplaceApiDygraph(TestScaleApiDygraph):
return x.scale_(scale, bias) return x.scale_(scale, bias)
class TestScaleOpZeroNumelVariable(unittest.TestCase):
def test_check_zero_numel_xpu(self):
if paddle.is_compiled_with_xpu():
paddle.disable_static()
paddle.set_device('xpu')
data = paddle.ones([0, 1])
out = paddle.scale(data, 2)
self.assertEqual(out.shape, data.shape)
paddle.enable_static()
support_types = get_xpu_op_support_types('scale') support_types = get_xpu_op_support_types('scale')
for stype in support_types: for stype in support_types:
create_test_class(globals(), XPUTestScaleOp, stype) create_test_class(globals(), XPUTestScaleOp, stype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册