unpooling.cu 6.3 KB
Newer Older
S
sweetsky0901 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sweetsky0901 已提交
15
#include "paddle/operators/math/unpooling.h"
S
sweetsky0901 已提交
16 17 18 19 20 21
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

S
sweetsky0901 已提交
22
template <typename T>
S
sweetsky0901 已提交
23 24
__global__ void KernelUnpool2dMax(const int nthreads,
                                  const T* input_data,
S
sweetsky0901 已提交
25
                                  const int * indices_data,
S
sweetsky0901 已提交
26 27
                                  const int input_height,
                                  const int input_width,
S
sweetsky0901 已提交
28
                                  const int channels,
S
sweetsky0901 已提交
29 30 31
                                  T* output_data,
                                  const int output_height,
                                  const int output_width) {
32 33 34 35
    int in_n_stride = input_height * input_width * channels;
    int in_c_stride = input_height * input_width;
    int out_n_stride = output_height * output_width * channels;
    int out_c_stride = output_height * output_width;
S
sweetsky0901 已提交
36 37 38
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int offset = blockDim.x * gridDim.x;
    for (int i = index; i < nthreads; i += offset) {
39 40 41 42
      int bidx = i / in_n_stride;
      int boffset = i % in_n_stride;
      int cidx = boffset / in_c_stride;
      int out_offset = bidx * out_n_stride + cidx * out_c_stride;
S
sweetsky0901 已提交
43
      int out_index = indices_data[i];
44
      PADDLE_ASSERT(out_index < out_c_stride);
S
sweetsky0901 已提交
45 46
      output_data[out_offset + out_index] = input_data[i];
    }
S
sweetsky0901 已提交
47
}
S
sweetsky0901 已提交
48
template <typename T>
S
sweetsky0901 已提交
49 50
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
                                      const T* input_data,
S
sweetsky0901 已提交
51
                                      const int* indices_data,
S
sweetsky0901 已提交
52 53
                                      const int input_height,
                                      const int input_width,
S
sweetsky0901 已提交
54
                                      const int channels,
S
sweetsky0901 已提交
55 56 57 58 59
                                      const T* output_data,
                                      const T* output_grad,
                                      const int output_height,
                                      const int output_width,
                                      T* input_grad) {
60 61 62 63
    int in_n_stride = input_height * input_width * channels;
    int in_c_stride = input_height * input_width;
    int out_n_stride = output_height * output_width * channels;
    int out_c_stride = output_height * output_width;
S
sweetsky0901 已提交
64 65 66
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int offset = blockDim.x * gridDim.x;
    for (int i = index; i < nthreads; i += offset) {
67 68 69 70
      int bidx = i / in_n_stride;
      int boffset = i % in_n_stride;
      int cidx = boffset / in_c_stride;
      int out_offset = bidx * out_n_stride + cidx * out_c_stride;
S
sweetsky0901 已提交
71
      int out_index = indices_data[i];
72
      PADDLE_ASSERT(out_index < out_c_stride);
S
sweetsky0901 已提交
73
      input_grad[i] = output_grad[out_offset + out_index];
S
sweetsky0901 已提交
74 75 76 77 78
    }
}
/*
 * All tensors are in NCHW format.
 */
S
sweetsky0901 已提交
79 80
template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
81 82 83 84 85 86 87 88 89 90 91 92
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& indices,
                  framework::Tensor * output) {
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
93
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
94
    T* output_data = output->mutable_data<T>(context.GetPlace());
95 96
    int threads = 1024;
    int grid =  (input.numel() + threads - 1) / threads;
S
sweetsky0901 已提交
97
    KernelUnpool2dMax<
S
sweetsky0901 已提交
98
        T><<<grid, threads, 0,
S
sweetsky0901 已提交
99
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
S
sweetsky0901 已提交
100
                 .stream()>>>(input.numel(), input_data, indices_data,
S
sweetsky0901 已提交
101
                              input_height, input_width, output_channels,
S
sweetsky0901 已提交
102 103 104 105 106 107
                              output_data, output_height, output_width);
  }
};
/*
 * All tensors are in NCHW format.
 */
S
sweetsky0901 已提交
108 109
template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
110 111 112
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input,
S
sweetsky0901 已提交
113
                  const framework::Tensor& indices,
S
sweetsky0901 已提交
114
                  const framework::Tensor& output,
S
sweetsky0901 已提交
115 116
                  const framework::Tensor& output_grad,
                  framework::Tensor * input_grad) {
S
sweetsky0901 已提交
117 118 119 120 121 122 123
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
124
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
125 126 127
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
128 129
    int threads = 1024;
    int grid =  (input.numel() + threads - 1) / threads;
S
sweetsky0901 已提交
130
    KernelUnpool2dMaxGrad<
S
sweetsky0901 已提交
131
        T><<<grid, threads, 0,
S
sweetsky0901 已提交
132
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
S
sweetsky0901 已提交
133
                 .stream()>>>(input.numel(), input_data, indices_data,
S
sweetsky0901 已提交
134
                              input_height, input_width, output_channels,
S
sweetsky0901 已提交
135 136 137 138 139 140
                              output_data, output_grad_data,
                              output_height, output_width,
                              input_grad_data);
  }
};

S
sweetsky0901 已提交
141 142
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
S
sweetsky0901 已提交
143

S
sweetsky0901 已提交
144 145
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
S
sweetsky0901 已提交
146 147 148 149

}  // namespace math
}  // namespace operators
}  // namespace paddle