未验证 提交 a8879148 编写于 作者: Z zyfncg 提交者: GitHub

fix performance problem caused by Conj (#38939)

上级 050aa6fe
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
......@@ -23,7 +24,13 @@ namespace pten {
template <typename T, typename Context>
void ConjKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
// If T is complex
template <typename T,
typename Context,
std::enable_if_t<
std::is_same<T, paddle::platform::complex<float>>::value ||
std::is_same<T, paddle::platform::complex<double>>::value,
bool> = true>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
auto out_meta = UnchangedInferMeta(x.meta());
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta));
......@@ -31,4 +38,15 @@ DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return dense_out;
}
// If T is not complex
template <typename T,
typename Context,
std::enable_if_t<
!std::is_same<T, paddle::platform::complex<float>>::value &&
!std::is_same<T, paddle::platform::complex<double>>::value,
bool> = true>
DenseTensor Conj(const Context& dev_ctx, const DenseTensor& x) {
return x;
}
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册