unpooling.cu 5.7 KB
Newer Older
S
sweetsky0901 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sweetsky0901 已提交
15
#include "paddle/operators/math/unpooling.h"
S
sweetsky0901 已提交
16 17 18 19 20 21 22 23 24
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {

template <typename T>
__global__ void KernelUnpool2dMax(const int nthreads,
                                  const T* input_data,
S
sweetsky0901 已提交
25
                                  const int* indices_data,
S
sweetsky0901 已提交
26 27 28 29 30 31 32 33 34 35 36
                                  const int input_height,
                                  const int input_width,
                                  T* output_data,
                                  const int output_height,
                                  const int output_width) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int out_offset =  i / (input_height * input_width) \
                      * output_height * output_width;
    int out_index = indices_data[i];
S
sweetsky0901 已提交
37
    PADDLE_ASSERT(out_index < (output_height * output_width));
S
sweetsky0901 已提交
38 39 40 41 42 43
    output_data[out_offset + out_index] = input_data[i];
  }
}
template <typename T>
__global__ void KernelUnpool2dMaxGrad(const int nthreads,
                                      const T* input_data,
S
sweetsky0901 已提交
44
                                      const int* indices_data,
S
sweetsky0901 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57
                                      const int input_height,
                                      const int input_width,
                                      const T* output_data,
                                      const T* output_grad,
                                      const int output_height,
                                      const int output_width,
                                      T* input_grad) {
    int index = blockIdx.x * blockDim.x + threadIdx.x;
    int offset = blockDim.x * gridDim.x;
    for (int i = index; i < nthreads; i += offset) {
        int out_offset =  i / (input_height * input_width) \
                          * output_height * output_width;
        int out_index = indices_data[i];
S
sweetsky0901 已提交
58
        PADDLE_ASSERT(out_index < (output_height * output_width));
S
sweetsky0901 已提交
59 60 61 62 63 64 65
        input_grad[i] = output_grad[out_offset + out_index];
    }
}
/*
 * All tensors are in NCHW format.
 */
template <typename T>
S
sweetsky0901 已提交
66
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
67 68 69 70 71 72 73 74 75 76 77 78
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input,
                  const framework::Tensor& indices,
                  framework::Tensor * output) {
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
79
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
    T* output_data = output->mutable_data<T>(context.GetPlace());

    int nthreads =  output->numel();
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelUnpool2dMax<
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(nthreads, input_data, indices_data,
                              input_height, input_width,
                              output_data, output_height, output_width);
  }
};
/*
 * All tensors are in NCHW format.
 */
template <typename T>
S
sweetsky0901 已提交
99
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
100 101 102
 public:
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input,
S
sweetsky0901 已提交
103
                  const framework::Tensor& indices,
S
sweetsky0901 已提交
104 105
                  framework::Tensor * input_grad,
                  const framework::Tensor& output,
S
sweetsky0901 已提交
106
                  const framework::Tensor& output_grad) {
S
sweetsky0901 已提交
107 108 109 110 111 112 113
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
114
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
    int nthreads =  output.numel();
    int blocks = (nthreads + 1024 - 1) / 1024;
    dim3 threads(1024, 1);
    dim3 grid(blocks, 1);

    KernelUnpool2dMaxGrad<
        T><<<grid, threads, 0,
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
                 .stream()>>>(
                              nthreads, input_data, indices_data,
                              input_height, input_width,
                              output_data, output_grad_data,
                              output_height, output_width,
                              input_grad_data);
  }
};

S
sweetsky0901 已提交
135 136
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
S
sweetsky0901 已提交
137

S
sweetsky0901 已提交
138 139
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
S
sweetsky0901 已提交
140 141 142 143

}  // namespace math
}  // namespace operators
}  // namespace paddle