pooling.cc 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/math/pooling.h"

namespace paddle {
namespace operators {
namespace math {

template <typename PoolProcess, typename T>
class Pool2dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
 public:
24 25
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& output,
26
                  std::vector<int>& ksize, std::vector<int>& strides,
27
                  std::vector<int>& paddings, PoolProcess pool_process) {
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];

    const int input_stride = input_height * input_width;
    const int output_stride = output_height * output_width;

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
45
    T* output_data = output.mutable_data<T>(context.GetPlace());
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

    for (int i = 0; i < batch_size; i++) {
      for (int c = 0; c < output_channels; ++c) {
        for (int ph = 0; ph < output_height; ++ph) {
          int hstart = ph * stride_height - padding_height;
          int hend = std::min(hstart + ksize_height, input_height);
          hstart = std::max(hstart, 0);
          for (int pw = 0; pw < output_width; ++pw) {
            int wstart = pw * stride_width - padding_width;
            int wend = std::min(wstart + ksize_width, input_width);
            wstart = std::max(wstart, 0);
            T ele = pool_process.initial();
            for (int h = hstart; h < hend; ++h) {
              for (int w = wstart; w < wend; ++w) {
                pool_process.process(ele, input_data[h * input_width + w]);
              }
            }
            int pool_size = (hend - hstart) * (wend - wstart);
            pool_process.finalize(ele, (static_cast<T>(pool_size)));
            output_data[ph * output_width + pw] = ele;
          }
        }
        input_data += input_stride;
        output_data += output_stride;
      }
    }
  }
};

template <typename PoolProcess, class T>
class Pool2dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
 public:
78 79
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
80 81 82
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
83
                  PoolProcess pool_process) {
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
    const int batch_size = input.dims()[0];
    const int input_height = input.dims()[2];
    const int input_width = input.dims()[3];
    const int output_channels = output.dims()[1];
    const int output_height = output.dims()[2];
    const int output_width = output.dims()[3];
    const int ksize_height = ksize[0];
    const int ksize_width = ksize[1];
    const int stride_height = strides[0];
    const int stride_width = strides[1];
    const int padding_height = paddings[0];
    const int padding_width = paddings[1];
    const int input_stride = input_height * input_width;
    const int output_stride = output_height * output_width;

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
102
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
103 104 105 106 107 108 109 110 111 112 113 114

    for (int i = 0; i < batch_size; i++) {
      for (int c = 0; c < output_channels; ++c) {
        for (int ph = 0; ph < output_height; ++ph) {
          int hstart = ph * stride_height - padding_height;
          int hend = std::min(hstart + ksize_height, input_height);
          hstart = std::max(hstart, 0);
          for (int pw = 0; pw < output_width; ++pw) {
            int wstart = pw * stride_width - padding_width;
            int wend = std::min(wstart + ksize_width, input_width);
            wstart = std::max(wstart, 0);
            int pool_size = (hend - hstart) * (wend - wstart);
115
            float scale = 1.0 / pool_size;
116 117 118 119 120 121 122
            for (int h = hstart; h < hend; ++h) {
              for (int w = wstart; w < wend; ++w) {
                pool_process.gradProcess(
                    input_data[h * input_width + w],
                    output_data[ph * output_width + pw],
                    output_grad_data[ph * output_width + pw],
                    input_grad_data[h * input_width + w],
123
                    static_cast<T>(scale));
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
              }
            }
          }
        }
        input_data += input_stride;
        output_data += output_stride;
        input_grad_data += input_stride;
        output_grad_data += output_stride;
      }
    }
  }
};

template class Pool2dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool2dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool2dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;

template <typename PoolProcess, class T>
class Pool3dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
 public:
157 158
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& output,
159
                  std::vector<int>& ksize, std::vector<int>& strides,
160
                  std::vector<int>& paddings, PoolProcess pool_process) {
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
    const int batch_size = input.dims()[0];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];

    const int input_stride = input_depth * input_height * input_width;
    const int output_stride = output_depth * output_height * output_width;

    const T* input_data = input.data<T>();
