未验证 提交 4d7ad277 编写于 作者: C Chen Weihang 提交者: GitHub

Fix reduce_sum dtype dispatch bug on gpu (#39349)

* fix pten reduce dispatch bug

* add cast beforce reduce

* fix test failed
上级 96964ff8
......@@ -302,4 +302,44 @@ namespace paddle {
} \
}()
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \
SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::BOOL, bool, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX64, \
::paddle::complex64, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX128, \
::paddle::complex128, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE1, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE1>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE2, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE2>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE3, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE3>::type, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
} // namespace paddle
......@@ -220,22 +220,15 @@ void Reduce(const DeviceContext& dev_ctx,
// no need to cast dtype
if (out_dtype == pten::DataType::UNDEFINED || out_dtype == x.dtype()) {
if (out_dtype == pten::DataType::UNDEFINED) {
out_dtype = x.dtype();
}
// do reduce sum
PD_VISIT_ALL_TYPES(
out_dtype, "ReduceKernelImpl", ([&] {
x.dtype(), "ReduceKernelImpl", ([&] {
pten::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
dev_ctx, x, out, dims, keep_dim, reduce_all);
}));
} else {
pten::DenseTensor tmp_tensor = pten::DenseTensor(
pten::make_intrusive<paddle::experimental::SharedStorage>(x.place()),
pten::DenseTensorMeta(out_dtype, x.dims(), x.layout()));
// cast x tensor to out_dtype
pten::CastKernel<T, DeviceContext>(dev_ctx, x, out_dtype, &tmp_tensor);
auto tmp_tensor = pten::Cast<T, DeviceContext>(dev_ctx, x, out_dtype);
// do reduce sum
PD_VISIT_ALL_TYPES(
......
......@@ -43,6 +43,7 @@ namespace cub = hipcub;
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/utils/array.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
#include "paddle/pten/kernels/primitive/kernel_primitives.h"
......@@ -1232,21 +1233,23 @@ void Reduce(const GPUContext& dev_ctx,
gpuStream_t stream = dev_ctx.stream();
if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) {
PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES(
auto tmp_tensor = pten::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES(
pten::DataType::INT32,
pten::DataType::INT64,
pten::DataType::FLOAT16,
out_dtype,
"TensorReduceFunctorImpl",
([&] {
using MPType = typename kps::details::MPTypeTrait<data_t>::Type;
pten::kernels::TensorReduceFunctorImpl<T,
pten::kernels::TensorReduceFunctorImpl<data_t,
data_t,
ReduceOp,
TransformOp<T, MPType>>(
TransformOp<data_t, MPType>>(
dev_ctx,
x,
tmp_tensor,
out,
TransformOp<T, MPType>(reduce_num),
TransformOp<data_t, MPType>(reduce_num),
reduce_dims,
stream);
}));
......
......@@ -740,13 +740,20 @@ class API_TestSumOp(unittest.TestCase):
if np_axis is None:
np_axis = attr_axis
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for place in places:
with fluid.program_guard(fluid.Program(), fluid.Program()):
data = fluid.data("data", shape=shape, dtype=x_dtype)
result_sum = paddle.sum(x=data, axis=attr_axis, dtype=attr_dtype)
result_sum = paddle.sum(x=data,
axis=attr_axis,
dtype=attr_dtype)
exe = fluid.Executor(fluid.CPUPlace())
exe = fluid.Executor(place)
input_data = np.random.rand(*shape).astype(x_dtype)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum])
res, = exe.run(feed={"data": input_data},
fetch_list=[result_sum])
self.assertTrue(
np.allclose(
......@@ -759,10 +766,12 @@ class API_TestSumOp(unittest.TestCase):
self.run_static(shape, "bool", axis, attr_dtype=None)
self.run_static(shape, "bool", axis, attr_dtype="int32")
self.run_static(shape, "bool", axis, attr_dtype="int64")
self.run_static(shape, "bool", axis, attr_dtype="float16")
self.run_static(shape, "int32", axis, attr_dtype=None)
self.run_static(shape, "int32", axis, attr_dtype="int32")
self.run_static(shape, "int32", axis, attr_dtype="int64")
self.run_static(shape, "int32", axis, attr_dtype="float64")
self.run_static(shape, "int64", axis, attr_dtype=None)
self.run_static(shape, "int64", axis, attr_dtype="int64")
......@@ -771,6 +780,7 @@ class API_TestSumOp(unittest.TestCase):
self.run_static(shape, "float32", axis, attr_dtype=None)
self.run_static(shape, "float32", axis, attr_dtype="float32")
self.run_static(shape, "float32", axis, attr_dtype="float64")
self.run_static(shape, "float32", axis, attr_dtype="int64")
self.run_static(shape, "float64", axis, attr_dtype=None)
self.run_static(shape, "float64", axis, attr_dtype="float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册