im2col.cu 19.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
H
hedaoyuan 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <algorithm>
#include <vector>
Y
Yi Wang 已提交
17
#include "paddle/fluid/operators/math/im2col.h"
D
dzhwinter 已提交
18
#include "paddle/fluid/platform/cuda_primitives.h"
H
hedaoyuan 已提交
19 20

namespace paddle {
21
namespace operators {
22
namespace math {
H
hedaoyuan 已提交
23 24

template <class T>
C
chengduoZH 已提交
25 26
__global__ void im2col(const T* data_im, int num_outs, int im_height,
                       int im_width, int dilation_h, int dilation_w,
H
hedaoyuan 已提交
27 28
                       int filter_height, int filter_width, int stride_height,
                       int stride_width, int padding_height, int padding_width,
29 30 31 32
                       int col_height, int col_width, T* data_col,
                       const DataLayout data_layout) {
  int input_channels = num_outs / col_height / col_width;
  int channels_col = input_channels * filter_height * filter_width;
C
chengduoZH 已提交
33 34
  const int index =
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
H
hedaoyuan 已提交
35
  if (index < num_outs) {
36
    int w_out = (data_layout != DataLayout::kNHWC
37 38
                     ? index % col_width
                     : (index / input_channels) % col_width);
39
    int h_out = (data_layout != DataLayout::kNHWC
40 41 42
                     ? (index / col_width) % col_height
                     : (index / input_channels / col_width) % col_height);
    int channel_in =
43
        (data_layout != DataLayout::kNHWC ? index / col_width / col_height
44
                                          : index % input_channels);
H
hedaoyuan 已提交
45
    int channel_out = channel_in * filter_height * filter_width;
C
chengduoZH 已提交
46 47
    int h_in = h_out * stride_height - padding_height;
    int w_in = w_out * stride_width - padding_width;
H
hedaoyuan 已提交
48

C
chengduoZH 已提交
49
    data_col += (channel_out * col_height + h_out) * col_width + w_out;
H
hedaoyuan 已提交
50 51
    for (int i = 0; i < filter_height; ++i) {
      for (int j = 0; j < filter_width; ++j) {
C
chengduoZH 已提交
52 53
        int rIdx = h_in + i * dilation_h;
        int cIdx = w_in + j * dilation_w;
54
        int im_idx;
55
        if (data_layout != DataLayout::kNHWC) {
56 57 58 59
          im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
        } else {
          im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
        }
C
chengduoZH 已提交
60 61 62
        *data_col =
            (rIdx >= im_height || rIdx < 0 || cIdx >= im_width || cIdx < 0)
                ? 0
63
                : data_im[im_idx];
C
chengduoZH 已提交
64
        data_col += col_height * col_width;
H
hedaoyuan 已提交
65 66 67 68 69 70
      }
    }
  }
}

/*
H
hedaoyuan 已提交
71 72 73
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
74 75
 */
template <class T>
H
hedaoyuan 已提交
76
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
77
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
78
 public:
Q
QI JUN 已提交
79
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
80 81
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
82 83
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
84 85 86 87 88
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
89
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
90 91 92 93
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
94 95

