未验证 提交 4d7ad277 编写于 作者: C Chen Weihang 提交者: GitHub

Fix reduce_sum dtype dispatch bug on gpu (#39349)

* fix pten reduce dispatch bug

* add cast beforce reduce

* fix test failed
上级 96964ff8
...@@ -302,4 +302,44 @@ namespace paddle { ...@@ -302,4 +302,44 @@ namespace paddle {
} \ } \
}() }()
#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \
SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::BOOL, bool, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX64, \
::paddle::complex64, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, \
::paddle::DataType::COMPLEX128, \
::paddle::complex128, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE1, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE1>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE2, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE2>::type, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE( \
NAME, \
SPECIFIED_TYPE3, \
::paddle::experimental::DataTypeToCppType<SPECIFIED_TYPE3>::type, \
__VA_ARGS__) \
default: \
PD_THROW("function " #NAME " is not implemented for data type `", \
__dtype__, \
"`"); \
} \
}()
} // namespace paddle } // namespace paddle
...@@ -220,22 +220,15 @@ void Reduce(const DeviceContext& dev_ctx, ...@@ -220,22 +220,15 @@ void Reduce(const DeviceContext& dev_ctx,
// no need to cast dtype // no need to cast dtype
if (out_dtype == pten::DataType::UNDEFINED || out_dtype == x.dtype()) { if (out_dtype == pten::DataType::UNDEFINED || out_dtype == x.dtype()) {
if (out_dtype == pten::DataType::UNDEFINED) {
out_dtype = x.dtype();
}
// do reduce sum // do reduce sum
PD_VISIT_ALL_TYPES( PD_VISIT_ALL_TYPES(
out_dtype, "ReduceKernelImpl", ([&] { x.dtype(), "ReduceKernelImpl", ([&] {
pten::ReduceKernelImpl<DeviceContext, T, data_t, Functor>( pten::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
dev_ctx, x, out, dims, keep_dim, reduce_all); dev_ctx, x, out, dims, keep_dim, reduce_all);
})); }));
} else { } else {
pten::DenseTensor tmp_tensor = pten::DenseTensor(
pten::make_intrusive<paddle::experimental::SharedStorage>(x.place()),
pten::DenseTensorMeta(out_dtype, x.dims(), x.layout()));
// cast x tensor to out_dtype // cast x tensor to out_dtype
pten::CastKernel<T, DeviceContext>(dev_ctx, x, out_dtype, &tmp_tensor); auto tmp_tensor = pten::Cast<T, DeviceContext>(dev_ctx, x, out_dtype);
// do reduce sum // do reduce sum
PD_VISIT_ALL_TYPES( PD_VISIT_ALL_TYPES(
......
...@@ -43,6 +43,7 @@ namespace cub = hipcub; ...@@ -43,6 +43,7 @@ namespace cub = hipcub;
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h" #include "paddle/pten/core/enforce.h"
#include "paddle/pten/core/utils/array.h" #include "paddle/pten/core/utils/array.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h" #include "paddle/pten/kernels/funcs/elementwise_base.h"
#include "paddle/pten/kernels/primitive/kernel_primitives.h" #include "paddle/pten/kernels/primitive/kernel_primitives.h"
...@@ -1232,21 +1233,23 @@ void Reduce(const GPUContext& dev_ctx, ...@@ -1232,21 +1233,23 @@ void Reduce(const GPUContext& dev_ctx,
gpuStream_t stream = dev_ctx.stream(); gpuStream_t stream = dev_ctx.stream();
if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) { if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) {
PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( auto tmp_tensor = pten::Cast<T>(dev_ctx, x, out_dtype);
PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES(
pten::DataType::INT32, pten::DataType::INT32,
pten::DataType::INT64, pten::DataType::INT64,
pten::DataType::FLOAT16,
out_dtype, out_dtype,
"TensorReduceFunctorImpl", "TensorReduceFunctorImpl",
([&] { ([&] {
using MPType = typename kps::details::MPTypeTrait<data_t>::Type; using MPType = typename kps::details::MPTypeTrait<data_t>::Type;
pten::kernels::TensorReduceFunctorImpl<T, pten::kernels::TensorReduceFunctorImpl<data_t,
data_t, data_t,
ReduceOp, ReduceOp,
TransformOp<T, MPType>>( TransformOp<data_t, MPType>>(
dev_ctx, dev_ctx,
x, tmp_tensor,
out, out,
TransformOp<T, MPType>(reduce_num), TransformOp<data_t, MPType>(reduce_num),
reduce_dims, reduce_dims,
stream); stream);
})); }));
......
...@@ -740,17 +740,24 @@ class API_TestSumOp(unittest.TestCase): ...@@ -740,17 +740,24 @@ class API_TestSumOp(unittest.TestCase):
if np_axis is None: if np_axis is None:
np_axis = attr_axis np_axis = attr_axis
with fluid.program_guard(fluid.Program(), fluid.Program()): places = [fluid.CPUPlace()]
data = fluid.data("data", shape=shape, dtype=x_dtype) if core.is_compiled_with_cuda():
result_sum = paddle.sum(x=data, axis=attr_axis, dtype=attr_dtype) places.append(fluid.CUDAPlace(0))
for place in places:
exe = fluid.Executor(fluid.CPUPlace()) with fluid.program_guard(fluid.Program(), fluid.Program()):
input_data = np.random.rand(*shape).astype(x_dtype) data = fluid.data("data", shape=shape, dtype=x_dtype)
res, = exe.run(feed={"data": input_data}, fetch_list=[result_sum]) result_sum = paddle.sum(x=data,
axis=attr_axis,
self.assertTrue( dtype=attr_dtype)
np.allclose(
res, np.sum(input_data.astype(attr_dtype), axis=np_axis))) exe = fluid.Executor(place)
input_data = np.random.rand(*shape).astype(x_dtype)
res, = exe.run(feed={"data": input_data},
fetch_list=[result_sum])
self.assertTrue(
np.allclose(
res, np.sum(input_data.astype(attr_dtype), axis=np_axis)))
def test_static(self): def test_static(self):
shape = [10, 10] shape = [10, 10]
...@@ -759,10 +766,12 @@ class API_TestSumOp(unittest.TestCase): ...@@ -759,10 +766,12 @@ class API_TestSumOp(unittest.TestCase):
self.run_static(shape, "bool", axis, attr_dtype=None) self.run_static(shape, "bool", axis, attr_dtype=None)
self.run_static(shape, "bool", axis, attr_dtype="int32") self.run_static(shape, "bool", axis, attr_dtype="int32")
self.run_static(shape, "bool", axis, attr_dtype="int64") self.run_static(shape, "bool", axis, attr_dtype="int64")
self.run_static(shape, "bool", axis, attr_dtype="float16")
self.run_static(shape, "int32", axis, attr_dtype=None) self.run_static(shape, "int32", axis, attr_dtype=None)
self.run_static(shape, "int32", axis, attr_dtype="int32") self.run_static(shape, "int32", axis, attr_dtype="int32")
self.run_static(shape, "int32", axis, attr_dtype="int64") self.run_static(shape, "int32", axis, attr_dtype="int64")
self.run_static(shape, "int32", axis, attr_dtype="float64")
self.run_static(shape, "int64", axis, attr_dtype=None) self.run_static(shape, "int64", axis, attr_dtype=None)
self.run_static(shape, "int64", axis, attr_dtype="int64") self.run_static(shape, "int64", axis, attr_dtype="int64")
...@@ -771,6 +780,7 @@ class API_TestSumOp(unittest.TestCase): ...@@ -771,6 +780,7 @@ class API_TestSumOp(unittest.TestCase):
self.run_static(shape, "float32", axis, attr_dtype=None) self.run_static(shape, "float32", axis, attr_dtype=None)
self.run_static(shape, "float32", axis, attr_dtype="float32") self.run_static(shape, "float32", axis, attr_dtype="float32")
self.run_static(shape, "float32", axis, attr_dtype="float64") self.run_static(shape, "float32", axis, attr_dtype="float64")
self.run_static(shape, "float32", axis, attr_dtype="int64")
self.run_static(shape, "float64", axis, attr_dtype=None) self.run_static(shape, "float64", axis, attr_dtype=None)
self.run_static(shape, "float64", axis, attr_dtype="float32") self.run_static(shape, "float64", axis, attr_dtype="float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册