From a88791481484ab6a61540a737336d79c65d021dc Mon Sep 17 00:00:00 2001 From: zyfncg Date: Sat, 15 Jan 2022 12:39:49 +0800 Subject: [PATCH] fix performance problem caused by Conj (#38939) --- paddle/pten/kernels/complex_kernel.h | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/paddle/pten/kernels/complex_kernel.h b/paddle/pten/kernels/complex_kernel.h index b6074f117ea..d12fc730fef 100644 --- a/paddle/pten/kernels/complex_kernel.h +++ b/paddle/pten/kernels/complex_kernel.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/platform/complex.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/infermeta/unary.h" #include "paddle/pten/kernels/empty_kernel.h" @@ -23,7 +24,13 @@ namespace pten { template void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); -template +// If T is complex +template >::value || + std::is_same>::value, + bool> = true> DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { auto out_meta = UnchangedInferMeta(x.meta()); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); @@ -31,4 +38,15 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { return dense_out; } +// If T is not complex +template >::value && + !std::is_same>::value, + bool> = true> +DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { + return x; +} + } // namespace pten -- GitLab