    int im_channels =
96
        (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
97
    int im_height =
98
        (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
99
    int im_width =
100
        (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
C
chengduoZH 已提交
101 102 103 104 105
    int filter_height = col->dims()[1];
    int filter_width = col->dims()[2];
    int col_height = col->dims()[3];
    int col_width = col->dims()[4];

C
chengduoZH 已提交
106
    int num_outputs = im_channels * col_height * col_width;
H
hedaoyuan 已提交
107 108 109
    int blocks = (num_outputs + 1024 - 1) / 1024;
    int block_x = 512;
    int block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
110
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
111
    dim3 grid(block_x, block_y);
Q
QI JUN 已提交
112
    im2col<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
113 114
        im.data<T>(), num_outputs, im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
115 116
        padding[0], padding[1], col_height, col_width, col->data<T>(),
        data_layout);
H
hedaoyuan 已提交
117 118 119 120
  }
};

template <class T>
C
chengduoZH 已提交
121 122 123 124
__global__ void col2im(int n, const T* data_col, int im_height, int im_width,
                       int dilation_h, int dilation_w, int filter_height,
                       int filter_width, int stride_height, int stride_width,
                       int padding_height, int padding_width, int col_height,
125 126
                       int col_width, T* data_im,
                       const DataLayout data_layout) {
C
chengduoZH 已提交
127
  const int index =
H
hedaoyuan 已提交
128
      (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
C
chengduoZH 已提交
129 130 131 132

  const int d_filter_height = dilation_h * (filter_height - 1) + 1;
  const int d_filter_width = dilation_w * (filter_width - 1) + 1;

133 134
  int input_channels = n / im_height / im_width;

H
hedaoyuan 已提交
135 136
  if (index < n) {
    T val = 0;
137
    int w = (data_layout != DataLayout::kNHWC
138 139
                 ? index % im_width + padding_width
                 : (index / input_channels) % im_width + padding_width);
140
    int h = (data_layout != DataLayout::kNHWC
141 142 143
                 ? (index / im_width) % im_height + padding_height
                 : (index / input_channels / im_width) % im_height +
                       padding_height);
144
    int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height
145
                                              : index % input_channels);
C
chengduoZH 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167

    // compute the start and end of the output
    int w_col_start =
        (w < d_filter_width) ? 0 : (w - d_filter_width) / stride_width + 1;
    int w_col_end = min(w / stride_width + 1, col_width);
    int h_col_start =
        (h < d_filter_height) ? 0 : (h - d_filter_height) / stride_height + 1;
    int h_col_end = min(h / stride_height + 1, col_height);

    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
        int h_off = (h - h_col * stride_height);
        int w_off = (w - w_col * stride_width);
        if (h_off % dilation_h == 0 && w_off % dilation_w == 0) {
          h_off /= dilation_h;
          w_off /= dilation_w;
          int data_col_index =
              (((c * filter_height + h_off) * filter_width + w_off) *
                   col_height +
               h_col) *
                  col_width +
              w_col;
C
chengduoZH 已提交
168

C
chengduoZH 已提交
169
          val += data_col[data_col_index];
H
hedaoyuan 已提交
170 171 172
        }
      }
    }
C
chengduoZH 已提交
173
    data_im[index] = val;
H
hedaoyuan 已提交
174 175 176 177
  }
}

/*
H
hedaoyuan 已提交
178 179 180
 * im = [input_channels, input_height, input_width]
 * col =
 *   [input_channels, filter_height, filter_width, output_height, output_width]
H
hedaoyuan 已提交
181 182
 */
template <class T>
H
hedaoyuan 已提交
183
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
184
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
185
 public:
Q
QI JUN 已提交
186
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
187 188 189
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
190 191
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
192 193 194 195 196
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
197
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
198 199 200 201
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
202 203

    int im_channels =
204
        (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
205
    int im_height =
206
        (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
207
    int im_width =
208
        (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
H
hedaoyuan 已提交
209 210
    int filter_height = col.dims()[1];
    int filter_width = col.dims()[2];
C
chengduoZH 已提交
211 212 213
    int col_height = col.dims()[3];
    int col_width = col.dims()[4];

C
chengduoZH 已提交
214 215 216
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
217
                          1,
218 219 220
                      col_height, platform::errors::InvalidArgument(
                                      "Output_height and padding(padding_up, "
                                      "padding_down) are inconsistent."));
C
chengduoZH 已提交
221 222 223
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
224
                          1,
225 226 227
                      col_width, platform::errors::InvalidArgument(
                                     "col_width and padding(padding_left, "
                                     "padding_right) are inconsistent."));
C
chengduoZH 已提交
228 229

    size_t num_kernels = im_channels * im_height * im_width;
H
hedaoyuan 已提交
230

H
hedaoyuan 已提交
231 232 233
    size_t blocks = (num_kernels + 1024 - 1) / 1024;
    size_t block_x = 512;
    size_t block_y = (blocks + 512 - 1) / 512;
H
hedaoyuan 已提交
234
    dim3 threads(1024, 1);
H
hedaoyuan 已提交
235
    dim3 grid(block_x, block_y);
H
hedaoyuan 已提交
236 237 238

    // To avoid involving atomic operations, we will launch one kernel per
    // bottom dimension, and then in the kernel add up the top dimensions.
Q
QI JUN 已提交
239
    col2im<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
240 241
        num_kernels, col.data<T>(), im_height, im_width, dilation[0],
        dilation[1], filter_height, filter_width, stride[0], stride[1],
242 243
        padding[0], padding[1], col_height, col_width, im->data<T>(),
        data_layout);
H
hedaoyuan 已提交
244 245 246
  }
};

