im2col.cu 23.2 KB
Newer Older
1
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
17

18
#include "paddle/phi/backends/gpu/gpu_context.h"
19
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
20
#include "paddle/phi/backends/gpu/gpu_primitives.h"
21
#include "paddle/phi/kernels/funcs/im2col.h"
H
hedaoyuan 已提交
22

23 24
namespace phi {
namespace funcs {
H
hedaoyuan 已提交
25 26

template <class T>
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41
__global__ void im2col(const T* data_im,
                       int num_outs,
                       int im_height,
                       int im_width,
                       int dilation_h,
                       int dilation_w,
                       int filter_height,
                       int filter_width,
                       int stride_height,
                       int stride_width,
                       int padding_height,
                       int padding_width,
                       int col_height,
                       int col_width,
                       T* data_col,
42 43 44
                       const DataLayout data_layout) {
  int input_channels = num_outs / col_height / col_width;
  int channels_col = input_channels * filter_height * filter_width;
C
chengduoZH 已提交
45 46
  const int index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
47
  if (index < num_outs) {
48
    int w_out = (data_layout != DataLayout::kNHWC
49 50
                     ? index % col_width
                     : (index / input_channels) % col_width);
51
    int h_out = (data_layout != DataLayout::kNHWC
52 53 54
                     ? (index / col_width) % col_height
                     : (index / input_channels / col_width) % col_height);
    int channel_in =
55
        (data_layout != DataLayout::kNHWC ? index / col_width / col_height
56
                                          : index % input_channels);
H
hedaoyuan 已提交
57
    int channel_out = channel_in * filter_height * filter_width;
C
chengduoZH 已提交
58 59
    int h_in = h_out * stride_height - padding_height;
    int w_in = w_out * stride_width - padding_width;
H
hedaoyuan 已提交
60

C
chengduoZH 已提交
61
    data_col += (channel_out * col_height + h_out) * col_width + w_out;
H
hedaoyuan 已提交
62 63
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
64 65
        int rIdx = h_in + i * dilation_h;
        int cIdx = w_in + j * dilation_w;
66
        int im_idx;
67
        if (data_layout != DataLayout::kNHWC) {
68 69 70 71
          im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
        } else {
          im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
        }
C
chengduoZH 已提交
72 73 74
        *data_col =
            (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
                ? 0
75
                : data_im[im_idx];
C
chengduoZH 已提交
76
        data_col += col_height * col_width;
H
hedaoyuan 已提交
77 78 79 80 81 82
      }
    }
  }
}

/*
H
hedaoyuan 已提交
83 84 85
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
86
 */
W
Wilber 已提交
87
template <class DeviceContext, class T>
88
class Im2ColFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
H
hedaoyuan 已提交
89
 public:
90
  void operator()(const DeviceContext& context,
91
                  const phi::DenseTensor& im,
W
Wilber 已提交
92
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
93
                  const std::vector<int>& stride,
94
                  const std::vector<int>& padding,
95
                  phi::DenseTensor* col,
96
                  const DataLayout data_layout) {
97 98
    PADDLE_ENFORCE_EQ(im.dims().size(),
                      3,
99
                      phi::errors::InvalidArgument(
100 101 102
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
103 104
    PADDLE_ENFORCE_EQ(col->dims().size(),
                      5,
105
                      phi::errors::InvalidArgument(
106 107 108
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
109 110

    int im_channels =
111
        (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
112
    int im_height =
113
        (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
114
    int im_width =
115
        (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
C
chengduoZH 已提交
116 117 118 119 120
    int filter_height = col->dims()[1];
    int filter_width = col->dims()[2];
    int col_height = col->dims()[3];
    int col_width = col->dims()[4];

C
chengduoZH 已提交
121
    int num_outputs = im_channels * col_height * col_width;
F
feng_shuai 已提交
122 123
    int num_thread = 1024;
#ifdef WITH_NV_JETSON
124
    phi::backends::gpu::ChangeThreadNum(context, &num_thread);
F
feng_shuai 已提交
125 126
#endif
    int blocks = (num_outputs + num_thread - 1) / num_thread;
H
hedaoyuan 已提交
127 128
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
F
feng_shuai 已提交
129
    dim3 threads(num_thread, 1);
H
hedaoyuan 已提交
130
    dim3 grid(block_x, block_y);
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
    im2col<T><<<grid, threads, 0, context.stream()>>>(im.data<T>(),
                                                      num_outputs,
                                                      im_height,
                                                      im_width,
                                                      dilation[0],
                                                      dilation[1],
                                                      filter_height,
                                                      filter_width,
                                                      stride[0],
                                                      stride[1],
                                                      padding[0],
                                                      padding[1],
                                                      col_height,
                                                      col_width,
                                                      col->data<T>(),
                                                      data_layout);
H
hedaoyuan 已提交
147 148 149 150
  }
};

template <class T>
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
__global__ void col2im(int n,
                       const T* data_col,
                       int im_height,
                       int im_width,
                       int dilation_h,
                       int dilation_w,
                       int filter_height,
                       int filter_width,
                       int stride_height,
                       int stride_width,
                       int padding_height,
                       int padding_width,
                       int col_height,
                       int col_width,
                       T* data_im,
166
                       const DataLayout data_layout) {
C
chengduoZH 已提交
167
  const int index =
H
hedaoyuan 已提交
168
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
169 170 171 172

  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

173 174
  int input_channels = n / im_height / im_width;

H
hedaoyuan 已提交
175 176
  if (index < n) {
    T val = 0;
177
    int w = (data_layout != DataLayout::kNHWC
178 179
                 ? index % im_width + padding_width
                 : (index / input_channels) % im_width + padding_width);
180
    int h = (data_layout != DataLayout::kNHWC
181 182 183
                 ? (index / im_width) % im_height + padding_height
                 : (index / input_channels / im_width) % im_height +
                       padding_height);
184
    int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height
185
                                              : index % input_channels);
C
chengduoZH 已提交
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207

    // compute the start and end of the output
    int w_col_start =
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
    int w_col_end = min(w / stride_width + 1, col_width);
    int h_col_start =
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
    int h_col_end = min(h / stride_height + 1, col_height);

    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
        int h_off = (h - h_col * stride_height);
        int w_off = (w - w_col * stride_width);
        if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
          h_off /= dilation_h;
          w_off /= dilation_w;
          int data_col_index =
              (((c * filter_height + h_off) * filter_width + w_off) *
                   col_height +
               h_col) *
                  col_width +
              w_col;
C
chengduoZH 已提交
208

C
chengduoZH 已提交
209
          val += data_col[data_col_index];
H
hedaoyuan 已提交
210 211 212
        }
      }
    }
C
chengduoZH 已提交
213
    data_im[index] = val;
H
hedaoyuan 已提交
214 215 216 217
  }
}

/*
H
hedaoyuan 已提交
218 219 220
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
221
 */
W
Wilber 已提交
222
template <class DeviceContext, class T>
223
class Col2ImFunctor<phi::funcs::ColFormat::kCFO, DeviceContext, T> {
H
hedaoyuan 已提交
224
 public:
225
  void operator()(const DeviceContext& context,
226
                  const phi::DenseTensor& col,
C
chengduoZH 已提交
227 228
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
229
                  const std::vector<int>& padding,
230
                  phi::DenseTensor* im,
231
                  const DataLayout data_layout) {
232 233
    PADDLE_ENFORCE_EQ(im->dims().size(),
                      3,
234
                      phi::errors::InvalidArgument(
235 236 237
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
238 239
    PADDLE_ENFORCE_EQ(col.dims().size(),
                      5,
240
                      phi::errors::InvalidArgument(
241 242 243
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
244 245

    int im_channels =
246
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
247
    int im_height =
248
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
249
    int im_width =
250
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
251 252
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
253 254 255
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];

256 257 258 259 260 261 262 263
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] -
         (dilation[0] * (filter_height - 1) + 1)) /
                stride[0] +
            1,
        col_height,
        phi::errors::InvalidArgument("Output_height and padding(padding_up, "
                                     "padding_down) are inconsistent."));
264 265 266 267 268 269
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] -
         (dilation[1] * (filter_width - 1) + 1)) /
                stride[1] +
            1,
        col_width,
270 271
        phi::errors::InvalidArgument("col_width and padding(padding_left, "
                                     "padding_right) are inconsistent."));
C
chengduoZH 已提交
272 273

    size_t num_kernels = im_channels * im_height * im_width;
H
hedaoyuan 已提交
274

F
feng_shuai 已提交
275 276
    int num_thread = 1024;
#ifdef WITH_NV_JETSON
277
    phi::backends::gpu::ChangeThreadNum(context, &num_thread);
F
feng_shuai 已提交
278 279
#endif
    size_t blocks = (num_kernels + num_thread - 1) / num_thread;
H
hedaoyuan 已提交
280 281
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
F
feng_shuai 已提交
282
    dim3 threads(num_thread, 1);
H
hedaoyuan 已提交
283
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
284 285 286

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
    col2im<T><<<grid, threads, 0, context.stream()>>>(num_kernels,
                                                      col.data<T>(),
                                                      im_height,
                                                      im_width,
                                                      dilation[0],
                                                      dilation[1],
                                                      filter_height,
                                                      filter_width,
                                                      stride[0],
                                                      stride[1],
                                                      padding[0],
                                                      padding[1],
                                                      col_height,
                                                      col_width,
                                                      im->data<T>(),
                                                      data_layout);
H
hedaoyuan 已提交
303 304 305
  }
};

