maxouting.cu 5.5 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/operators/math/maxouting.h"
#include "paddle/fluid/platform/cuda_helper.h"
W
wanghaox 已提交
17 18 19 20 21

namespace paddle {
namespace operators {
namespace math {

W
wanghaox 已提交
22
template <typename T>
W
wanghaox 已提交
23
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
24 25 26
                             const int channels, const int input_height,
                             const int input_width, int groups,
                             T* output_data) {
W
wanghaox 已提交
27 28
  const int size = input_height * input_width * channels / groups;
  const int feat_len = input_height * input_width;
W
wanghaox 已提交
29 30 31 32 33
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int batch_idx = i / size;
    int batch_offset = i % size;
W
wanghaox 已提交
34 35
    int channel_idx = batch_offset / feat_len;
    int feat_idx = batch_offset % feat_len;
W
wanghaox 已提交
36
    int data_idx =
37
        (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
W
wanghaox 已提交
38
    T ele = static_cast<T>(-FLT_MAX);
W
wanghaox 已提交
39
    for (int g = 0; g < groups; ++g) {
W
wanghaox 已提交
40
      T x = input_data[data_idx + g * feat_len];
W
wanghaox 已提交
41
      ele = ele > x ? ele : x;
W
wanghaox 已提交
42
    }
W
wanghaox 已提交
43
    output_data[i] = ele;
W
wanghaox 已提交
44 45 46
  }
}
template <typename T>
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
__global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
                                 const T* output_data, const T* output_grad,
                                 T* input_grad, const int channels,
                                 const int input_height, const int input_width,
                                 int groups) {
  const int size = input_height * input_width * channels / groups;
  const int feat_len = input_height * input_width;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int batch_idx = i / size;
    int batch_offset = i % size;
    int channel_idx = batch_offset / feat_len;
    int feat_idx = batch_offset % feat_len;
    int data_idx =
W
wanghaox 已提交
62
        (batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
63 64 65 66 67 68 69
    int max_index = -1;
    bool continue_match = true;
    for (int g = 0; g < groups && continue_match; ++g) {
      if (input_data[data_idx + g * feat_len] == output_data[i]) {
        max_index = data_idx + g * feat_len;
        continue_match = false;
        break;
W
wanghaox 已提交
70 71
      }
    }
72 73 74 75
    if (max_index != -1) {
      input_grad[max_index] += output_grad[index];
    }
  }
W
wanghaox 已提交
76 77 78 79
}
/*
 * All tensors are in NCHW format.
 */
W
wanghaox 已提交
80
template <typename T>
Q
QI JUN 已提交
81
class MaxOutFunctor<platform::CUDADeviceContext, T> {
W
wanghaox 已提交
82
 public:
Q
QI JUN 已提交
83
  void operator()(const platform::CUDADeviceContext& context,
84
                  const framework::Tensor& input, framework::Tensor* output,
W
wanghaox 已提交
85
                  int groups) {
W
wanghaox 已提交
86 87 88 89
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
W
wanghaox 已提交
90 91 92
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
W
wanghaox 已提交
93 94

    const T* input_data = input.data<T>();
W
wanghaox 已提交
95
    T* output_data = output->mutable_data<T>(context.GetPlace());
96
    int nthreads = output->numel();
W
wanghaox 已提交
97 98 99 100
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
101 102 103
    KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, input_channels, input_height, input_width, groups,
        output_data);
W
wanghaox 已提交
104 105 106 107 108 109
  }
};
/*
 * All tensors are in NCHW format.
 */
template <typename T>
Q
QI JUN 已提交
110
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
W
wanghaox 已提交
111
 public:
Q
QI JUN 已提交
112
  void operator()(const platform::CUDADeviceContext& context,
113
                  const framework::Tensor& input, framework::Tensor* input_grad,
W
wanghaox 已提交
114
                  const framework::Tensor& output,
115
                  const framework::Tensor& output_grad, int groups) {
W
wanghaox 已提交
116 117 118 119 120 121 122 123 124 125 126
    const int batch_size = input.dims()[0];
    const int input_channels = input.dims()[1];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
W
wanghaox 已提交
127
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
128
    int nthreads = output.numel();
W
wanghaox 已提交
129 130 131 132
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

Q
QI JUN 已提交
133 134 135
    KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
        nthreads, input_data, output_data, output_grad_data, input_grad_data,
        input_channels, input_height, input_width, groups);
W
wanghaox 已提交
136 137 138
  }
};

Q
QI JUN 已提交
139 140
template class MaxOutGradFunctor<platform::CUDADeviceContext, float>;
template class MaxOutGradFunctor<platform::CUDADeviceContext, double>;
W
wanghaox 已提交
141

Q
QI JUN 已提交
142 143
template class MaxOutFunctor<platform::CUDADeviceContext, float>;
template class MaxOutFunctor<platform::CUDADeviceContext, double>;
W
wanghaox 已提交
144 145 146 147

}  // namespace math
}  // namespace operators
}  // namespace paddle