prelu_kernel.cu 2.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16
#include "paddle/phi/kernels/prelu_kernel.h"

17 18
#include "glog/logging.h"

19 20
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
21 22
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
23 24 25 26 27 28 29 30 31
#include "paddle/phi/kernels/gpu/prelu_funcs.h"

namespace phi {

template <typename T, typename Context>
void PReluKernel(const Context& dev_ctx,
                 const DenseTensor& x,
                 const DenseTensor& alpha,
                 const std::string& data_format,
32
                 const std::string& mode,
33
                 DenseTensor* out) {
34
  dev_ctx.template Alloc<T>(out);
35 36
  const T* x_ptr = x.data<T>();
  const T* alpha_ptr = alpha.data<T>();
37

38 39 40 41 42
  int numel = x.numel();
  auto dim = x.dims();
  auto x_rank = dim.size();

  VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim["
43 44
          << x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel
          << ", mode:" << mode << ", format:" << data_format;
45 46 47 48

  if (mode == "channel") {
    bool channel_last = data_format == "NHWC";
    size_t channel = channel_last ? dim[x_rank - 1] : dim[1];
49 50 51 52 53 54 55 56 57 58 59
    if (channel_last) {
      auto func = PReluChannelLastWiseCUDAFunctor<T>(x_ptr, alpha_ptr, channel);
      phi::IndexKernel<T, PReluChannelLastWiseCUDAFunctor<T>>(
          dev_ctx, out, func);
    } else {
      size_t plane_size = numel / dim[0] / channel;
      auto func = PReluChannelFirstWiseCUDAFunctor<T>(
          x_ptr, alpha_ptr, numel, channel, plane_size);
      phi::IndexKernel<T, PReluChannelFirstWiseCUDAFunctor<T>>(
          dev_ctx, out, func);
    }
60
  } else if (mode == "element") {
61 62 63 64 65
    size_t spatial_size = numel / dim[0];
    auto func =
        PreluElementWiseDirectCUDAFunctor<T>(x_ptr, alpha_ptr, spatial_size);
    phi::IndexKernel<T, PreluElementWiseDirectCUDAFunctor<T>>(
        dev_ctx, out, func);
66
  } else {
67 68 69 70
    std::vector<const DenseTensor*> ins = {&x};
    std::vector<DenseTensor*> outs = {out};
    auto func = PreluScalarDirectCUDAFunctor<T>(alpha_ptr);
    phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, func);
71 72 73 74 75 76 77 78 79 80 81 82
  }
}

}  // namespace phi

PD_REGISTER_KERNEL(prelu,
                   GPU,
                   ALL_LAYOUT,
                   phi::PReluKernel,
                   float,
                   phi::dtype::float16,
                   double) {}