提交 e967645c 编写于 作者: H hedaoyuan

Refine the gpu code.

上级 f7be9cb9
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc DEPS cblas device_context) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context) cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context)
endif() endif()
......
...@@ -12,86 +12,89 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,86 +12,89 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "Im2Col.h" #include "paddle/operators/math/im2col.h"
#include "hl_device_functions.cuh" #include "paddle/platform/cuda_helper.h"
namespace paddle { namespace paddle {
template <class T> template <class T>
__global__ void im2col(const T* data_im, int numOuts, int height, int width, __global__ void im2col(const T* data_im, int num_outs, int height, int width,
int blockH, int blockW, int strideH, int strideW, int filter_height, int filter_width, int stride_height,
int paddingH, int paddingW, int height_col, int stride_width, int padding_height, int padding_width,
int width_col, T* data_col) { int output_height, int output_width, T* data_col) {
int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < numOuts) { if (index < num_outs) {
int w_out = index % width_col; int w_out = index % output_width;
index /= width_col; index /= output_width;
int h_out = index % height_col; int h_out = index % output_height;
int channel_in = index / height_col; int channel_in = index / output_height;
int channel_out = channel_in * blockH * blockW; int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * strideH; int h_in = h_out * stride_height;
int w_in = w_out * strideW; int w_in = w_out * stride_width;
data_col += (channel_out * height_col + h_out) * width_col + w_out; data_col += (channel_out * output_height + h_out) * output_width + w_out;
for (int i = 0; i < blockH; ++i) { for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < blockW; ++j) { for (int j = 0; j < filter_width; ++j) {
int rIdx = int(h_in + i); int rIdx = int(h_in + i);
int cIdx = int(w_in + j); int cIdx = int(w_in + j);
if ((rIdx - (int)paddingH) >= (int)height || if ((rIdx - (int)padding_height) >= (int)height ||
(rIdx - (int)paddingH) < 0 || (rIdx - (int)padding_height) < 0 ||
(cIdx - (int)paddingW) >= (int)width || (cIdx - (int)padding_width) >= (int)width ||
(cIdx - (int)paddingW) < 0) { (cIdx - (int)padding_width) < 0) {
*data_col = 0; *data_col = 0;
} else { } else {
rIdx = rIdx + channel_in * height - paddingH; rIdx = rIdx + channel_in * height - padding_height;
cIdx = cIdx - paddingW; cIdx = cIdx - padding_width;
*data_col = data_im[rIdx * width + cIdx]; *data_col = data_im[rIdx * width + cIdx];
} }
data_col += height_col * width_col; data_col += output_height * output_width;
} }
} }
} }
} }
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T>
class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, T> { class Im2ColFunctor<kCFO, platform::GPUPlace, T> {
public: public:
void operator()(const T* imData, const TensorShape& imShape, T* colData, void operator()(const framework::Tensor& im, framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputHeight = colShape[3];
int outputWidth = colShape[4];
int numKernels = inputChannels * outputHeight * outputWidth; int input_channels = im.dims()[0];
int blocks = (numKernels + 1024 - 1) / 1024; int input_height = im.dims()[1];
int blockX = 512; int input_width = im.dims()[2];
int blockY = (blocks + 512 - 1) / 512; int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
int num_outputs = input_channels * output_height * output_width;
int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512;
int block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(block_x, block_y);
im2col<T><<<grid, threads, 0, STREAM_DEFAULT>>>( im2col<T><<<grid, threads>>>(
imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth, im.data<T>(), num_outputs, input_height, input_width, filter_height,
strideHeight, strideWidth, paddingHeight, paddingWidth, outputHeight, filter_width, stride_height, stride_width, padding_height,
outputWidth, colData); padding_width, output_height, output_width, col.data<T>());
CHECK_SYNC("Im2ColFunctor GPU failed");
} }
}; };
template <class T> template <class T>
__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width, __global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
size_t channels, size_t blockH, size_t blockW, size_t channels, size_t filter_height,
size_t strideH, size_t strideW, size_t paddingH, size_t filter_width, size_t stride_height,
size_t paddingW, size_t height_col, size_t width_col, size_t stride_width, size_t padding_height,
T* data_im) { size_t padding_width, size_t output_height,
size_t output_width, T* data_im) {
size_t index = size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) { if (index < n) {
...@@ -99,104 +102,112 @@ __global__ void col2im(size_t n, const T* data_col, size_t height, size_t width, ...@@ -99,104 +102,112 @@ __global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
int w = int(index % width); int w = int(index % width);
int h = int((index / width) % height); int h = int((index / width) % height);
int c = int(index / (width * height)); int c = int(index / (width * height));
if ((w - (int)paddingW) >= 0 && if ((w - (int)padding_width) >= 0 &&
(w - (int)paddingW) < (width - 2 * paddingW) && (w - (int)padding_width) < (width - 2 * padding_width) &&
(h - (int)paddingH) >= 0 && (h - paddingH) < (height - 2 * paddingH)) { (h - (int)padding_height) >= 0 &&
(h - padding_height) < (height - 2 * padding_height)) {
// compute the start and end of the output // compute the start and end of the output
int w_col_start = int w_col_start = (w < (int)filter_width)
(w < (int)blockW) ? 0 : (w - int(blockW)) / (int)strideW + 1; ? 0
int w_col_end = min((int)(w / (int)strideW + 1), (int)(width_col)); : (w - int(filter_width)) / (int)stride_width + 1;
int h_col_start = int w_col_end =
(h < (int)blockH) ? 0 : (h - (int)blockH) / (int)strideH + 1; min((int)(w / (int)stride_width + 1), (int)(output_width));
int h_col_end = min(int(h / strideH + 1), int(height_col)); int h_col_start = (h < (int)filter_height)
? 0
: (h - (int)filter_height) / (int)stride_height + 1;
int h_col_end = min(int(h / stride_height + 1), int(output_height));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out] // the col location: [c * width * height + h_out, w_out]
int c_col = int(c * blockH * blockW) + int c_col = int(c * filter_height * filter_width) +
(h - h_col * (int)strideH) * (int)blockW + (h - h_col * (int)stride_height) * (int)filter_width +
(w - w_col * (int)strideW); (w - w_col * (int)stride_width);
val += data_col[(c_col * height_col + h_col) * width_col + w_col]; val +=
data_col[(c_col * output_height + h_col) * output_width + w_col];
} }
} }
h -= paddingH; h -= padding_height;
w -= paddingW; w -= padding_width;
data_im[c * ((width - 2 * paddingW) * (height - 2 * paddingH)) + data_im[c * ((width - 2 * padding_width) *
h * (width - 2 * paddingW) + w] += val; (height - 2 * padding_height)) +
h * (width - 2 * padding_width) + w] += val;
} }
} }
} }
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [inputChannels, filterHeight, filterWidth, outputHeight, outputWidth] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class T>
class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, T> { class Col2ImFunctor<kCFO, platform::GPUPlace, T> {
public: public:
void operator()(T* imData, const TensorShape& imShape, const T* colData, void operator()(framework::Tensor& im, const framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2];
int filterHeight = colShape[1]; int input_channels = im.dims()[0];
int filterWidth = colShape[2]; int input_height = im.dims()[1];
int outputHeight = colShape[3]; int input_width = im.dims()[2];
int outputWidth = colShape[4]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
size_t numKernels = inputChannels * (inputHeight + 2 * paddingHeight) * size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
(inputWidth + 2 * paddingWidth); (input_width + 2 * padding_width);
size_t blocks = (numKernels + 1024 - 1) / 1024; size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t blockX = 512; size_t block_x = 512;
size_t blockY = (blocks + 512 - 1) / 512; size_t block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(block_x, block_y);
// To avoid involving atomic operations, we will launch one kernel per // To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions. // bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<<grid, threads, 0, STREAM_DEFAULT>>>( col2im<T><<<grid, threads>>>(
numKernels, colData, inputHeight + 2 * paddingHeight, num_kernels, col.data<T>(), input_height + 2 * padding_height,
inputWidth + 2 * paddingWidth, inputChannels, filterHeight, filterWidth, input_width + 2 * padding_width, input_channels, filter_height,
strideHeight, strideWidth, paddingHeight, paddingWidth, outputHeight, filter_width, stride_height, stride_width, padding_height,
outputWidth, imData); padding_width, output_height, output_width, im.data<T>());
CHECK_SYNC("Col2ImFunctor GPU failed");
} }
}; };
template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, float>; template class Im2ColFunctor<kCFO, platform::GPUPlace, float>;
template class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, double>; template class Im2ColFunctor<kCFO, platform::GPUPlace, double>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, float>; template class Col2ImFunctor<kCFO, platform::GPUPlace, float>;
template class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, double>; template class Col2ImFunctor<kCFO, platform::GPUPlace, double>;
template <class T> template <class T>
__global__ void im2colOCF(const T* imData, T* colData, int inputChannels, __global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int inputHeight, int inputWidth, int filterHeight, int input_height, int input_width, int filter_height,
int filterWidth, int strideHeight, int strideWidth, int filter_width, int stride_height, int stride_width,
int paddingHeight, int paddingWidth, int outputHeight, int padding_height, int padding_width,
int outputWidth) { int output_height, int output_width) {
int swId = blockIdx.x; int swid = blockIdx.x;
int shId = blockIdx.y; int shid = blockIdx.y;
for (int channelId = threadIdx.z; channelId < inputChannels; for (int channelid = threadIdx.z; channelid < input_channels;
channelId += blockDim.z) { channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int width_offset = idx + swid * stride_width - padding_width;
int heightOffset = idy + shId * strideHeight - paddingHeight; int height_offset = idy + shid * stride_height - padding_height;
int imOffset = widthOffset + heightOffset * inputWidth + int im_offset = width_offset + height_offset * input_width +
channelId * inputHeight * inputWidth; channelid * input_height * input_width;
int colOffset = idx + idy * filterWidth + int col_offset = idx + idy * filter_width +
channelId * filterHeight * filterWidth + channelid * filter_height * filter_width +
(shId * outputWidth + swId) * (shid * output_width + swid) *
(inputChannels * filterHeight * filterWidth); (input_channels * filter_height * filter_width);
if (heightOffset >= inputHeight || heightOffset < 0 || if (height_offset >= input_height || height_offset < 0 ||
widthOffset >= inputWidth || widthOffset < 0) { width_offset >= input_width || width_offset < 0) {
colData[colOffset] = T(0); col_data[col_offset] = T(0);
} else { } else {
colData[colOffset] = imData[imOffset]; col_data[col_offset] = im_data[im_offset];
} }
} }
} }
...@@ -204,76 +215,79 @@ __global__ void im2colOCF(const T* imData, T* colData, int inputChannels, ...@@ -204,76 +215,79 @@ __global__ void im2colOCF(const T* imData, T* colData, int inputChannels,
} }
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, T> { class Im2ColFunctor<kOCF, platform::GPUPlace, T> {
public: public:
void operator()(const T* imData, const TensorShape& imShape, T* colData, void operator()(const framework::Tensor& im, framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2]; int input_channels = im.dims()[0];
int filterHeight = colShape[3]; int input_height = im.dims()[1];
int filterWidth = colShape[4]; int input_width = im.dims()[2];
int outputHeight = colShape[0]; int filter_height = col.dims()[3];
int outputWidth = colShape[1]; int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int blockDimX = 0; int block_dim_x = 0;
int blockDimY = 0; int block_dim_y = 0;
if (filterHeight <= 4 && filterWidth <= 4) { if (filter_height <= 4 && filter_width <= 4) {
blockDimX = 4; block_dim_x = 4;
blockDimY = 4; block_dim_y = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) { } else if (filter_height <= 8 && filter_width <= 8) {
blockDimX = 8; block_dim_x = 8;
blockDimY = 8; block_dim_y = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) { } else if (filter_height <= 16 && filter_width <= 16) {
blockDimX = 16; block_dim_x = 16;
blockDimY = 16; block_dim_y = 16;
} else { } else {
blockDimX = 32; block_dim_x = 32;
blockDimY = 32; block_dim_y = 32;
} }
int blockDimZ = 1024 / blockDimX / blockDimY; int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels)); dim3 threads(block_dim_x, block_dim_y,
dim3 grid(outputWidth, outputHeight); std::min(block_dim_z, input_channels));
im2colOCF<T><<<grid, threads, 0, STREAM_DEFAULT>>>( dim3 grid(output_width, output_height);
imData, colData, inputChannels, inputHeight, inputWidth, filterHeight, im2colOCF<T><<<grid, threads>>>(
filterWidth, strideHeight, strideWidth, paddingHeight, paddingWidth, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
outputHeight, outputWidth); filter_height, filter_width, stride_height, stride_width,
CHECK_SYNC("Im2ColFunctor GPU failed"); padding_height, padding_width, output_height, output_width);
} }
}; };
template <class T> template <class T>
__global__ void col2imOCF(T* imData, const T* colData, int inputChannels, __global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int inputHeight, int inputWidth, int filterHeight, int input_height, int input_width, int filter_height,
int filterWidth, int strideHeight, int strideWidth, int filter_width, int stride_height, int stride_width,
int paddingHeight, int paddingWidth, int outputHeight, int padding_height, int padding_width,
int outputWidth) { int output_height, int output_width) {
int swId = blockIdx.x; int swid = blockIdx.x;
int shId = blockIdx.y; int shid = blockIdx.y;
for (int channelId = threadIdx.z; channelId < inputChannels; for (int channelid = threadIdx.z; channelid < input_channels;
channelId += blockDim.z) { channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filterHeight; idy += blockDim.y) { for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filterWidth; idx += blockDim.x) { for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int widthOffset = idx + swId * strideWidth - paddingWidth; int width_offset = idx + swid * stride_width - padding_width;
int heightOffset = idy + shId * strideHeight - paddingHeight; int height_offset = idy + shid * stride_height - padding_height;
int imOffset = widthOffset + heightOffset * inputWidth + int im_offset = width_offset + height_offset * input_width +
channelId * inputHeight * inputWidth; channelid * input_height * input_width;
int colOffset = idx + idy * filterWidth + int col_offset = idx + idy * filter_width +
channelId * filterHeight * filterWidth + channelid * filter_height * filter_width +
(shId * outputWidth + swId) * (shid * output_width + swid) *
(inputChannels * filterHeight * filterWidth); (input_channels * filter_height * filter_width);
if (heightOffset >= 0 && heightOffset < inputHeight && if (height_offset >= 0 && height_offset < input_height &&
widthOffset >= 0 && widthOffset < inputWidth) { width_offset >= 0 && width_offset < input_width) {
paddle::paddleAtomicAdd(imData + imOffset, colData[colOffset]); paddle::platform::CudaAtomicAdd(im_data + im_offset,
col_data[col_offset]);
} }
} }
} }
...@@ -281,54 +295,56 @@ __global__ void col2imOCF(T* imData, const T* colData, int inputChannels, ...@@ -281,54 +295,56 @@ __global__ void col2imOCF(T* imData, const T* colData, int inputChannels,
} }
/* /*
* imShape = [inputChannels, inputHeight, inputWidth] * im = [input_channels, input_height, input_width]
* colShape = * col =
* [outputHeight, outputWidth, inputChannels, filterHeight, filterWidth] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, T> { class Col2ImFunctor<kOCF, platform::GPUPlace, T> {
public: public:
void operator()(T* imData, const TensorShape& imShape, const T* colData, void operator()(framework::Tensor& im, const framework::Tensor& col,
const TensorShape& colShape, int strideHeight, int stride_height, int stride_width, int padding_height,
int strideWidth, int paddingHeight, int paddingWidth) { int padding_width) {
int inputChannels = imShape[0]; PADDLE_ENFORCE(im.dims().size() == 3);
int inputHeight = imShape[1]; PADDLE_ENFORCE(col.dims().size() == 5);
int inputWidth = imShape[2]; int input_channels = im.dims()[0];
int filterHeight = colShape[3]; int input_height = im.dims()[1];
int filterWidth = colShape[4]; int input_width = im.dims()[2];
int outputHeight = colShape[0]; int filter_height = col.dims()[3];
int outputWidth = colShape[1]; int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int blockDimX = 0; int block_dim_x = 0;
int blockDimY = 0; int block_dim_y = 0;
if (filterHeight <= 4 && filterWidth <= 4) { if (filter_height <= 4 && filter_width <= 4) {
blockDimX = 4; block_dim_x = 4;
blockDimY = 4; block_dim_y = 4;
} else if (filterHeight <= 8 && filterWidth <= 8) { } else if (filter_height <= 8 && filter_width <= 8) {
blockDimX = 8; block_dim_x = 8;
blockDimY = 8; block_dim_y = 8;
} else if (filterHeight <= 16 && filterWidth <= 16) { } else if (filter_height <= 16 && filter_width <= 16) {
blockDimX = 16; block_dim_x = 16;
blockDimY = 16; block_dim_y = 16;
} else { } else {
blockDimX = 32; block_dim_x = 32;
blockDimY = 32; block_dim_y = 32;
} }
int blockDimZ = 1024 / blockDimX / blockDimY; int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(blockDimX, blockDimY, std::min(blockDimZ, inputChannels)); dim3 threads(block_dim_x, block_dim_y,
dim3 grid(outputWidth, outputHeight); std::min(block_dim_z, input_channels));
col2imOCF<T><<<grid, threads, 0, STREAM_DEFAULT>>>( dim3 grid(output_width, output_height);
imData, colData, inputChannels, inputHeight, inputWidth, filterHeight, col2imOCF<T><<<grid, threads, 0>>>(
filterWidth, strideHeight, strideWidth, paddingHeight, paddingWidth, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
outputHeight, outputWidth); filter_height, filter_width, stride_height, stride_width,
CHECK_SYNC("Col2ImFunctor GPU failed"); padding_height, padding_width, output_height, output_width);
} }
}; };
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, float>; template class Im2ColFunctor<kOCF, platform::GPUPlace, float>;
template class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, double>; template class Im2ColFunctor<kOCF, platform::GPUPlace, double>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, float>; template class Col2ImFunctor<kOCF, platform::GPUPlace, float>;
template class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, double>; template class Col2ImFunctor<kOCF, platform::GPUPlace, double>;
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册