H
hedaoyuan 已提交
247
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
248
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
249
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
250
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
251
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
252
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
253
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
Q
QI JUN 已提交
254
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
255 256

template <class T>
C
chengduoZH 已提交
257 258 259
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
260
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
261
                          int col_width, T* col_data) {
H
hedaoyuan 已提交
262 263
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
264
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
265 266 267 268
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
269
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
270 271
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
272

H
hedaoyuan 已提交
273 274
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
275 276 277 278 279 280 281 282
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);

        col_data[col_offset] =
            (height_offset >= im_height || height_offset < 0 ||
             width_offset >= im_width || width_offset < 0)
                ? T(0)
                : im_data[im_offset];
H
hedaoyuan 已提交
283 284 285 286 287 288
      }
    }
  }
}

/*
H
hedaoyuan 已提交
289 290 291
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
292 293
 */
template <class T>
H
hedaoyuan 已提交
294
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
295
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
296
 public:
Q
QI JUN 已提交
297
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
298 299
                  const framework::Tensor& im, const std::vector<int>& dilation,
                  const std::vector<int>& stride,
300 301
                  const std::vector<int>& padding, framework::Tensor* col,
                  const DataLayout data_layout) {
302 303 304 305 306
    PADDLE_ENFORCE_EQ(im.dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im.dims()));
307
    PADDLE_ENFORCE_EQ(col->dims().size(), 5,
308 309 310 311
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col->dims()));
312

C
chengduoZH 已提交
313 314 315
    int im_channels = im.dims()[0];
    int im_height = im.dims()[1];
    int im_width = im.dims()[2];
C
chengduoZH 已提交
316 317 318 319 320
    int filter_height = col->dims()[3];
    int filter_width = col->dims()[4];
    int col_height = col->dims()[0];
    int col_width = col->dims()[1];

H
hedaoyuan 已提交
321 322 323 324 325 326 327 328 329 330 331
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
332
    } else {
H
hedaoyuan 已提交
333 334
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
335 336
    }

H
hedaoyuan 已提交
337
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
338 339
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
340
    im2colOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
341 342 343
        im.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, col->data<T>());
H
hedaoyuan 已提交
344 345 346 347
  }
};

template <class T>
C
chengduoZH 已提交
348 349 350
__global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
                          int im_width, int filter_height, int filter_width,
                          int stride_height, int stride_width,
C
chengduoZH 已提交
351
                          int padding_height, int padding_width, int col_height,
C
chengduoZH 已提交
352
                          int col_width, T* im_data) {
H
hedaoyuan 已提交
353 354
  int swid = blockIdx.x;
  int shid = blockIdx.y;
C
chengduoZH 已提交
355
  for (int channelid = threadIdx.z; channelid < im_channels;
H
hedaoyuan 已提交
356 357 358 359
       channelid += blockDim.z) {
    for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
      for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
        int width_offset = idx + swid * stride_width - padding_width;
C
chengduoZH 已提交
360
        int height_offset = idy + shid * stride_height - padding_height;
C
chengduoZH 已提交
361 362
        int im_offset = width_offset + height_offset * im_width +
                        channelid * im_height * im_width;
H
hedaoyuan 已提交
363

H
hedaoyuan 已提交
364 365
        int col_offset = idx + idy * filter_width +
                         channelid * filter_height * filter_width +
C
chengduoZH 已提交
366 367
                         (shid * col_width + swid) *
                             (im_channels * filter_height * filter_width);
H
hedaoyuan 已提交
368

C
chengduoZH 已提交
369 370
        if (height_offset >= 0 && height_offset < im_height &&
            width_offset >= 0 && width_offset < im_width) {
H
hedaoyuan 已提交
371 372
          paddle::platform::CudaAtomicAdd(im_data + im_offset,
                                          col_data[col_offset]);
H
hedaoyuan 已提交
373 374 375 376 377 378 379
        }
      }
    }
  }
}

