提交 0140f3f9 编写于 作者: H hedaoyuan 提交者: GitHub

Merge pull request #3753 from hedaoyuan/conv_op

Add im2col functor
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context)
else()
cc_library(math_function SRCS math_function.cc DEPS cblas device_context)
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/im2col.h"
namespace paddle {
namespace operators {
namespace math {
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
int channels_col = input_channels * filter_height * filter_width;
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) < 0 ||
(im_row_idx - padding_height) >= input_height ||
(im_col_idx - padding_width) < 0 ||
(im_col_idx - padding_width) >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0);
} else {
im_row_idx += c_im * input_height - padding_height;
im_col_idx -= padding_width;
col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx];
}
}
}
}
}
};
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
int channels_col = input_channels * filter_height * filter_width;
T* im_data = im.data<T>();
const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_height) >= 0 &&
(im_row_idx - padding_height) < input_height &&
(im_col_idx - padding_width) >= 0 &&
(im_col_idx - padding_width) < input_width) {
im_row_idx += c_im * input_height - padding_height;
im_col_idx -= padding_width;
im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w];
}
}
}
}
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, double>;
/*
* im = [input_channels, input_height, input_width]
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
input_channels +
channel) *
filter_height +
filter_row_idx) *
filter_width +
filter_col_idx;
if (im_row_offset < 0 || im_row_offset >= input_height ||
im_col_offset < 0 || im_col_offset >= input_width) {
col_data[col_offset] = T(0);
} else {
int im_offset =
(channel * input_height + im_row_offset) * input_width +
im_col_offset;
col_data[col_offset] = im_data[im_offset];
}
}
}
}
}
}
}
};
/*
* im = [input_channels, input_height, input_width]
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
T* im_data = im.data<T>();
const T* col_data = col.data<T>();
for (int col_row_idx = 0; col_row_idx < output_height; ++col_row_idx) {
for (int col_col_idx = 0; col_col_idx < output_width; ++col_col_idx) {
for (int channel = 0; channel < input_channels; ++channel) {
for (int filter_row_idx = 0; filter_row_idx < filter_height;
++filter_row_idx) {
for (int filter_col_idx = 0; filter_col_idx < filter_width;
++filter_col_idx) {
int im_row_offset =
col_row_idx * stride_height + filter_row_idx - padding_height;
int im_col_offset =
col_col_idx * stride_width + filter_col_idx - padding_width;
int col_offset = (((col_row_idx * output_width + col_col_idx) *
input_channels +
channel) *
filter_height +
filter_row_idx) *
filter_width +
filter_col_idx;
if (im_row_offset >= 0 && im_row_offset < input_height &&
im_col_offset >= 0 && im_col_offset < input_width) {
int im_offset =
(channel * input_height + im_row_offset) * input_width +
im_col_offset;
im_data[im_offset] += col_data[col_offset];
}
}
}
}
}
}
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/im2col.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
__global__ void im2col(const T* data_im, int num_outs, int height, int width,
int filter_height, int filter_width, int stride_height,
int stride_width, int padding_height, int padding_width,
int output_height, int output_width, T* data_col) {
int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < num_outs) {
int w_out = index % output_width;
index /= output_width;
int h_out = index % output_height;
int channel_in = index / output_height;
int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height;
int w_in = w_out * stride_width;
data_col += (channel_out * output_height + h_out) * output_width + w_out;
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int rIdx = int(h_in + i);
int cIdx = int(w_in + j);
if ((rIdx - (int)padding_height) >= (int)height ||
(rIdx - (int)padding_height) < 0 ||
(cIdx - (int)padding_width) >= (int)width ||
(cIdx - (int)padding_width) < 0) {
*data_col = 0;
} else {
rIdx = rIdx + channel_in * height - padding_height;
cIdx = cIdx - padding_width;
*data_col = data_im[rIdx * width + cIdx];
}
data_col += output_height * output_width;
}
}
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
int num_outputs = input_channels * output_height * output_width;
int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512;
int block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(block_x, block_y);
im2col<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, col.data<T>());
}
};
template <class T>
__global__ void col2im(size_t n, const T* data_col, size_t height, size_t width,
size_t channels, size_t filter_height,
size_t filter_width, size_t stride_height,
size_t stride_width, size_t padding_height,
size_t padding_width, size_t output_height,
size_t output_width, T* data_im) {
size_t index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < n) {
T val = 0;
int w = int(index % width);
int h = int((index / width) % height);
int c = int(index / (width * height));
if ((w - (int)padding_width) >= 0 &&
(w - (int)padding_width) < (width - 2 * padding_width) &&
(h - (int)padding_height) >= 0 &&
(h - padding_height) < (height - 2 * padding_height)) {
// compute the start and end of the output
int w_col_start = (w < (int)filter_width)
? 0
: (w - int(filter_width)) / (int)stride_width + 1;
int w_col_end =
min((int)(w / (int)stride_width + 1), (int)(output_width));
int h_col_start = (h < (int)filter_height)
? 0
: (h - (int)filter_height) / (int)stride_height + 1;
int h_col_end = min(int(h / stride_height + 1), int(output_height));
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
// the col location: [c * width * height + h_out, w_out]
int c_col = int(c * filter_height * filter_width) +
(h - h_col * (int)stride_height) * (int)filter_width +
(w - w_col * (int)stride_width);
val +=
data_col[(c_col * output_height + h_col) * output_width + w_col];
}
}
h -= padding_height;
w -= padding_width;
data_im[c * ((width - 2 * padding_width) *
(height - 2 * padding_height)) +
h * (width - 2 * padding_width) + w] += val;
}
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int output_height = col.dims()[3];
int output_width = col.dims()[4];
size_t num_kernels = input_channels * (input_height + 2 * padding_height) *
(input_width + 2 * padding_width);
size_t blocks = (num_kernels + 1024 - 1) / 1024;
size_t block_x = 512;
size_t block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(block_x, block_y);
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, im.data<T>());
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, double>;
template <class T>
__global__ void im2colOCF(const T* im_data, T* col_data, int input_channels,
int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width,
int output_height, int output_width) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels;
channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width;
int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width +
(shid * output_width + swid) *
(input_channels * filter_height * filter_width);
if (height_offset >= input_height || height_offset < 0 ||
width_offset >= input_width || width_offset < 0) {
col_data[col_offset] = T(0);
} else {
col_data[col_offset] = im_data[im_offset];
}
}
}
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
block_dim_x = 4;
block_dim_y = 4;
} else if (filter_height <= 8 && filter_width <= 8) {
block_dim_x = 8;
block_dim_y = 8;
} else if (filter_height <= 16 && filter_width <= 16) {
block_dim_x = 16;
block_dim_y = 16;
} else {
block_dim_x = 32;
block_dim_y = 32;
}
int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
im2colOCF<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
}
};
template <class T>
__global__ void col2imOCF(T* im_data, const T* col_data, int input_channels,
int input_height, int input_width, int filter_height,
int filter_width, int stride_height, int stride_width,
int padding_height, int padding_width,
int output_height, int output_width) {
int swid = blockIdx.x;
int shid = blockIdx.y;
for (int channelid = threadIdx.z; channelid < input_channels;
channelid += blockDim.z) {
for (int idy = threadIdx.y; idy < filter_height; idy += blockDim.y) {
for (int idx = threadIdx.x; idx < filter_width; idx += blockDim.x) {
int width_offset = idx + swid * stride_width - padding_width;
int height_offset = idy + shid * stride_height - padding_height;
int im_offset = width_offset + height_offset * input_width +
channelid * input_height * input_width;
int col_offset = idx + idy * filter_width +
channelid * filter_height * filter_width +
(shid * output_width + swid) *
(input_channels * filter_height * filter_width);
if (height_offset >= 0 && height_offset < input_height &&
width_offset >= 0 && width_offset < input_width) {
paddle::platform::CudaAtomicAdd(im_data + im_offset,
col_data[col_offset]);
}
}
}
}
}
/*
* im = [input_channels, input_height, input_width]
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
int input_height = im.dims()[1];
int input_width = im.dims()[2];
int filter_height = col.dims()[3];
int filter_width = col.dims()[4];
int output_height = col.dims()[0];
int output_width = col.dims()[1];
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
block_dim_x = 4;
block_dim_y = 4;
} else if (filter_height <= 8 && filter_width <= 8) {
block_dim_x = 8;
block_dim_y = 8;
} else if (filter_height <= 16 && filter_width <= 16) {
block_dim_x = 16;
block_dim_y = 16;
} else {
block_dim_x = 32;
block_dim_y = 32;
}
int block_dim_z = 1024 / block_dim_x / block_dim_y;
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
col2imOCF<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
}
};
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
enum class ColFormat { kCFO = 0, kOCF = 1 };
/*
* \brief Converts the image data of three dimensions(CHW) into a colData of
* five dimensions in the Im2ColFunctor calculation,
* And in the Col2ImFunctor calculation, it is reversed.
*
* \param imData Image data.
* \param imShape The shape of imData,
* [input_channels, input_height, input_width].
* \param colData Column data.
* \param colShape The shape of colData.
*
* If the template argument Format is kCFO, the shape of colData is:
* [input_channels, filter_height, filter_width, output_height, output_width]
* So, it is easy to reshape into a convolution matrix for convolution
* calculation based on matrix multiplication.
* The shape of convolution matrix is [height, width], where the height is equal
* input_channels * filter_height * filter_width, and the width is equal
* output_height * output_width.
*
* Reshape:
* shape of colData shape of convolution matrix
* [input_channels,
* filter_height,
* filter_width, ======> [height, width]
* output_height,
* output_width]
*
* If the template argument Format is kOCF, the shape of colData is:
* [output_height, output_width, input_channels, filter_height, filter_width]
* So, it is easy to reshape into a sequence matrix for rnn calculation.
* The shape of sequence matrix is [seq_length, step_size], where the seq_length
* is equal output_height * output_width, and the step_size is equal
* input_channels * filter_height * filter_width.
*
* Reshape:
* shape of colData shape of sequence matrix
* [output_height,
* output_width,
* input_channels, ======> [seqLength, stepSize]
* filter_height,
* filter_width]
*
* \note The caller needs to ensure that imShape.inputChannels is equal to
* colShape.inputChannels.
*/
template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context);
};
template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/im2col.h"
#include <gtest/gtest.h>
#include <iostream>
template <typename Place>
void testIm2col() {
paddle::framework::Tensor input_tmp;
paddle::framework::Tensor input;
paddle::framework::Tensor output_cfo;
paddle::framework::Tensor output_ocf;
paddle::framework::Tensor output_tmp;
/**
* input = [0, 1, 2,
* 3, 4, 5]
*
* output_cfo = [0, 1
* 1, 2
* 3, 4
* 4, 5]
*
* output_ocf = [0, 1, 3, 4
* 1, 2, 4, 5]
*/
int input_height = 2;
int input_width = 3;
int filter_size = 2;
int stride = 1;
int padding = 0;
int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
float* input_ptr = input_tmp.mutable_data<float>(
{1, input_height, input_width}, paddle::platform::CPUPlace());
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input_ptr, arr, 6 * sizeof(float));
auto* place = new Place();
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom<float>(input_tmp, *place);
}
output_cfo.mutable_data<float>(
{1, filter_size, filter_size, output_height, output_width}, *place);
output_ocf.mutable_data<float>(
{output_height, output_width, 1, filter_size, filter_size}, *place);
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, float>
im2col;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kOCF, Place, float>
im2col_ocf;
paddle::platform::DeviceContext* context;
if (paddle::platform::is_cpu_place(*place)) {
context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else {
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
}
im2col(input, output_cfo, stride, stride, padding, padding, context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output_cfo.data<float>();
} else {
output_tmp.CopyFrom<float>(output_cfo, paddle::platform::CPUPlace());
out_cfo_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_cfo_ptr[0], 0);
EXPECT_EQ(out_cfo_ptr[1], 1);
EXPECT_EQ(out_cfo_ptr[2], 1);
EXPECT_EQ(out_cfo_ptr[3], 2);
EXPECT_EQ(out_cfo_ptr[4], 3);
EXPECT_EQ(out_cfo_ptr[5], 4);
EXPECT_EQ(out_cfo_ptr[6], 4);
EXPECT_EQ(out_cfo_ptr[7], 5);
float* out_ocf_ptr;
if (paddle::platform::is_cpu_place(*place)) {
out_ocf_ptr = output_ocf.data<float>();
} else {
output_tmp.CopyFrom<float>(output_ocf, paddle::platform::CPUPlace());
out_ocf_ptr = output_tmp.data<float>();
}
EXPECT_EQ(out_ocf_ptr[0], 0);
EXPECT_EQ(out_ocf_ptr[1], 1);
EXPECT_EQ(out_ocf_ptr[2], 3);
EXPECT_EQ(out_ocf_ptr[3], 4);
EXPECT_EQ(out_ocf_ptr[4], 1);
EXPECT_EQ(out_ocf_ptr[5], 2);
EXPECT_EQ(out_ocf_ptr[6], 4);
EXPECT_EQ(out_ocf_ptr[7], 5);
}
TEST(math, im2col) {
testIm2col<paddle::platform::CPUPlace>();
#ifndef PADDLE_ONLY_CPU
testIm2col<paddle::platform::GPUPlace>();
#endif
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册