im2col.cu 20.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/im2col.h"
18 19
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
W
Wilber 已提交
20
#include "paddle/pten/backends/gpu/gpu_context.h"
H
hedaoyuan 已提交
21 22

namespace paddle {
23
namespace operators {
24
namespace math {
H
hedaoyuan 已提交
25 26

template <class T>
C
chengduoZH 已提交
27 28
__global__ void im2col(const T* data_im, int num_outs, int im_height,
                       int im_width, int dilation_h, int dilation_w,
H
hedaoyuan 已提交
29 30
                       int filter_height, int filter_width, int stride_height,
                       int stride_width, int padding_height, int padding_width,
31 32 33 34
                       int col_height, int col_width, T* data_col,
                       const DataLayout data_layout) {
  int input_channels = num_outs / col_height / col_width;
  int channels_col = input_channels * filter_height * filter_width;
C
chengduoZH 已提交
35 36
  const int index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
37
  if (index < num_outs) {
38
    int w_out = (data_layout != DataLayout::kNHWC
39 40
                     ? index % col_width
                     : (index / input_channels) % col_width);
41
    int h_out = (data_layout != DataLayout::kNHWC
42 43 44
                     ? (index / col_width) % col_height
                     : (index / input_channels / col_width) % col_height);
    int channel_in =
45
        (data_layout != DataLayout::kNHWC ? index / col_width / col_height
46
                                          : index % input_channels);
H
hedaoyuan 已提交
47
    int channel_out = channel_in * filter_height * filter_width;
C
chengduoZH 已提交
48 49
    int h_in = h_out * stride_height - padding_height;
    int w_in = w_out * stride_width - padding_width;
H
hedaoyuan 已提交
50

C
chengduoZH 已提交
51
    data_col += (channel_out * col_height + h_out) * col_width + w_out;
H
hedaoyuan 已提交
52 53
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
54 55
        int rIdx = h_in + i * dilation_h;
        int cIdx = w_in + j * dilation_w;
56
        int im_idx;
57
        if (data_layout != DataLayout::kNHWC) {
58 59 60 61
          im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
        } else {
          im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
        }
C
chengduoZH 已提交
62 63 64
        *data_col =
            (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
                ? 0
65
                : data_im[im_idx];
C
chengduoZH 已提交
66
        data_col += col_height * col_width;
H
hedaoyuan 已提交
67 68 69 70 71 72
      }
    }
  }
}

/*
H
hedaoyuan 已提交
73 74 75
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
76
 */
W
Wilber 已提交
77 78 79
template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
                    T> {
H
hedaoyuan 已提交
80
 public:
W
Wilber 已提交
81 82
  void operator()(const DeviceContext& context, const framework::Tensor& im,
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
83
                  const std::vector<int>& stride,
84 85
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
86 87 88 89 90
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
91
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
92 93 94 95
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
96 97

    int im_channels =
98
        (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
99
    int im_height =
100
        (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
101
    int im_width =
102
        (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
C
chengduoZH 已提交
103 104 105 106 107
    int filter_height = col->dims()[1];
    int filter_width = col->dims()[2];
    int col_height = col->dims()[3];
    int col_width = col->dims()[4];

C
chengduoZH 已提交
108
    int num_outputs = im_channels * col_height * col_width;
F
feng_shuai 已提交
109 110 111 112 113
    int num_thread = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &num_thread);
#endif
    int blocks = (num_outputs + num_thread - 1) / num_thread;
H
hedaoyuan 已提交
114 115
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
F
feng_shuai 已提交
116
    dim3 threads(num_thread, 1);
H
hedaoyuan 已提交
117
    dim3 grid(block_x, block_y);
Q
QI JUN 已提交
118
    im2col<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
119 120
        im.data<T>(), num_outputs, im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
121 122
        padding[0], padding[1], col_height, col_width, col->data<T>(),
        data_layout);
H
hedaoyuan 已提交
123 124 125 126
  }
};

template <class T>
C
chengduoZH 已提交
127 128 129 130
__global__ void col2im(int n, const T* data_col, int im_height, int im_width,
                       int dilation_h, int dilation_w, int filter_height,
                       int filter_width, int stride_height, int stride_width,
                       int padding_height, int padding_width, int col_height,
131 132
                       int col_width, T* data_im,
                       const DataLayout data_layout) {
C
chengduoZH 已提交
133
  const int index =
H
hedaoyuan 已提交
134
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
135 136 137 138

  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

139 140
  int input_channels = n / im_height / im_width;

H
hedaoyuan 已提交
141 142
  if (index < n) {
    T val = 0;
143
    int w = (data_layout != DataLayout::kNHWC
144 145
                 ? index % im_width + padding_width
                 : (index / input_channels) % im_width + padding_width);
146
    int h = (data_layout != DataLayout::kNHWC
147 148 149
                 ? (index / im_width) % im_height + padding_height
                 : (index / input_channels / im_width) % im_height +
                       padding_height);
150
    int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height
151
                                              : index % input_channels);
C
chengduoZH 已提交
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173

    // compute the start and end of the output
    int w_col_start =
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
    int w_col_end = min(w / stride_width + 1, col_width);
    int h_col_start =
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
    int h_col_end = min(h / stride_height + 1, col_height);

    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
        int h_off = (h - h_col * stride_height);
        int w_off = (w - w_col * stride_width);
        if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
          h_off /= dilation_h;
          w_off /= dilation_w;
          int data_col_index =
              (((c * filter_height + h_off) * filter_width + w_off) *
                   col_height +
               h_col) *
                  col_width +
              w_col;
C
chengduoZH 已提交
174

C
chengduoZH 已提交
175
          val += data_col[data_col_index];
H
hedaoyuan 已提交
176 177 178
        }
      }
    }
C
chengduoZH 已提交
179
    data_im[index] = val;
H
hedaoyuan 已提交
180 181 182 183
  }
}

