diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h index 4760270caa3c6d7bef36a467e41987ef62ef109b..28b6ebc2433224ac6743b35df5e16c6f4b9402cd 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h @@ -34,6 +34,7 @@ namespace cub = hipcub; #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/cast_op.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/fast_divmod.h" @@ -705,8 +706,16 @@ void TensorReduceFunctorImpl(const framework::Tensor& x, framework::Tensor* y, if (config.reduce_num == 1) { auto out_dims = y->dims(); - framework::TensorCopy(x, y->place(), y); - y->Resize(out_dims); + if (x.type() == y->type()) { + framework::TensorCopy(x, y->place(), y); + y->Resize(out_dims); + } else { + auto* dev_ctx = static_cast( + paddle::platform::DeviceContextPool::Instance().Get(x.place())); + framework::VisitDataType( + static_cast(y->type()), + CastOpFunctor(&x, y, *dev_ctx)); + } return; }