unpooling.cu 5.9 KB
Newer Older
1
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
S
sweetsky0901 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/math/unpooling.h"
D
dzhwinter 已提交
16
#include "paddle/fluid/platform/cuda_primitives.h"
S
sweetsky0901 已提交
17 18 19 20

namespace paddle {
namespace operators {
namespace math {
S
sweetsky0901 已提交
21
template <typename T>
S
sweetsky0901 已提交
22
__global__ void KernelUnpool2dMax(const int nthreads, const T* input_data,
S
sweetsky0901 已提交
23
                                  const int* indices_data,
S
sweetsky0901 已提交
24
                                  const int input_height, const int input_width,
S
sweetsky0901 已提交
25 26 27
                                  const int channels, T* output_data,
                                  const int output_height,
                                  const int output_width) {
S
sweetsky0901 已提交
28 29 30 31 32 33 34 35 36 37 38 39
  int in_n_stride = input_height * input_width * channels;
  int in_c_stride = input_height * input_width;
  int out_n_stride = output_height * output_width * channels;
  int out_c_stride = output_height * output_width;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int bidx = i / in_n_stride;
    int boffset = i % in_n_stride;
    int cidx = boffset / in_c_stride;
    int out_offset = bidx * out_n_stride + cidx * out_c_stride;
    int out_index = indices_data[i];
40 41 42 43
    PADDLE_ENFORCE(out_index < out_c_stride,
                   "out_index < out_c_stride. Expected %ld < %ld, but got "
                   "%ld >= %ld. Please check input value.",
                   out_index, out_c_stride, out_index, out_c_stride);
S
sweetsky0901 已提交
44 45
    output_data[out_offset + out_index] = input_data[i];
  }
S
sweetsky0901 已提交
46
}
S
sweetsky0901 已提交
47
template <typename T>
S
sweetsky0901 已提交
48 49 50 51 52
__global__ void KernelUnpool2dMaxGrad(
    const int nthreads, const T* input_data, const int* indices_data,
    const int input_height, const int input_width, const int channels,
    const T* output_data, const T* output_grad, const int output_height,
    const int output_width, T* input_grad) {
S
sweetsky0901 已提交
53 54 55 56 57 58 59 60 61 62 63 64
  int in_n_stride = input_height * input_width * channels;
  int in_c_stride = input_height * input_width;
  int out_n_stride = output_height * output_width * channels;
  int out_c_stride = output_height * output_width;
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  int offset = blockDim.x * gridDim.x;
  for (int i = index; i < nthreads; i += offset) {
    int bidx = i / in_n_stride;
    int boffset = i % in_n_stride;
    int cidx = boffset / in_c_stride;
    int out_offset = bidx * out_n_stride + cidx * out_c_stride;
    int out_index = indices_data[i];
65 66 67 68
    PADDLE_ENFORCE(out_index < out_c_stride,
                   "out_index < out_c_stride. Expected %ld < %ld, but got "
                   "%ld >= %ld. Please check input value.",
                   out_index, out_c_stride, out_index, out_c_stride);
S
sweetsky0901 已提交
69 70
    input_grad[i] = output_grad[out_offset + out_index];
  }
S
sweetsky0901 已提交
71 72 73 74
}
/*
 * All tensors are in NCHW format.
 */
S
sweetsky0901 已提交
75
template <typename T>
Q
QI JUN 已提交
76
class Unpool2dMaxFunctor<platform::CUDADeviceContext, T> {
S
sweetsky0901 已提交
77
 public:
Q
QI JUN 已提交
78
  void operator()(const platform::CUDADeviceContext& context,
S
sweetsky0901 已提交
79 80
                  const framework::Tensor& input,
                  const framework::Tensor& indices, framework::Tensor* output) {
S
sweetsky0901 已提交
81 82 83 84 85 86 87
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output->dims()[1];
    const int output_height = output->dims()[2];
    const int output_width = output->dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
88
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
89
    T* output_data = output->mutable_data<T>(context.GetPlace());
90
    int threads = 1024;
S
sweetsky0901 已提交
91
    int grid = (input.numel() + threads - 1) / threads;
Q
QI JUN 已提交
92 93 94
    KernelUnpool2dMax<T><<<grid, threads, 0, context.stream()>>>(
        input.numel(), input_data, indices_data, input_height, input_width,
        output_channels, output_data, output_height, output_width);
S
sweetsky0901 已提交
95 96 97 98 99
  }
};
/*
 * All tensors are in NCHW format.
 */
S
sweetsky0901 已提交
100
template <typename T>
Q
QI JUN 已提交
101
class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, T> {
S
sweetsky0901 已提交
102
 public:
Q
QI JUN 已提交
103
  void operator()(const platform::CUDADeviceContext& context,
S
sweetsky0901 已提交
104
                  const framework::Tensor& input,
S
sweetsky0901 已提交
105
                  const framework::Tensor& indices,
S
sweetsky0901 已提交
106
                  const framework::Tensor& output,
S
sweetsky0901 已提交
107
                  const framework::Tensor& output_grad,
S
sweetsky0901 已提交
108
                  framework::Tensor* input_grad) {
S
sweetsky0901 已提交
109 110 111 112 113 114 115
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const T* input_data = input.data<T>();
S
sweetsky0901 已提交
116
    const int* indices_data = indices.data<int>();
S
sweetsky0901 已提交
117 118 119
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
    T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
120
    int threads = 1024;
S
sweetsky0901 已提交
121
    int grid = (input.numel() + threads - 1) / threads;
Q
QI JUN 已提交
122 123 124 125
    KernelUnpool2dMaxGrad<T><<<grid, threads, 0, context.stream()>>>(
        input.numel(), input_data, indices_data, input_height, input_width,
        output_channels, output_data, output_grad_data, output_height,
        output_width, input_grad_data);
S
sweetsky0901 已提交
126 127
  }
};
Q
QI JUN 已提交
128 129 130 131
template class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, float>;
template class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, double>;
template class Unpool2dMaxFunctor<platform::CUDADeviceContext, float>;
template class Unpool2dMaxFunctor<platform::CUDADeviceContext, double>;
S
sweetsky0901 已提交
132 133 134
}  // namespace math
}  // namespace operators
}  // namespace paddle