diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 173ca228096d962f5bcffa18d66d2926295d0a7c..017d4e26f2b7b583737a8698ac87e63e69535ac2 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once + #include "Function.h" namespace paddle { diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index eb0084804814cd188b90f8e0933cc633657a9d19..896267141337f1476c8b48a13b99cc928485f11c 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -19,8 +19,7 @@ limitations under the License. */ namespace paddle { -typedef Compare2Function Compare2CpuFunction; - +template class ConvolutionTest { public: ConvolutionTest(const std::string& conv1, @@ -50,13 +49,14 @@ public: std::vector paddings = {padding, padding}; std::vector strides = {stride, stride}; - Compare2CpuFunction test(conv1, - conv2, - FuncConfig() - .set("paddings", paddings) - .set("strides", strides) - .set("groups", (size_t)1) - .set("algo", algo)); + Compare2Function test( + conv1, + conv2, + FuncConfig() + .set("paddings", paddings) + .set("strides", strides) + .set("groups", (size_t)1) + .set("algo", algo)); TensorShape shape0{ batchSize, inputChannels, inputSize, inputSize}; @@ -79,7 +79,13 @@ public: }; TEST(Convolution, GEMM) { - ConvolutionTest test("NaiveConv-CPU", "GemmConv-CPU"); + ConvolutionTest test("NaiveConv-CPU", + "GemmConv-CPU"); +} + +TEST(Convolution, GEMM2) { + ConvolutionTest test("GemmConv-CPU", + "GemmConv-GPU"); } } // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index b8e44cc60bce40182c5ada6c8d975d9f297c3025..6857fe7482497b4a0c1f77e5ef1b0666f8074f5a 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "ConvOp.h" -#include "paddle/math/MathFunctions.h" +#include "GemmConvOp.h" +#include "GemmFunctor.h" #include "paddle/math/MemoryHandle.h" namespace paddle { @@ -24,7 +24,7 @@ namespace paddle { * output_height, output_width] */ template -class Im2ColFunctor { +class Im2ColFunctor { public: void operator()(const T* imData, int inputChannels, @@ -112,7 +112,8 @@ public: resizeBuffer(size); real* colData = reinterpret_cast(memory_->getBuf()); - Im2ColFunctor im2col; + Im2ColFunctor im2col; + GemmFunctor gemm; size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; size_t outputOffset = (outputChannels / groups_) * outputHeight * outputWidth; @@ -136,19 +137,17 @@ public: int M = outputChannels; int N = outputHeight * outputWidth; int K = inputChannels * filterHeight * filterWidth; - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1.0f, - filterData + g * filterOffset, - K, - colData, - N, - 0.0f, - outputData + g * outputOffset, - N); + gemm(M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + colData, + N, + 0.0f, + outputData + g * outputOffset, + N); inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; } @@ -166,5 +165,6 @@ private: }; REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); +REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction); } // namespace paddle diff --git a/paddle/function/GemmConvOp.h b/paddle/function/GemmConvOp.h new file mode 100644 index 0000000000000000000000000000000000000000..652a64afba4a5a26c4edc8fe00237f1bceb781c7 --- /dev/null +++ b/paddle/function/GemmConvOp.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "ConvOp.h" + +namespace paddle { + +/* + * imData = [input_channels, input_height, input_width] + * colData = [input_channels, filter_height, filter_width, + * output_height, output_width] + */ +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* colData); +}; + +} // namespace paddle diff --git a/paddle/function/GemmConvOpGpu.cu b/paddle/function/GemmConvOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..06b9904261cd7df8a00aa655abb1985345f96cd6 --- /dev/null +++ b/paddle/function/GemmConvOpGpu.cu @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "ConvOp.h" +#include "GemmConvOp.h" + +namespace paddle { + +template +__global__ +void im2col(const T* data_im, int numOuts, int height, int width, + int blockH, int blockW, + int strideH, int strideW, + int paddingH, int paddingW, + int height_col, int width_col, + T* data_col) { + int index = + (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; + if (index < numOuts) { + int w_out = index % width_col; + index /= width_col; + int h_out = index % height_col; + int channel_in = index / height_col; + int channel_out = channel_in * blockH * blockW; + int h_in = h_out * strideH; + int w_in = w_out * strideW; + + data_col += (channel_out * height_col + h_out) * width_col + w_out; + for (int i = 0; i < blockH; ++i) { + for (int j = 0; j < blockW; ++j) { + int rIdx = int(h_in+i); + int cIdx = int(w_in+j); + if ((rIdx-(int)paddingH) >= (int)height || + (rIdx-(int)paddingH) < 0 || + (cIdx-(int)paddingW) >= (int)width || + (cIdx-(int)paddingW) < 0) { + *data_col = 0; + } else { + rIdx = rIdx + channel_in*height - paddingH; + cIdx = cIdx - paddingW; + *data_col = data_im[rIdx* width + cIdx]; + } + data_col += height_col * width_col; + } + } + } +} + +template +class Im2ColFunctor { +public: + void operator()(const T* imData, + int inputChannels, + int inputHeight, + int inputWidth, + int filterHeight, + int filterWidth, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int outputHeight, + int outputWidth, + T* colData) { + int numKernels = inputChannels * outputHeight * outputWidth; + int blocks = (numKernels + 1024 -1) / 1024; + int blockX = 512; + int blockY = (blocks + 512 - 1) / 512; + dim3 threads(1024, 1); + dim3 grid(blockX, blockY); + im2col<<< grid, threads, 0, STREAM_DEFAULT >>> + (imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth, + strideHeight, strideWidth, paddingHeight, paddingWidth, + outputHeight, outputWidth, colData); + CHECK_SYNC("Im2ColFunctor GPU failed"); + } +}; + +template class Im2ColFunctor; +template class Im2ColFunctor; + +} // namespace paddle diff --git a/paddle/function/GemmFunctor.h b/paddle/function/GemmFunctor.h new file mode 100644 index 0000000000000000000000000000000000000000..5fb2f8a6d9e8fde7273e2e0b0de1101fe5989eb2 --- /dev/null +++ b/paddle/function/GemmFunctor.h @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/MathFunctions.h" + +namespace paddle { + +// TODO(hedaoyuan): Since the hl_matrix_mul interface does not conform to the +// cblas_dgemm interface's parameter format, it is necessary to introduce +// GemmFunctor as a new interface. Later, when considering the implementation +// of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul +// interface. +template +class GemmFunctor { +public: + void operator()(const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc); +}; + +template +class GemmFunctor { +public: + void operator()(const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); + } +}; + +template +class GemmFunctor { +public: + void operator()(const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc) { + hl_matrix_mul((T*)A, + HPPL_OP_N, + (T*)B, + HPPL_OP_N, + C, + M, + N, + K, + alpha, + beta, + lda, + ldb, + ldc); + } +}; + +} // namespace paddle