306
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
307 308
                             phi::GPUContext,
                             float>;
309
template class Im2ColFunctor<phi::funcs::ColFormat::kCFO,
310 311
                             phi::GPUContext,
                             double>;
312
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
313 314
                             phi::GPUContext,
                             float>;
315
template class Col2ImFunctor<phi::funcs::ColFormat::kCFO,
316 317
                             phi::GPUContext,
                             double>;
H
hedaoyuan 已提交
318 319

template <class T>
320 321 322 323 324 325 326 327 328 329 330 331 332
__global__ void im2colOCF(const T* im_data,
                          int im_channels,
                          int im_height,
                          int im_width,
                          int filter_height,
                          int filter_width,
                          int stride_height,
                          int stride_width,
                          int padding_height,
                          int padding_width,
                          int col_height,
                          int col_width,
                          T* col_data) {
H
hedaoyuan 已提交
333 334
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
335
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
336 337 338 339
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
340
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
341 342
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
343

H
hedaoyuan 已提交
344 345
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
346 347 348 349 350 351 352 353
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);

        col_data[col_offset] =
            (height_offset >= im_height || height_offset < 0 ||
             width_offset >= im_width || width_offset < 0)
                ? T(0)
                : im_data[im_offset];
H
hedaoyuan 已提交
354 355 356 357 358 359
      }
    }
  }
}

