abs_kernel.cu 2.6 KB
Newer Older
F
From00 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/abs_kernel.h"

F
From00 已提交
17 18
#include <algorithm>
#include <vector>
19
#include "paddle/phi/backends/gpu/gpu_context.h"
20
#include "paddle/phi/common/bfloat16.h"
21 22 23 24
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
F
From00 已提交
25

26
namespace phi {
F
From00 已提交
27 28 29 30 31

template <typename T, typename Enable = void>
struct CudaAbsFunctor;

template <typename T>
32 33
struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::dtype::Real<T>>> {
  __device__ __forceinline__ phi::dtype::Real<T> operator()(const T x) const {
F
From00 已提交
34 35 36 37 38
    return abs(x);
  }
};

template <typename T>
39 40 41 42 43 44 45 46 47 48 49 50
struct CudaAbsFunctor<
    T,
    std::enable_if_t<std::is_same<T, phi::dtype::Real<T>>::value &&
                     std::is_same<T, phi::dtype::bfloat16>::value>> {
  __device__ __forceinline__ T operator()(const T x) const { return abs(x); }
};

template <typename T>
struct CudaAbsFunctor<
    T,
    std::enable_if_t<std::is_same<T, phi::dtype::Real<T>>::value &&
                     !std::is_same<T, phi::dtype::bfloat16>::value>> {
F
From00 已提交
51 52 53 54 55 56 57
  __device__ __forceinline__ T operator()(const T x) const {
    return std::abs(x);
  }
};

template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
58
  ctx.template Alloc<phi::dtype::Real<T>>(out);
F
From00 已提交
59 60 61 62
  std::vector<const DenseTensor*> ins = {&x};
  std::vector<DenseTensor*> outs = {out};
  auto functor = CudaAbsFunctor<T>();

63
  funcs::ElementwiseKernel<phi::dtype::Real<T>>(ctx, ins, &outs, functor);
F
From00 已提交
64 65
}

66
}  // namespace phi
F
From00 已提交
67

68
PD_REGISTER_KERNEL(abs,
F
From00 已提交
69 70
                   GPU,
                   ALL_LAYOUT,
71
                   phi::AbsKernel,
F
From00 已提交
72 73 74 75
                   float,
                   double,
                   int,
                   int64_t,
76
                   phi::dtype::float16,
77
                   phi::dtype::bfloat16,
78
                   phi::dtype::complex<float>,
79
                   phi::dtype::complex<double>) {
80
  kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
81
}