unpooling.cu 6.4 KB
Newer Older
S
sweetsky0901 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sweetsky0901 已提交
15
#include "paddle/operators/math/unpooling.h"
S
sweetsky0901 已提交
16 17 18 19 20 21 22 23 24
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads,
                                  const T* input_data,
S
sweetsky0901 已提交
25
                                  const T* indices_data,
S
sweetsky0901 已提交
26 27
                                  const int input_height,
                                  const int input_width,
S
sweetsky0901 已提交
28
                                  const int channels,
S
sweetsky0901 已提交
29 30 31
                                  T* output_data,
                                  const int output_height,
                                  const int output_width) {
S
sweetsky0901 已提交
32 33 34 35
  int bsize = input_height * input_width * channels;
  int csize = input_height * input_width;
  int out_bsize = output_height * output_width * channels;
  int out_csize = output_height * output_width;
S
sweetsky0901 已提交
36 37 38
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
S
sweetsky0901 已提交
39 40 41 42
    int bidx = i / bsize;
    int boffset = i % bsize;
    int cidx = boffset / csize;
    int out_offset = bidx * out_bsize + cidx * out_csize;
S
sweetsky0901 已提交
43
    int out_index = indices_data[i];
S
sweetsky0901 已提交
44
    PADDLE_ASSERT(out_index < (output_height * output_width));
S
sweetsky0901 已提交
45 46 47 48 49 50
    output_data[out_offset + out_index] = input_data[i];
  }
}
template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
                                      const T* input_data,
S
sweetsky0901 已提交
51
                                      const T* indices_data,
S
sweetsky0901 已提交
52 53
                                      const int input_height,
                                      const int input_width,
S
sweetsky0901 已提交
54
                                      const int channels,
S
sweetsky0901 已提交
55 56 57 58 59
                                      const T* output_data,
                                      const T* output_grad,
                                      const int output_height,
                                      const int output_width,
                                      T* input_grad) {
S
sweetsky0901 已提交
60 61 62 63
    int bsize = input_height * input_width * channels;
    int csize = input_height * input_width;
    int out_bsize = output_height * output_width * channels;
    int out_csize = output_height * output_width;
S
sweetsky0901 已提交
64 65 66
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int offset = blockDim.x * gridDim.x;
    for (int i = index; i < nthreads; i += offset) {
S
sweetsky0901 已提交
67 68 69 70 71 72 73
      int bidx = i / bsize;
      int boffset = i % bsize;
      int cidx = boffset / csize;
      int out_offset = bidx * out_bsize + cidx * out_csize;
      int out_index = indices_data[i];
      PADDLE_ASSERT(out_index < (output_height * output_width));
      input_grad[i] = output_grad[out_offset + out_index];
S
sweetsky0901 已提交
74 75 76 77 78 79
    }
}
/*
 * All tensors are in NCHW format.
 */
template <typename T>
S
sweetsky0901 已提交
80
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
81 82 83 84 85 86 87 88 89 90 91 92
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& indices,
                  framework::Tensor * output) {
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
93
    const T* indices_data = indices.data<T>();
S
sweetsky0901 已提交
94
    T* output_data = output->mutable_data<T>(context.GetPlace());
S
sweetsky0901 已提交
95
    int nthreads = batch_size * output_channels * input_height * input_width;
S
sweetsky0901 已提交
96 97 98 99 100 101 102 103
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelUnpool2dMax<
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(nthreads, input_data, indices_data,
S
sweetsky0901 已提交
104
                              input_height, input_width, output_channels,
S
sweetsky0901 已提交
105 106 107 108 109 110 111
                              output_data, output_height, output_width);
  }
};
/*
 * All tensors are in NCHW format.
 */
template <typename T>
S
sweetsky0901 已提交
112
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
113 114 115
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input,
S
sweetsky0901 已提交
116
                  const framework::Tensor& indices,
S
sweetsky0901 已提交
117 118
                  framework::Tensor * input_grad,
                  const framework::Tensor& output,
S
sweetsky0901 已提交
119
                  const framework::Tensor& output_grad) {
S
sweetsky0901 已提交
120 121 122 123 124 125 126
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
127
    const T* indices_data = indices.data<T>();
S
sweetsky0901 已提交
128 129 130
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
S
sweetsky0901 已提交
131
    int nthreads = batch_size * output_channels * input_height * input_width;
S
sweetsky0901 已提交
132 133 134 135 136 137 138 139 140
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelUnpool2dMaxGrad<
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
                              nthreads, input_data, indices_data,
S
sweetsky0901 已提交
141
                              input_height, input_width, output_channels,
S
sweetsky0901 已提交
142 143 144 145 146 147
                              output_data, output_grad_data,
                              output_height, output_width,
                              input_grad_data);
  }
};

S
sweetsky0901 已提交
148 149
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
S
sweetsky0901 已提交
150

S
sweetsky0901 已提交
151 152
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
S
sweetsky0901 已提交
153 154 155 156

}  // namespace math
}  // namespace operators
}  // namespace paddle