From e0fd3bbfb154d6732da92f232f32b6501186c510 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Tue, 21 Dec 2021 14:52:00 +0800 Subject: [PATCH] [pten] fix when out_dtype is same with x.dtype and still transform type error (#38285) * fix when out_dtype is same with x.dtype and still transform type error * fix spell error --- paddle/pten/kernels/hybird/cuda/reduce/reduce.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/pten/kernels/hybird/cuda/reduce/reduce.h b/paddle/pten/kernels/hybird/cuda/reduce/reduce.h index c88965e6def..f55d483de14 100644 --- a/paddle/pten/kernels/hybird/cuda/reduce/reduce.h +++ b/paddle/pten/kernels/hybird/cuda/reduce/reduce.h @@ -61,7 +61,7 @@ void Reduce(const CUDAContext& dev_ctx, gpuStream_t stream = dev_ctx.stream(); - if (out_dtype != pten::DataType::UNDEFINED) { + if (out_dtype != pten::DataType::UNDEFINED && out_dtype != x.dtype()) { PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES( out_dtype, "TensorReduceFunctorImpl", ([&] { pten::detail::TensorReduceFunctorImpl( -- GitLab