提交 3c0aa0cc 编写于 作者: H hedaoyuan

Add GPU GemmConvFunction implementation

上级 3ce974b9
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "Function.h"
namespace paddle {
......
......@@ -19,8 +19,7 @@ limitations under the License. */
namespace paddle {
typedef Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> Compare2CpuFunction;
template <DeviceType DType1, DeviceType DType2>
class ConvolutionTest {
public:
ConvolutionTest(const std::string& conv1,
......@@ -50,13 +49,14 @@ public:
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2CpuFunction test(conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
Compare2Function<DType1, DType2> test(
conv1,
conv2,
FuncConfig()
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
batchSize, inputChannels, inputSize, inputSize};
......@@ -79,7 +79,13 @@ public:
};
TEST(Convolution, GEMM) {
ConvolutionTest test("NaiveConv-CPU", "GemmConv-CPU");
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test("NaiveConv-CPU",
"GemmConv-CPU");
}
TEST(Convolution, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test("GemmConv-CPU",
"GemmConv-GPU");
}
} // namespace paddle
......@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
#include "paddle/math/MathFunctions.h"
#include "GemmConvOp.h"
#include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle {
......@@ -24,7 +24,7 @@ namespace paddle {
* output_height, output_width]
*/
template <class T>
class Im2ColFunctor {
class Im2ColFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
......@@ -112,7 +112,8 @@ public:
resizeBuffer(size);
real* colData = reinterpret_cast<real*>(memory_->getBuf());
Im2ColFunctor<real> im2col;
Im2ColFunctor<Device, real> im2col;
GemmFunctor<Device, real> gemm;
size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth;
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
......@@ -136,19 +137,17 @@ public:
int M = outputChannels;
int N = outputHeight * outputWidth;
int K = inputChannels * filterHeight * filterWidth;
gemm<real>(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
colData,
N,
0.0f,
outputData + g * outputOffset,
N);
gemm(M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
colData,
N,
0.0f,
outputData + g * outputOffset,
N);
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
......@@ -166,5 +165,6 @@ private:
};
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "ConvOp.h"
namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <DeviceType Device, class T>
class Im2ColFunctor {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData);
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "ConvOp.h"
#include "GemmConvOp.h"
namespace paddle {
template<class T>
__global__
void im2col(const T* data_im, int numOuts, int height, int width,
int blockH, int blockW,
int strideH, int strideW,
int paddingH, int paddingW,
int height_col, int width_col,
T* data_col) {
int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < numOuts) {
int w_out = index % width_col;
index /= width_col;
int h_out = index % height_col;
int channel_in = index / height_col;
int channel_out = channel_in * blockH * blockW;
int h_in = h_out * strideH;
int w_in = w_out * strideW;
data_col += (channel_out * height_col + h_out) * width_col + w_out;
for (int i = 0; i < blockH; ++i) {
for (int j = 0; j < blockW; ++j) {
int rIdx = int(h_in+i);
int cIdx = int(w_in+j);
if ((rIdx-(int)paddingH) >= (int)height ||
(rIdx-(int)paddingH) < 0 ||
(cIdx-(int)paddingW) >= (int)width ||
(cIdx-(int)paddingW) < 0) {
*data_col = 0;
} else {
rIdx = rIdx + channel_in*height - paddingH;
cIdx = cIdx - paddingW;
*data_col = data_im[rIdx* width + cIdx];
}
data_col += height_col * width_col;
}
}
}
}
template <class T>
class Im2ColFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const T* imData,
int inputChannels,
int inputHeight,
int inputWidth,
int filterHeight,
int filterWidth,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int outputHeight,
int outputWidth,
T* colData) {
int numKernels = inputChannels * outputHeight * outputWidth;
int blocks = (numKernels + 1024 -1) / 1024;
int blockX = 512;
int blockY = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(blockX, blockY);
im2col<T><<< grid, threads, 0, STREAM_DEFAULT >>>
(imData, numKernels, inputHeight, inputWidth, filterHeight, filterWidth,
strideHeight, strideWidth, paddingHeight, paddingWidth,
outputHeight, outputWidth, colData);
CHECK_SYNC("Im2ColFunctor GPU failed");
}
};
template class Im2ColFunctor<DEVICE_TYPE_GPU, float>;
template class Im2ColFunctor<DEVICE_TYPE_GPU, double>;
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/math/MathFunctions.h"
namespace paddle {
// TODO(hedaoyuan): Since the hl_matrix_mul interface does not conform to the
// cblas_dgemm interface's parameter format, it is necessary to introduce
// GemmFunctor as a new interface. Later, when considering the implementation
// of MatMulFunction, we need to consider the reconstruction of hl_matrix_mul
// interface.
template <DeviceType Device, class T>
class GemmFunctor {
public:
void operator()(const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc);
};
template <class T>
class GemmFunctor<DEVICE_TYPE_CPU, T> {
public:
void operator()(const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
gemm<T>(CblasNoTrans,
CblasNoTrans,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
}
};
template <class T>
class GemmFunctor<DEVICE_TYPE_GPU, T> {
public:
void operator()(const int M,
const int N,
const int K,
const T alpha,
const T* A,
const int lda,
const T* B,
const int ldb,
const T beta,
T* C,
const int ldc) {
hl_matrix_mul((T*)A,
HPPL_OP_N,
(T*)B,
HPPL_OP_N,
C,
M,
N,
K,
alpha,
beta,
lda,
ldb,
ldc);
}
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册