diff --git a/paddle/pten/kernels/complex_kernel.h b/paddle/pten/kernels/complex_kernel.h index b6074f117ea14ad6b970081f3cde1e21798f3b14..d12fc730fef871f39e43b2b2277860dc1ee95019 100644 --- a/paddle/pten/kernels/complex_kernel.h +++ b/paddle/pten/kernels/complex_kernel.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/fluid/platform/complex.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/infermeta/unary.h" #include "paddle/pten/kernels/empty_kernel.h" @@ -23,7 +24,13 @@ namespace pten { template void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out); -template +// If T is complex +template >::value || + std::is_same>::value, + bool> = true> DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { auto out_meta = UnchangedInferMeta(x.meta()); auto dense_out = pten::Empty(dev_ctx, std::move(out_meta)); @@ -31,4 +38,15 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { return dense_out; } +// If T is not complex +template >::value && + !std::is_same>::value, + bool> = true> +DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) { + return x; +} + } // namespace pten