prelu_op.cu 7.6 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
#include <vector>
14

N
nhzlx 已提交
15 16 17
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/operators/prelu_op.h"
18
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
19
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
N
nhzlx 已提交
20 21 22 23 24 25

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

26 27 28 29 30 31
#define CUDA_NUM_THREADS 1024

inline static int PADDLE_GET_BLOCKS(const int N) {
  return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}

N
nhzlx 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44
template <typename DeviceContext, typename T>
class CUDAPReluKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
    auto* alpha = context.Input<Tensor>("Alpha");
    auto* out = context.Output<Tensor>("Out");

    const T* x_ptr = x->data<T>();
    T* o_ptr = out->mutable_data<T>(context.GetPlace());

    const T* alpha_ptr = alpha->data<T>();
    auto& mode = context.Attr<std::string>("mode");
45
    auto& data_format = context.Attr<std::string>("data_format");
N
nhzlx 已提交
46 47 48

    int numel = x->numel();
    auto dim = x->dims();
49
    auto x_rank = dim.size();
50

51 52
    VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim["
            << x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel;
N
nhzlx 已提交
53 54

    if (mode == "channel") {
55 56
      bool channel_last = data_format == "NHWC";
      size_t channel = channel_last ? dim[x_rank - 1] : dim[1];
N
nhzlx 已提交
57 58
      math::PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
      prelu_channel_wise(context.cuda_device_context().stream(), x_ptr,
59 60
                         alpha_ptr, o_ptr, dim[0], channel, channel_last,
                         numel);
N
nhzlx 已提交
61 62 63
    } else if (mode == "element") {
      math::PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
      prelu_element_wise(context.cuda_device_context().stream(), x_ptr,
64
                         alpha_ptr, o_ptr, dim[0], numel);
N
nhzlx 已提交
65 66 67
    } else {
      math::PreluScalarDirectCUDAFunctor<T> prelu_scalar;
      prelu_scalar(context.cuda_device_context().stream(), x_ptr, alpha_ptr,
68
                   o_ptr, numel);
N
nhzlx 已提交
69 70 71 72
    }
  }
};

73
enum PRELU_MODE { Element, ChannelFirst, ChannelLast, Scalar };
74 75

template <typename T>
76 77 78 79 80
__global__ void PReluOpGradKernel(const T* x_ptr, const T* alpha_ptr,
                                  const T* dy_ptr, T* dx_ptr, T* dalpha_ptr,
                                  size_t channel_num, size_t plane_size,
                                  size_t spatial_size, size_t numel,
                                  PRELU_MODE mode) {
81 82 83 84 85
  CUDA_KERNEL_LOOP(index, numel) {
    T scale;
    if (mode == Element) {
      size_t element_index = index % spatial_size;
      scale = alpha_ptr[element_index];
86
    } else if (mode == ChannelFirst) {
87 88 89
      size_t temp = index / plane_size;
      size_t channel_index = temp % channel_num;
      scale = alpha_ptr[channel_index];
90 91 92
    } else if (mode == ChannelLast) {
      size_t channel_index = index % channel_num;
      scale = alpha_ptr[channel_index];
93 94 95 96 97
    } else {
      scale = alpha_ptr[0];
    }
    T x = x_ptr[index];
    T dy = dy_ptr[index];
C
cc 已提交
98 99 100
    T zero = static_cast<T>(0);
    if (dx_ptr != nullptr) dx_ptr[index] = (x > zero) ? dy : scale * dy;
    if (dalpha_ptr != nullptr) dalpha_ptr[index] = (x > zero) ? zero : x * dy;
101 102 103
  }
}

104 105
template <typename T>
class PreluOpGradFunctor {
106
 public:
107
  void operator()(gpuStream_t stream, const T* x, const T* alpha, const T* dy,
108
                  T* dx, T* dalpha, const framework::DDim& input_dims,
109
                  PRELU_MODE mode) {
110 111 112 113 114 115
    size_t numel = 1;
    for (size_t i = 0; i < input_dims.size(); ++i) {
      numel *= input_dims[i];
    }
    size_t plane_size = numel / input_dims[0] / input_dims[1];
    size_t spatial_size = numel / input_dims[0];
116 117
    size_t channel =
        mode == ChannelLast ? input_dims[input_dims.size() - 1] : input_dims[1];
118

119 120
    PReluOpGradKernel<
        T><<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
121 122
        x, alpha, dy, dx, dalpha, channel, plane_size, spatial_size, numel,
        mode);
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
  }
};

