提交 62da438e 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #4462 from chengduoZH/Add_vol2col_functor

Add vol2col functor
...@@ -3,11 +3,14 @@ if(WITH_GPU) ...@@ -3,11 +3,14 @@ if(WITH_GPU)
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator) cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context operator)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
endif() endif()
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/vol2col.h"
namespace paddle {
namespace operators {
namespace math {
/*
* vol = [input_channels, input_depth, input_height, input_width]
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template <class T>
class Vol2ColFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
const T* vol_data = vol.data<T>();
T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int d_offset = (c / filter_width / filter_height) % filter_depth;
int c_in = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset;
for (int h = 0; h < output_height; ++h) {
int h_pad = h * stride_height - padding_height + h_offset;
for (int w = 0; w < output_width; ++w) {
int w_pad = w * stride_width - padding_width + w_offset;
int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
if (h_pad < 0 || h_pad >= input_height || w_pad < 0 ||
w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) {
col_data[col_idx] = static_cast<T>(0);
} else {
int vol_idx =
((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
col_data[col_idx] = vol_data[vol_idx];
}
}
}
}
}
}
};
/*
* vol = [input_channels,input_depth, input_height, input_width]
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template <class T>
class Col2VolFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
T* vol_data = vol.data<T>();
const T* col_data = col.data<T>();
for (int c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int d_offset = (c / filter_width / filter_height) % filter_depth;
int cIm = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * stride_depth - padding_depth + d_offset;
for (int h = 0; h < output_height; ++h) {
int h_pad = h * stride_height - padding_height + h_offset;
for (int w = 0; w < output_width; ++w) {
int w_pad = w * stride_width - padding_width + w_offset;
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
int vol_idx =
((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
int col_idx =
((c * output_depth + d) * output_height + h) * output_width +
w;
vol_data[vol_idx] += col_data[col_idx];
}
}
}
}
}
}
};
template class Vol2ColFunctor<platform::CPUPlace, float>;
template class Vol2ColFunctor<platform::CPUPlace, double>;
template class Col2VolFunctor<platform::CPUPlace, float>;
template class Col2VolFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/vol2col.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <class T>
__global__ void vol2col(int num_kernels, const T* data_vol, int depth,
int height, int width, int filter_depth,
int filter_height, int filter_width, int stride_depth,
int stride_height, int stride_width, int padding_depth,
int padding_height, int padding_width, int output_detph,
int output_height, int output_width, T* data_col) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
int w_out = index % output_width;
int h_out = (index / output_width) % output_height;
int d_out = (index / output_width / output_height) % output_detph;
int channel_in = index / output_width / output_height / output_detph;
int channel_out = channel_in * filter_depth * filter_height * filter_width;
int w_in = w_out * stride_width - padding_width;
int h_in = h_out * stride_height - padding_height;
int d_in = d_out * stride_depth - padding_depth;
data_col += ((channel_out * output_detph + d_out) * output_height + h_out) *
output_width +
w_out;
data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filter_depth; ++k) {
for (int i = 0; i < filter_height; ++i) {
for (int j = 0; j < filter_width; ++j) {
int d = d_in + k;
int h = h_in + i;
int w = w_in + j;
*data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
w < width)
? data_vol[(k * height + i) * width + j]
: 0;
data_col += output_detph * output_height * output_width;
}
}
}
}
}
/*
* im = [input_channels,intpu_depth, input_height, input_width]
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template <class T>
class Vol2ColFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int num_outputs =
input_channels * output_depth * output_height * output_width;
const int threads = 1024;
const int blocks = (num_outputs + 1024 - 1) / 1024;
vol2col<T><<<blocks, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
filter_depth, filter_height, filter_width, stride_depth, stride_height,
stride_width, padding_depth, padding_height, padding_width,
output_depth, output_height, output_width, col.data<T>());
}
};
template <class T>
__global__ void col2vol(int num_kernels, const T* data_col, int depth,
int height, int width, int filter_depth,
int filter_height, int filter_width, int stride_depth,
int stride_height, int stride_width, int padding_depth,
int padding_height, int padding_width, int output_detph,
int output_height, int output_width, T* data_vol) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
T src_val = 0;
int w = index % width + padding_width;
int h = (index / width) % height + padding_height;
int d = (index / width / height) % depth + padding_depth;
int c = index / width / height / depth;
// compute the start and end of the output
int w_col_start =
(w < filter_width) ? 0 : (w - filter_width) / stride_width + 1;
int w_col_end = min(w / stride_width + 1, output_width);
int h_col_start =
(h < filter_height) ? 0 : (h - filter_height) / stride_height + 1;
int h_col_end = min(h / stride_height + 1, output_height);
int d_col_start =
(d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1;
int d_col_end = min(d / stride_depth + 1, output_detph);
int offset = (c * filter_depth * filter_height * filter_width +
d * filter_width * filter_height + h * filter_width + w) *
output_detph * output_height * output_width;
int coeff_d_col =
(1 - stride_depth * filter_width * filter_height * output_detph) *
output_height * output_width;
int coeff_h_col =
(1 - stride_height * filter_width * output_detph * output_height) *
output_width;
int coeff_w_col =
(1 - stride_width * output_detph * output_height * output_width);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
src_val += data_col[offset + d_col * coeff_d_col +
h_col * coeff_h_col + w_col * coeff_w_col];
}
}
}
data_vol[index] = src_val;
}
}
/*
* im = [input_channels, input_depth, input_height, input_width]
* col =
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template <class T>
class Col2VolFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const {
PADDLE_ENFORCE(vol.dims().size() == 4);
PADDLE_ENFORCE(col.dims().size() == 7);
int input_channels = vol.dims()[0];
int input_depth = vol.dims()[1];
int input_height = vol.dims()[2];
int input_width = vol.dims()[3];
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int num_kernels = input_channels * input_depth * input_height * input_width;
const int threads = 1024;
const int blocks = (num_kernels + 1024 - 1) / 1024;
col2vol<T><<<blocks, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width,
filter_depth, filter_height, filter_width, stride_depth, stride_height,
stride_width, padding_depth, padding_height, padding_width,
output_depth, output_height, output_width, vol.data<T>());
}
};
template class Vol2ColFunctor<platform::GPUPlace, float>;
template class Vol2ColFunctor<platform::GPUPlace, double>;
template class Col2VolFunctor<platform::GPUPlace, float>;
template class Col2VolFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/*
* \brief Converts the feature data of four dimensions(CDHW) into a colData of
* seven dimensions in the Vol2ColFunctor calculation,
* And in the Col2VolFunctor calculation, it is reversed.
*
* \param volData Vol data.
* \param volShape The shape of volData,
* [input_channels, input_depth, input_height, input_width].
* \param colData Column data.
* \param colShape The shape of colData.
*
* The shape of colData is:
* [input_channels, filter_depth, filter_height, filter_width, output_depth,
* output_height, output_width]
* So, it is easy to reshape into a convolution matrix for convolution
* calculation based on matrix multiplication.
* The shape of convolution matrix is [height, width], where the height is equal
* input_channels * filter_depth * filter_height * filter_width, and the width
* is equal output_depth * output_height * output_width.
*
* Reshape:
* shape of colData shape of convolution matrix
* [input_channels,
* filter_depth,
* filter_height,
* filter_width, ======> [height, width]
* output_depth,
* output_height,
* output_width]
*
* \note The caller needs to ensure that volShape.inputChannels is equal to
* colShape.inputChannels.
*/
template <typename Place, typename T>
class Vol2ColFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& vol, framework::Tensor& col,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const;
};
template <typename Place, typename T>
class Col2VolFunctor {
public:
void operator()(const platform::DeviceContext& context,
framework::Tensor& vol, const framework::Tensor& col,
int stride_depth, int stride_height, int stride_width,
int padding_depth, int padding_height,
int padding_width) const;
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/vol2col.h"
#include <gtest/gtest.h>
#include <iostream>
template <typename Place>
void testVol2col() {
paddle::framework::Tensor input;
paddle::framework::Tensor input_tmp;
paddle::framework::Tensor output;
paddle::framework::Tensor output_tmp;
auto* place = new Place();
paddle::platform::DeviceContext* context;
if (paddle::platform::is_cpu_place(*place)) {
context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else {
#ifdef PADDLE_WITH_CUDA
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#else
PADDLE_THROW("no GPU support");
#endif // PADDLE_WITH_CUDA
}
/**
* input = [[0, 1, 2,
* 3, 4, 5]
* [6, 7, 8,
* 9, 10, 11]]
*
* output = [0, 1
* 1, 2
* 3, 4
* 4, 5
* 6, 7
* 7, 8
* 9, 10
* 10, 11]
*
* col2vol = [[0, 2, 2,
* 3, 8, 5]
* [6, 14, 8,
* 9, 20, 11]]
*
*/
int input_depth = 2;
int input_height = 2;
int input_width = 3;
int filter_size = 2;
int stride = 1;
int padding = 0;
int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1;
int output_height = (input_height - filter_size + 2 * padding) / stride + 1;
int output_width = (input_width - filter_size + 2 * padding) / stride + 1;
// Vol2Col test
float* input_ptr =
input_tmp.mutable_data<float>({1, input_depth, input_height, input_width},
paddle::platform::CPUPlace());
float arr[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
memcpy(input_ptr, arr, 12 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom<float>(input_tmp, *place);
}
output.mutable_data<float>({1, filter_size, filter_size, filter_size,
output_depth, output_height, output_width},
*place);
paddle::operators::math::Vol2ColFunctor<Place, float> vol2col;
vol2col(*context, input, output, stride, stride, stride, padding, padding,
padding);
float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11};
float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) {
out_cfo_ptr = output.data<float>();
} else {
output_tmp.CopyFrom<float>(output, paddle::platform::CPUPlace());
out_cfo_ptr = output_tmp.data<float>();
}
for (int i = 0; i < 16; ++i) {
EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]);
}
// Col2Vol test
float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11};
memset(input_ptr, 0, 12 * sizeof(float));
if (paddle::platform::is_cpu_place(*place)) {
input = input_tmp;
} else {
input.CopyFrom<float>(input_tmp, *place);
}
paddle::operators::math::Col2VolFunctor<Place, float> col2vol;
col2vol(*context, input, output, stride, stride, stride, padding, padding,
padding);
float* in_ptr;
if (paddle::platform::is_cpu_place(*place)) {
in_ptr = input.data<float>();
} else {
input_tmp.CopyFrom<float>(input, paddle::platform::CPUPlace());
in_ptr = input_tmp.data<float>();
}
for (int i = 0; i < 12; ++i) {
EXPECT_EQ(in_ptr[i], col_2_vol[i]);
}
}
TEST(math, vol2col) {
testVol2col<paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA
testVol2col<paddle::platform::GPUPlace>();
#endif // PADDLE_WITH_CUDA
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册