From d5268a6e0ebe77d25af677df9274031f21a08237 Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 28 Sep 2021 10:42:29 +0800 Subject: [PATCH] fix bug of reduce_sum when src_dtype != dst_dtype and reduce_num == 1 (#36123) --- paddle/fluid/operators/reduce_ops/reduce_op.cu.h | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 4760270caa3..28b6ebc2433 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -34,6 +34,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/fast_divmod.h" @@ -705,8 +706,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, if (config.reduce_num == 1) { auto out_dims = y->dims(); - framework::TensorCopy(x, y->place(), y); - y->Resize(out_dims); + if (x.type() == y->type()) { + framework::TensorCopy(x, y->place(), y); + y->Resize(out_dims); + } else { + auto* dev_ctx = static_cast( + paddle::platform::DeviceContextPool::Instance().Get(x.place())); + framework::VisitDataType( + static_cast(y->type()), + CastOpFunctor(&x, y, *dev_ctx)); + } return; } -- GitLab