template <typename DeviceContext, typename T>
class CUDAPReluGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* x = context.Input<Tensor>("X");
    auto* alpha = context.Input<Tensor>("Alpha");
    auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
    auto* dy = context.Input<Tensor>(framework::GradVarName("Out"));
    auto* dalpha = context.Output<Tensor>(framework::GradVarName("Alpha"));

    const T* x_ptr = x->data<T>();
    const T* alpha_ptr = alpha->data<T>();
    const T* dy_ptr = dy->data<T>();
    T* dx_ptr = dx ? dx->mutable_data<T>(context.GetPlace()) : nullptr;
    T* dalpha_ptr =
        dalpha ? dalpha->mutable_data<T>(context.GetPlace()) : nullptr;

    if (!dx && !dalpha) return;

    auto& mode = context.Attr<std::string>("mode");
146
    auto& data_format = context.Attr<std::string>("data_format");
147 148 149

    int numel = x->numel();
    auto dim = x->dims();
150
    auto x_rank = dim.size();
151
    std::vector<int> input_shape = framework::vectorize<int>(dim);
152 153 154 155
    auto stream = context.cuda_device_context().stream();

    T* dalpha_tmp_ptr;
    Tensor dalpha_tmp;
156
    if (dalpha_ptr == nullptr) {
157 158 159 160 161 162 163
      dalpha_tmp_ptr = dalpha_ptr;
    } else {
      auto& dev_ctx = context.template device_context<DeviceContext>();
      dalpha_tmp = context.AllocateTmpTensor<T, DeviceContext>(dim, dev_ctx);
      dalpha_tmp_ptr = dalpha_tmp.mutable_data<T>(context.GetPlace());
    }

164
    PRELU_MODE m;
165
    bool channel_last = false;
166
    if (mode == "element") {
167
      m = Element;
168
    } else if (mode == "channel") {
169 170
      channel_last = data_format == "NHWC";
      m = channel_last ? ChannelLast : ChannelFirst;
171
    } else {
172
      m = Scalar;
173
    }
174
    PreluOpGradFunctor<T> prelu_grad;
175 176
    prelu_grad(stream, x_ptr, alpha_ptr, dy_ptr, dx_ptr, dalpha_tmp_ptr, dim,
               m);
177

178
    if (dalpha_tmp_ptr == nullptr) return;
179 180

    std::vector<int> reduce_dims;
181
    for (size_t i = 0; i < dim.size(); i++) {
182 183
      if (mode == "channel" && !channel_last && i == 1) continue;
      if (mode == "channel" && channel_last && i == dim.size() - 1) continue;
184
      if (mode == "element" && i != 0) continue;
185 186 187
      reduce_dims.push_back(i);
    }

188
    TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
W
Wilber 已提交
189 190
        context.cuda_device_context(), dalpha_tmp, dalpha,
        kps::IdentityFunctor<T>(), reduce_dims, stream);
191 192 193
  }
};

N
nhzlx 已提交
194 195 196 197
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
C
cc 已提交
198
namespace plat = paddle::platform;
N
nhzlx 已提交
199 200
REGISTER_OP_CUDA_KERNEL(
    prelu, ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, float>,
C
cc 已提交
201
    ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, plat::float16>,
N
nhzlx 已提交
202
    ops::CUDAPReluKernel<paddle::platform::CUDADeviceContext, double>);
203 204 205
REGISTER_OP_CUDA_KERNEL(
    prelu_grad,
    ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, float>,
C
cc 已提交
206 207
    ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext,
                             plat::float16>,
208
    ops::CUDAPReluGradKernel<paddle::platform::CUDADeviceContext, double>);