/*
H
hedaoyuan 已提交
184 185 186
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
187
 */
W
Wilber 已提交
188 189 190
template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
                    T> {
H
hedaoyuan 已提交
191
 public:
W
Wilber 已提交
192
  void operator()(const DeviceContext& context, const framework::Tensor& col,
C
chengduoZH 已提交
193 194
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
195 196
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
197 198 199 200 201
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
202
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
203 204 205 206
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
207 208

    int im_channels =
209
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
210
    int im_height =
211
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
212
    int im_width =
213
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
214 215
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
216 217 218
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];

C
chengduoZH 已提交
219 220 221
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
222
                          1,
223 224 225
                      col_height, platform::errors::InvalidArgument(
                                      "Output_height and padding(padding_up, "
                                      "padding_down) are inconsistent."));
C
chengduoZH 已提交
226 227 228
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
229
                          1,
230 231 232
                      col_width, platform::errors::InvalidArgument(
                                     "col_width and padding(padding_left, "
                                     "padding_right) are inconsistent."));
C
chengduoZH 已提交
233 234

    size_t num_kernels = im_channels * im_height * im_width;
H
hedaoyuan 已提交
235

F
feng_shuai 已提交
236 237 238 239 240
    int num_thread = 1024;
#ifdef WITH_NV_JETSON
    platform::ChangeThreadNum(context, &num_thread);
#endif
    size_t blocks = (num_kernels + num_thread - 1) / num_thread;
H
hedaoyuan 已提交
241 242
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
F
feng_shuai 已提交
243
    dim3 threads(num_thread, 1);
H
hedaoyuan 已提交
244
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
245 246 247

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
Q
QI JUN 已提交
248
    col2im<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
249 250
        num_kernels, col.data<T>(), im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
251 252
        padding[0], padding[1], col_height, col_width, im->data<T>(),
        data_layout);
H
hedaoyuan 已提交
253 254 255
  }
};

H
hedaoyuan 已提交
256
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
257
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
258
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
259
                             platform::CUDADeviceContext, double>;
W
Wilber 已提交
260 261 262 263
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             pten::GPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
                             pten::GPUContext, double>;
H
hedaoyuan 已提交
264
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
265
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
266
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
267
                             platform::CUDADeviceContext, double>;
W
Wilber 已提交
268 269 270 271
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             pten::GPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
                             pten::GPUContext, double>;
H
hedaoyuan 已提交
272 273