/*
H
hedaoyuan 已提交
380 381 382
 * im = [input_channels, input_height, input_width]
 * col =
 *   [output_height, output_width, input_channels, filter_height, filter_width]
H
hedaoyuan 已提交
383 384
 */
template <class T>
H
hedaoyuan 已提交
385
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
386
                    platform::CUDADeviceContext, T> {
H
hedaoyuan 已提交
387
 public:
Q
QI JUN 已提交
388
  void operator()(const platform::CUDADeviceContext& context,
C
chengduoZH 已提交
389 390 391
                  const framework::Tensor& col,
                  const std::vector<int>& dilation,
                  const std::vector<int>& stride,
392 393
                  const std::vector<int>& padding, framework::Tensor* im,
                  const DataLayout data_layout) {
394 395 396 397 398
    PADDLE_ENFORCE_EQ(im->dims().size(), 3,
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'im' should be 3. But got "
                          "the dims of tensor 'im' is [%s].",
                          im->dims()));
399
    PADDLE_ENFORCE_EQ(col.dims().size(), 5,
400 401 402 403
                      platform::errors::InvalidArgument(
                          "The dimension of tensor 'col' should be 5. But got "
                          "the dims of tensor 'col' is [%s].",
                          col.dims()));
404

C
chengduoZH 已提交
405 406 407
    int im_channels = im->dims()[0];
    int im_height = im->dims()[1];
    int im_width = im->dims()[2];
H
hedaoyuan 已提交
408 409
    int filter_height = col.dims()[3];
    int filter_width = col.dims()[4];
C
chengduoZH 已提交
410 411 412
    int col_height = col.dims()[0];
    int col_width = col.dims()[1];

C
chengduoZH 已提交
413 414 415
    PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
                       (dilation[0] * (filter_height - 1) + 1)) /
                              stride[0] +
C
chengduoZH 已提交
416
                          1,
417 418 419
                      col_height, platform::errors::InvalidArgument(
                                      "Output_height and padding(padding_up, "
                                      "padding_down) are inconsistent."));
C
chengduoZH 已提交
420 421 422
    PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
                       (dilation[1] * (filter_width - 1) + 1)) /
                              stride[1] +
C
chengduoZH 已提交
423
                          1,
424 425 426
                      col_width, platform::errors::InvalidArgument(
                                     "col_width and padding(padding_left, "
                                     "padding_right) are inconsistent."));
C
chengduoZH 已提交
427

H
hedaoyuan 已提交
428 429 430 431 432 433 434 435 436 437 438
    int block_dim_x = 0;
    int block_dim_y = 0;
    if (filter_height <= 4 && filter_width <= 4) {
      block_dim_x = 4;
      block_dim_y = 4;
    } else if (filter_height <= 8 && filter_width <= 8) {
      block_dim_x = 8;
      block_dim_y = 8;
    } else if (filter_height <= 16 && filter_width <= 16) {
      block_dim_x = 16;
      block_dim_y = 16;
H
hedaoyuan 已提交
439
    } else {
H
hedaoyuan 已提交
440 441
      block_dim_x = 32;
      block_dim_y = 32;
H
hedaoyuan 已提交
442 443
    }

H
hedaoyuan 已提交
444
    int block_dim_z = 1024 / block_dim_x / block_dim_y;
C
chengduoZH 已提交
445 446
    dim3 threads(block_dim_x, block_dim_y, std::min(block_dim_z, im_channels));
    dim3 grid(col_width, col_height);
Q
QI JUN 已提交
447
    col2imOCF<T><<<grid, threads, 0, context.stream()>>>(
C
chengduoZH 已提交
448 449 450
        col.data<T>(), im_channels, im_height, im_width, filter_height,
        filter_width, stride[0], stride[1], padding[0], padding[1], col_height,
        col_width, im->data<T>());
H
hedaoyuan 已提交
451 452 453
  }
};

H
hedaoyuan 已提交
454
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
455
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
456
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
457
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
458
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
459
                             platform::CUDADeviceContext, float>;
H
hedaoyuan 已提交
460
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
Q
QI JUN 已提交
461
                             platform::CUDADeviceContext, double>;
H
hedaoyuan 已提交
462

463
}  // namespace math
464
}  // namespace operators
H
hedaoyuan 已提交
465
}  // namespace paddle