C
chengduoZH 已提交
183
    T* output_data = output.mutable_data<T>(context.GetPlace());
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226

    for (int i = 0; i < batch_size; i++) {
      for (int c = 0; c < output_channels; ++c) {
        for (int pd = 0; pd < output_depth; ++pd) {
          int dstart = pd * stride_depth - padding_depth;
          int dend = std::min(dstart + ksize_depth, input_depth);
          dstart = std::max(dstart, 0);
          for (int ph = 0; ph < output_height; ++ph) {
            int hstart = ph * stride_height - padding_height;
            int hend = std::min(hstart + ksize_height, input_height);
            hstart = std::max(hstart, 0);
            for (int pw = 0; pw < output_width; ++pw) {
              int wstart = pw * stride_width - padding_width;
              int wend = std::min(wstart + ksize_width, input_width);
              wstart = std::max(wstart, 0);
              int output_idx = (pd * output_height + ph) * output_width + pw;
              T ele = pool_process.initial();
              for (int d = dstart; d < dend; ++d) {
                for (int h = hstart; h < hend; ++h) {
                  for (int w = wstart; w < wend; ++w) {
                    pool_process.process(
                        ele,
                        input_data[(d * input_height + h) * input_width + w]);
                  }
                }
              }
              int pool_size =
                  (dend - dstart) * (hend - hstart) * (wend - wstart);
              pool_process.finalize(ele, static_cast<T>(pool_size));
              output_data[output_idx] = ele;
            }
          }
        }
        input_data += input_stride;
        output_data += output_stride;
      }
    }
  }
};

template <typename PoolProcess, class T>
class Pool3dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
 public:
227 228
  void operator()(const platform::DeviceContext& context,
                  const framework::Tensor& input, framework::Tensor& input_grad,
229 230 231
                  const framework::Tensor& output,
                  const framework::Tensor& output_grad, std::vector<int>& ksize,
                  std::vector<int>& strides, std::vector<int>& paddings,
232
                  PoolProcess pool_process) {
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
    const int batch_size = input.dims()[0];
    const int input_depth = input.dims()[2];
    const int input_height = input.dims()[3];
    const int input_width = input.dims()[4];
    const int output_channels = output.dims()[1];
    const int output_depth = output.dims()[2];
    const int output_height = output.dims()[3];
    const int output_width = output.dims()[4];
    const int ksize_depth = ksize[0];
    const int ksize_height = ksize[1];
    const int ksize_width = ksize[2];
    const int stride_depth = strides[0];
    const int stride_height = strides[1];
    const int stride_width = strides[2];
    const int padding_depth = paddings[0];
    const int padding_height = paddings[1];
    const int padding_width = paddings[2];
    const int input_stride = input_depth * input_height * input_width;
    const int output_stride = output_depth * output_height * output_width;

    const T* input_data = input.data<T>();
    const T* output_data = output.data<T>();
    const T* output_grad_data = output_grad.data<T>();
C
chengduoZH 已提交
256
    T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275

    for (int i = 0; i < batch_size; i++) {
      for (int c = 0; c < output_channels; ++c) {
        for (int pd = 0; pd < output_depth; ++pd) {
          int dstart = pd * stride_depth - padding_depth;
          int dend = std::min(dstart + ksize_depth, input_depth);
          dstart = std::max(dstart, 0);
          for (int ph = 0; ph < output_height; ++ph) {
            int hstart = ph * stride_height - padding_height;
            int hend = std::min(hstart + ksize_height, input_height);
            hstart = std::max(hstart, 0);

            for (int pw = 0; pw < output_width; ++pw) {
              int wstart = pw * stride_width - padding_width;
              int wend = std::min(wstart + ksize_width, input_width);
              wstart = std::max(wstart, 0);

              int pool_size =
                  (dend - dstart) * (hend - hstart) * (wend - wstart);
276
              float scale = 1.0 / pool_size;
277 278 279 280 281 282 283 284 285
              for (int d = dstart; d < dend; ++d) {
                for (int h = hstart; h < hend; ++h) {
                  for (int w = wstart; w < wend; ++w) {
                    int input_idx = (d * input_height + h) * input_width + w;
                    int output_idx =
                        (pd * output_height + ph) * output_width + pw;
                    pool_process.gradProcess(
                        input_data[input_idx], output_data[output_idx],
                        output_grad_data[output_idx],
286
                        input_grad_data[input_idx], static_cast<T>(scale));
287 288 289 290 291 292
                  }
                }
              }
            }
          }
        }
293 294 295 296
        input_data += input_stride;
        output_data += output_stride;
        input_grad_data += input_stride;
        output_grad_data += output_stride;
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320
      }
    }
  }
};

template class Pool3dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<float>, float>;
template class Pool3dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dForwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;
template class Pool3dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dBackwardFunctor<
    platform::CPUPlace, paddle::operators::math::pool::avePool<double>, double>;
}  // namespace math
}  // namespace operators
}  // namespace paddle