提交 45c8f9b2 编写于 作者: H hedaoyuan

Add context parameter and math namespace.

上级 abfac74c
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/operators/math/im2col.h" #include "paddle/operators/math/im2col.h"
namespace paddle { namespace paddle {
namespace math {
/* /*
* im = [input_channels, input_height, input_width] * im = [input_channels, input_height, input_width]
...@@ -26,7 +27,7 @@ class Im2ColFunctor<kCFO, platform::CPUPlace, T> { ...@@ -26,7 +27,7 @@ class Im2ColFunctor<kCFO, platform::CPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -77,7 +78,7 @@ class Col2ImFunctor<kCFO, platform::CPUPlace, T> { ...@@ -77,7 +78,7 @@ class Col2ImFunctor<kCFO, platform::CPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -130,7 +131,7 @@ class Im2ColFunctor<kOCF, platform::CPUPlace, T> { ...@@ -130,7 +131,7 @@ class Im2ColFunctor<kOCF, platform::CPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -189,7 +190,7 @@ class Col2ImFunctor<kOCF, platform::CPUPlace, T> { ...@@ -189,7 +190,7 @@ class Col2ImFunctor<kOCF, platform::CPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -241,4 +242,5 @@ template class Im2ColFunctor<kOCF, platform::CPUPlace, double>; ...@@ -241,4 +242,5 @@ template class Im2ColFunctor<kOCF, platform::CPUPlace, double>;
template class Col2ImFunctor<kOCF, platform::CPUPlace, float>; template class Col2ImFunctor<kOCF, platform::CPUPlace, float>;
template class Col2ImFunctor<kOCF, platform::CPUPlace, double>; template class Col2ImFunctor<kOCF, platform::CPUPlace, double>;
} // namespace math
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
namespace paddle { namespace paddle {
namespace math {
template <class T> template <class T>
__global__ void im2col(const T* data_im, int num_outs, int height, int width, __global__ void im2col(const T* data_im, int num_outs, int height, int width,
...@@ -63,7 +64,7 @@ class Im2ColFunctor<kCFO, platform::GPUPlace, T> { ...@@ -63,7 +64,7 @@ class Im2ColFunctor<kCFO, platform::GPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -81,6 +82,7 @@ class Im2ColFunctor<kCFO, platform::GPUPlace, T> { ...@@ -81,6 +82,7 @@ class Im2ColFunctor<kCFO, platform::GPUPlace, T> {
int block_y = (blocks + 512 - 1) / 512; int block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(block_x, block_y); dim3 grid(block_x, block_y);
// TODO(hedaoyuan): launch kernel on specified stream
im2col<T><<<grid, threads>>>( im2col<T><<<grid, threads>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height, im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height, filter_width, stride_height, stride_width, padding_height,
...@@ -145,7 +147,7 @@ class Col2ImFunctor<kCFO, platform::GPUPlace, T> { ...@@ -145,7 +147,7 @@ class Col2ImFunctor<kCFO, platform::GPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -168,6 +170,7 @@ class Col2ImFunctor<kCFO, platform::GPUPlace, T> { ...@@ -168,6 +170,7 @@ class Col2ImFunctor<kCFO, platform::GPUPlace, T> {
// To avoid involving atomic operations, we will launch one kernel per // To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions. // bottom dimension, and then in the kernel add up the top dimensions.
// TODO(hedaoyuan): launch kernel on specified stream
col2im<T><<<grid, threads>>>( col2im<T><<<grid, threads>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height, num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height, input_width + 2 * padding_width, input_channels, filter_height,
...@@ -224,7 +227,7 @@ class Im2ColFunctor<kOCF, platform::GPUPlace, T> { ...@@ -224,7 +227,7 @@ class Im2ColFunctor<kOCF, platform::GPUPlace, T> {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -255,6 +258,7 @@ class Im2ColFunctor<kOCF, platform::GPUPlace, T> { ...@@ -255,6 +258,7 @@ class Im2ColFunctor<kOCF, platform::GPUPlace, T> {
dim3 threads(block_dim_x, block_dim_y, dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels)); std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height); dim3 grid(output_width, output_height);
// TODO(hedaoyuan): launch kernel on specified stream
im2colOCF<T><<<grid, threads>>>( im2colOCF<T><<<grid, threads>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
...@@ -304,7 +308,7 @@ class Col2ImFunctor<kOCF, platform::GPUPlace, T> { ...@@ -304,7 +308,7 @@ class Col2ImFunctor<kOCF, platform::GPUPlace, T> {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width) { int padding_width, platform::DeviceContext* context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -335,7 +339,8 @@ class Col2ImFunctor<kOCF, platform::GPUPlace, T> { ...@@ -335,7 +339,8 @@ class Col2ImFunctor<kOCF, platform::GPUPlace, T> {
dim3 threads(block_dim_x, block_dim_y, dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels)); std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height); dim3 grid(output_width, output_height);
col2imOCF<T><<<grid, threads, 0>>>( // TODO(hedaoyuan): launch kernel on specified stream
col2imOCF<T><<<grid, threads>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width); padding_height, padding_width, output_height, output_width);
...@@ -347,4 +352,5 @@ template class Im2ColFunctor<kOCF, platform::GPUPlace, double>; ...@@ -347,4 +352,5 @@ template class Im2ColFunctor<kOCF, platform::GPUPlace, double>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, float>; template class Col2ImFunctor<kOCF, platform::GPUPlace, float>;
template class Col2ImFunctor<kOCF, platform::GPUPlace, double>; template class Col2ImFunctor<kOCF, platform::GPUPlace, double>;
} // namespace math
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
namespace paddle { namespace paddle {
namespace math {
/* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */ /* The storage format of the coldata in the Im2ColFunctor and Col2ImFunctor. */
enum ColFormat { kCFO = 0, kOCF = 1 }; enum ColFormat { kCFO = 0, kOCF = 1 };
...@@ -72,7 +73,7 @@ class Im2ColFunctor { ...@@ -72,7 +73,7 @@ class Im2ColFunctor {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width); int padding_width, platform::DeviceContext* context);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
...@@ -80,7 +81,8 @@ class Col2ImFunctor { ...@@ -80,7 +81,8 @@ class Col2ImFunctor {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width); int padding_width, platform::DeviceContext* context);
}; };
} // namespace math
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册