template <class T>
C
chengduoZH 已提交
274 275 276
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
277
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
278
                          int col_width, T* col_data) {
H
hedaoyuan 已提交
279 280
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
281
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
282 283 284 285
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
286
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
287 288
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
289

H
hedaoyuan 已提交
290 291
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
292 293 294 295 296 297 298 299
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);

        col_data[col_offset] =
            (height_offset >= im_height || height_offset < 0 ||
             width_offset >= im_width || width_offset < 0)
                ? T(0)
                : im_data[im_offset];
H
hedaoyuan 已提交
300 301 302 303 304 305
      }
    }
  }
}

/*
H
hedaoyuan 已提交
306 307 308
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
309
 */
W
Wilber 已提交
310 311 312
template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
                    T> {
H
hedaoyuan 已提交
313
 public:
W
Wilber 已提交
314 315
  void operator()(const DeviceContext& context, const framework::Tensor& im,
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
316
                  const std::vector<int>& stride,
317 318
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
319 320 321 322 323
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
324
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
325 326 327 328
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
329

C
chengduoZH 已提交
330 331 332
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
333 334 335 336 337
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];

H
hedaoyuan 已提交
338 339 340 341 342 343 344 345 346 347 348
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
349
    } else {
H
hedaoyuan 已提交
350 351
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
352 353
    }

H
hedaoyuan 已提交
354
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
355 356
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
357
    im2colOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
358 359 360
        im.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, col->data<T>());
H
hedaoyuan 已提交
361 362 363 364
  }
};

template <class T>
C
chengduoZH 已提交
365 366 367
__global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
368
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
369
                          int col_width, T* im_data) {
H
hedaoyuan 已提交
370 371
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
372
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
373 374 375 376
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
377
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
378 379
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
380

H
hedaoyuan 已提交
381 382
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
383 384
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);
H
hedaoyuan 已提交
385

C
chengduoZH 已提交
386 387
        if (height_offset >= 0 && height_offset < im_height &&
            width_offset >= 0 && width_offset < im_width) {
H
hedaoyuan 已提交
388 389
          paddle::platform::CudaAtomicAdd(im_data + im_offset,
                                          col_data[col_offset]);
H
hedaoyuan 已提交
390 391 392 393 394 395 396
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
397 398 399
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
400
 */
W
Wilber 已提交
401 402 403
template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
                    T> {
H
hedaoyuan 已提交
404
 public:
W
Wilber 已提交
405
  void operator()(const DeviceContext& context, const framework::Tensor& col,
C
chengduoZH 已提交
406 407
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
408 409
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
410 411 412 413 414
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
415
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
416 417 418 419
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
420

C
chengduoZH 已提交
421 422 423
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
424 425
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
426 427 428
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];

C
chengduoZH 已提交
429 430 431
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
432
                          1,
433 434 435
                      col_height, platform::errors::InvalidArgument(
                                      "Output_height and padding(padding_up, "
                                      "padding_down) are inconsistent."));
C
chengduoZH 已提交
436 437 438
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
439
                          1,
440 441 442
                      col_width, platform::errors::InvalidArgument(
                                     "col_width and padding(padding_left, "
                                     "padding_right) are inconsistent."));
C
chengduoZH 已提交
443

H
hedaoyuan 已提交
444 445 446 447 448 449 450 451 452 453 454
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
455
    } else {
H
hedaoyuan 已提交
456 457
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
458 459
    }

H
hedaoyuan 已提交
460
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
461 462
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
463
    col2imOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
464 465 466
        col.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, im->data<T>());
H
hedaoyuan 已提交
467 468 469
  }
};

H
hedaoyuan 已提交
470
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
471
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
472
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
473
                             platform::CUDADeviceContext, double>;
W
Wilber 已提交
474 475 476 477 478
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             pten::GPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
                             pten::GPUContext, double>;

H
hedaoyuan 已提交
479
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
480
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
481
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
482
                             platform::CUDADeviceContext, double>;
W
Wilber 已提交
483 484 485 486
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             pten::GPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
                             pten::GPUContext, double>;
H
hedaoyuan 已提交
487

488
}  // namespace math
489
}  // namespace operators
H
hedaoyuan 已提交
490
}  // namespace paddle