/*
H
hedaoyuan 已提交
360 361 362
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
363
 */
W
Wilber 已提交
364
template <class DeviceContext, class T>
365
class Im2ColFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
H
hedaoyuan 已提交
366
 public:
367
  void operator()(const DeviceContext& context,
368
                  const phi::DenseTensor& im,
W
Wilber 已提交
369
                  const std::vector<int>& dilation,
C
chengduoZH 已提交
370
                  const std::vector<int>& stride,
371
                  const std::vector<int>& padding,
372
                  phi::DenseTensor* col,
373
                  const DataLayout data_layout) {
374 375
    PADDLE_ENFORCE_EQ(im.dims().size(),
                      3,
376
                      phi::errors::InvalidArgument(
377 378 379
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
380 381
    PADDLE_ENFORCE_EQ(col->dims().size(),
                      5,
382
                      phi::errors::InvalidArgument(
383 384 385
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
386

C
chengduoZH 已提交
387 388 389
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
390 391 392 393 394
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];

H
hedaoyuan 已提交
395 396 397 398 399 400 401 402 403 404 405
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
406
    } else {
H
hedaoyuan 已提交
407 408
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
409 410
    }

H
hedaoyuan 已提交
411
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
412 413
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
414 415 416 417 418 419 420 421 422 423 424 425 426
    im2colOCF<T><<<grid, threads, 0, context.stream()>>>(im.data<T>(),
                                                         im_channels,
                                                         im_height,
                                                         im_width,
                                                         filter_height,
                                                         filter_width,
                                                         stride[0],
                                                         stride[1],
                                                         padding[0],
                                                         padding[1],
                                                         col_height,
                                                         col_width,
                                                         col->data<T>());
H
hedaoyuan 已提交
427 428 429 430
  }
};

template <class T>
431 432 433 434 435 436 437 438 439 440 441 442 443
__global__ void col2imOCF(const T* col_data,
                          int im_channels,
                          int im_height,
                          int im_width,
                          int filter_height,
                          int filter_width,
                          int stride_height,
                          int stride_width,
                          int padding_height,
                          int padding_width,
                          int col_height,
                          int col_width,
                          T* im_data) {
H
hedaoyuan 已提交
444 445
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
446
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
447 448 449 450
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
451
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
452 453
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
454

H
hedaoyuan 已提交
455 456
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
457 458
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);
H
hedaoyuan 已提交
459

C
chengduoZH 已提交
460 461
        if (height_offset >= 0 && height_offset < im_height &&
            width_offset >= 0 && width_offset < im_width) {
462
          phi::CudaAtomicAdd(im_data + im_offset, col_data[col_offset]);
H
hedaoyuan 已提交
463 464 465 466 467 468 469
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
470 471 472
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
473
 */
W
Wilber 已提交
474
template <class DeviceContext, class T>
475
class Col2ImFunctor<phi::funcs::ColFormat::kOCF, DeviceContext, T> {
H
hedaoyuan 已提交
476
 public:
477
  void operator()(const DeviceContext& context,
478
                  const phi::DenseTensor& col,
C
chengduoZH 已提交
479 480
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
481
                  const std::vector<int>& padding,
482
                  phi::DenseTensor* im,
483
                  const DataLayout data_layout) {
484 485
    PADDLE_ENFORCE_EQ(im->dims().size(),
                      3,
486
                      phi::errors::InvalidArgument(
487 488 489
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
490 491
    PADDLE_ENFORCE_EQ(col.dims().size(),
                      5,
492
                      phi::errors::InvalidArgument(
493 494 495
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
496

C
chengduoZH 已提交
497 498 499
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
500 501
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
502 503 504
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];

505 506 507 508 509 510 511 512
    PADDLE_ENFORCE_EQ(
        (im_height + padding[0] + padding[2] -
         (dilation[0] * (filter_height - 1) + 1)) /
                stride[0] +
            1,
        col_height,
        phi::errors::InvalidArgument("Output_height and padding(padding_up, "
                                     "padding_down) are inconsistent."));
513 514 515 516 517 518
    PADDLE_ENFORCE_EQ(
        (im_width + padding[1] + padding[3] -
         (dilation[1] * (filter_width - 1) + 1)) /
                stride[1] +
            1,
        col_width,
519 520
        phi::errors::InvalidArgument("col_width and padding(padding_left, "
                                     "padding_right) are inconsistent."));
C
chengduoZH 已提交
521

H
hedaoyuan 已提交
522 523 524 525 526 527 528 529 530 531 532
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
533
    } else {
H
hedaoyuan 已提交
534 535
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
536 537
    }

H
hedaoyuan 已提交
538
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
539 540
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
541 542 543 544 545 546 547 548 549 550 551 552 553
    col2imOCF<T><<<grid, threads, 0, context.stream()>>>(col.data<T>(),
                                                         im_channels,
                                                         im_height,
                                                         im_width,
                                                         filter_height,
                                                         filter_width,
                                                         stride[0],
                                                         stride[1],
                                                         padding[0],
                                                         padding[1],
                                                         col_height,
                                                         col_width,
                                                         im->data<T>());
H
hedaoyuan 已提交
554 555 556
  }
};

557
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
558 559
                             phi::GPUContext,
                             float>;
560
template class Im2ColFunctor<phi::funcs::ColFormat::kOCF,
561 562
                             phi::GPUContext,
                             double>;
W
Wilber 已提交
563

564
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
565 566
                             phi::GPUContext,
                             float>;
567
template class Col2ImFunctor<phi::funcs::ColFormat::kOCF,
568 569
                             phi::GPUContext,
                             double>;
H
hedaoyuan 已提交
570

571 572
}  // namespace funcs
}  // namespace phi