unpooling.cu 5.4 KB
Newer Older
S
sweetsky0901 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sweetsky0901 已提交
15
#include "paddle/operators/math/unpooling.h"
S
sweetsky0901 已提交
16 17 18 19 20
#include "paddle/platform/cuda_helper.h"

namespace paddle {
namespace operators {
namespace math {
S
sweetsky0901 已提交
21
template <typename T>
S
sweetsky0901 已提交
22 23 24 25
__global__ void KernelUnpool2dMax(
    const int nthreads, const T* input_data, const int* indices_data,
    const int input_height, const int input_width, const int channels,
    T* output_data, const int output_height, const int output_width) {
S
sweetsky0901 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
  int in_n_stride = input_height * input_width * channels;
  int in_c_stride = input_height * input_width;
  int out_n_stride = output_height * output_width * channels;
  int out_c_stride = output_height * output_width;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int bidx = i / in_n_stride;
    int boffset = i % in_n_stride;
    int cidx = boffset / in_c_stride;
    int out_offset = bidx * out_n_stride + cidx * out_c_stride;
    int out_index = indices_data[i];
    PADDLE_ASSERT(out_index < out_c_stride);
    output_data[out_offset + out_index] = input_data[i];
  }
S
sweetsky0901 已提交
41
}
S
sweetsky0901 已提交
42
template <typename T>
S
sweetsky0901 已提交
43 44 45 46 47
__global__ void KernelUnpool2dMaxGrad(
    const int nthreads, const T* input_data, const int* indices_data,
    const int input_height, const int input_width, const int channels,
    const T* output_data, const T* output_grad, const int output_height,
    const int output_width, T* input_grad) {
S
sweetsky0901 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
  int in_n_stride = input_height * input_width * channels;
  int in_c_stride = input_height * input_width;
  int out_n_stride = output_height * output_width * channels;
  int out_c_stride = output_height * output_width;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int bidx = i / in_n_stride;
    int boffset = i % in_n_stride;
    int cidx = boffset / in_c_stride;
    int out_offset = bidx * out_n_stride + cidx * out_c_stride;
    int out_index = indices_data[i];
    PADDLE_ASSERT(out_index < out_c_stride);
    input_grad[i] = output_grad[out_offset + out_index];
  }
S
sweetsky0901 已提交
63 64 65 66
}
/*
 * All tensors are in NCHW format.
 */
S
sweetsky0901 已提交
67 68
template <typename T>
class Unpool2dMaxFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
69 70 71 72
 public:
  void operator()(
    const platform::DeviceContext& context, const framework::Tensor& input,
    const framework::Tensor& indices, framework::Tensor* output) {
S
sweetsky0901 已提交
73 74 75 76 77 78 79
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
80
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
81
    T* output_data = output->mutable_data<T>(context.GetPlace());
82
    int threads = 1024;
S
sweetsky0901 已提交
83
    int grid = (input.numel() + threads - 1) / threads;
S
sweetsky0901 已提交
84
    KernelUnpool2dMax<T><<<grid, threads, 0,
S
sweetsky0901 已提交
85
             reinterpret_cast<const platform::CUDADeviceContext&>(context)
S
sweetsky0901 已提交
86 87 88
             .stream()>>>(input.numel(), input_data, indices_data,
                input_height, input_width, output_channels,
                output_data, output_height, output_width);
S
sweetsky0901 已提交
89 90 91 92 93
  }
};
/*
 * All tensors are in NCHW format.
 */
S
sweetsky0901 已提交
94 95
template <typename T>
class Unpool2dMaxGradFunctor<platform::GPUPlace, T> {
S
sweetsky0901 已提交
96
 public:
S
sweetsky0901 已提交
97 98
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input,
S
sweetsky0901 已提交
99
                  const framework::Tensor& indices,
S
sweetsky0901 已提交
100
                  const framework::Tensor& output,
S
sweetsky0901 已提交
101
                  const framework::Tensor& output_grad,
S
sweetsky0901 已提交
102
                  framework::Tensor* input_grad) {
S
sweetsky0901 已提交
103 104 105 106 107 108 109
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
110
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
111 112 113
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
114
    int threads = 1024;
S
sweetsky0901 已提交
115
    int grid = (input.numel() + threads - 1) / threads;
S
sweetsky0901 已提交
116 117 118 119 120
    KernelUnpool2dMaxGrad<T><<<grid, threads, 0,
          reinterpret_cast<const platform::CUDADeviceContext&>(context)
          .stream()>>>(input.numel(), input_data, indices_data,
              input_height, input_width, output_channels, output_data,
              output_grad_data, output_height, output_width, input_grad_data);
S
sweetsky0901 已提交
121 122
  }
};
S
sweetsky0901 已提交
123 124 125 126
template class Unpool2dMaxGradFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxGradFunctor<platform::GPUPlace, double>;
template class Unpool2dMaxFunctor<platform::GPUPlace, float>;
template class Unpool2dMaxFunctor<platform::GPUPlace, double>;
S
sweetsky0901 已提交
127 128 129
}  // namespace math
}  // namespace operators
}  // namespace paddle