提交 efae51ce 编写于 作者: X xzl

add the mobilenet gpu acceleration, cpu is in the process

上级 eeb17c26
...@@ -18,11 +18,6 @@ limitations under the License. */ ...@@ -18,11 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <class T> template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> { class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> {
public: public:
...@@ -33,6 +28,8 @@ public: ...@@ -33,6 +28,8 @@ public:
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputHeight,
int inputWidth,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -40,7 +37,7 @@ public: ...@@ -40,7 +37,7 @@ public:
int paddingH, int paddingH,
int paddingW, int paddingW,
T* outputData) { T* outputData) {
// NO_IMPLEMENTATION // TODO(zhaolong) : cpu implementation of depthwise convolution
} }
}; };
...@@ -118,8 +115,8 @@ public: ...@@ -118,8 +115,8 @@ public:
size_t batchSize = input[0]; size_t batchSize = input[0];
// size_t inputChannels = input[1]; // size_t inputChannels = input[1];
// size_t inputHeight = input[2]; size_t inputHeight = input[2];
// size_t inputWidth = input[3]; size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter); size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter); size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1]; size_t outputChannels = output[1];
...@@ -139,6 +136,8 @@ public: ...@@ -139,6 +136,8 @@ public:
outputChannels, outputChannels,
outputHeight, outputHeight,
outputWidth, outputWidth,
inputHeight,
inputWidth,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH(), strideH(),
...@@ -233,8 +232,8 @@ public: ...@@ -233,8 +232,8 @@ public:
} }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size()); // CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size()); // CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs); check(inputs, outputs);
const TensorShape& output = inputs[0].shape(); const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[1].shape(); const TensorShape& input = inputs[1].shape();
......
...@@ -18,11 +18,6 @@ limitations under the License. */ ...@@ -18,11 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
/*
* imData = [input_channels, input_height, input_width]
* colData = [input_channels, filter_height, filter_width,
* output_height, output_width]
*/
template <DeviceType Device, class T> template <DeviceType Device, class T>
class DepthwiseConvFunctor { class DepthwiseConvFunctor {
public: public:
...@@ -33,6 +28,8 @@ public: ...@@ -33,6 +28,8 @@ public:
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputHeight,
int intputWidth,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
......
...@@ -14,73 +14,95 @@ limitations under the License. */ ...@@ -14,73 +14,95 @@ limitations under the License. */
#include "ConvOp.h" #include "ConvOp.h"
#include "DepthwiseConvOp.h" #include "DepthwiseConvOp.h"
#include "GemmFunctor.h"
#include "paddle/math/MemoryHandle.h"
namespace paddle { namespace paddle {
template <class T> template <class T>
__global__ void ConvolutionDepthwiseWeightForward(const int nthreads, __global__
const T* const bottom_data, const T* const weight_data, void ConvolutionDepthwiseForward(const int nthreads,
const int num, const int channels, const int top_height, const T* const inputData, const T* const filterData,
const int top_width, const int bottom_height, const int bottom_width, const int batchSize, const int outputChannels, const int outputHeight,
const int kernel_h, const int kernel_w, const int stride_h, const int outputWidth, const int inputHeight, const int inputWidth,
const int stride_w, const int pad_h, const int pad_w, const int filterHeight, const int filterWidth, const int strideH,
const int dilation_h, const int dilation_w, T* const top_data) { const int strideW, const int paddingH, const int paddingW,
T* const outputData) {
int index = int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if(index < nthreads) { if(index < nthreads) {
const int n = index / channels / top_height / top_width; const int n = index / outputChannels / outputHeight / outputWidth;
const int c = (index / top_height / top_width) % channels; const int c = (index / outputHeight / outputWidth) % outputChannels;
const int h = (index / top_width) % top_height; const int h = (index / outputWidth) % outputHeight;
const int w = index % top_width; const int w = index % outputWidth;
const T* weight = weight_data + c * kernel_h * kernel_w; const T* weight = filterData + c * filterHeight * filterWidth;
T value = 0; T value = 0;
for (int kh = 0; kh < kernel_h; ++kh) { const int h_in_start = -paddingH + h * strideH;
for (int kw = 0; kw < kernel_w; ++kw) { const int w_in_start = -paddingW + w * strideW;
const int h_in = -pad_h + h * stride_h + kh * dilation_h; const int h_in_end = -paddingH + h * strideH + filterHeight - 1;
const int w_in = -pad_w + w * stride_w + kw * dilation_w; const int w_in_end = -paddingW + w * strideW + filterWidth - 1;
if ((h_in >= 0) && (h_in < bottom_height) if ((h_in_start >= 0) && (h_in_end < inputHeight)
&& (w_in >= 0) && (w_in < bottom_width)) { &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
const int offset = ((n * channels + c) * bottom_height + h_in) for (int kh = 0; kh < filterHeight; ++kh) {
* bottom_width + w_in; for (int kw = 0; kw < filterWidth; ++kw) {
value += (*weight) * bottom_data[offset]; const int h_in = -paddingH + h * strideH + kh;
} const int w_in = -paddingW + w * strideW + kw;
++weight; const int offset = ((n * outputChannels + c) * inputHeight + h_in)
} * inputWidth + w_in;
} value += (*weight) * inputData[offset];
top_data[index] = value; ++weight;
}
}
}else{
for (int kh = 0; kh < filterHeight; ++kh) {
for (int kw = 0; kw < filterWidth; ++kw) {
const int h_in = -paddingH + h * strideH + kh;
const int w_in = -paddingW + w * strideW + kw;
if ((h_in >= 0) && (h_in < inputHeight)
&& (w_in >= 0) && (w_in < inputWidth)) {
const int offset = ((n * outputChannels + c) * inputHeight + h_in)
* inputWidth + w_in;
value += (*weight) * inputData[offset];
}
++weight;
}
}
}
outputData[index] = value;
} }
} }
template <class T> template <class T>
__global__ void ConvolutionDepthwiseBottomBackward(const int nthreads, __global__
void ConvolutionDepthwiseInputBackward(const int nthreads,
const T* const top_diff, const T* const weight_data, const T* const top_diff, const T* const weight_data,
const int num, const int channels, const int top_height, const int num, const int outputChannels, const int outputHeight,
const int top_width, const int bottom_height, const int bottom_width, const int outputWidth, const int inputHeight, const int inputWidth,
const int kernel_h, const int kernel_w, const int stride_h, const int filterHeight, const int filterWidth, const int strideH,
const int stride_w, const int pad_h, const int pad_w, const int strideW, const int paddingH, const int paddingW,
const int dilation_h, const int dilation_w, T* const bottom_diff) { T* const bottom_diff) {
int index = int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if(index < nthreads) { if(index < nthreads) {
const int n = index / channels / bottom_height / bottom_width; const int n = index / outputChannels / inputHeight / inputWidth;
const int c = (index / bottom_height / bottom_width) % channels; const int c = (index / inputHeight / inputWidth) % outputChannels;
const int h = (index / bottom_width) % bottom_height; const int h = (index / inputWidth) % inputHeight;
const int w = index % bottom_width; const int w = index % inputWidth;
const T* weight = weight_data + c * kernel_h * kernel_w; const T* weight = weight_data + c * filterHeight * filterWidth;
T value = 0; T value = 0;
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < filterHeight; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < filterWidth; ++kw) {
const int h_out_s = h + pad_h - kh * dilation_h; const int h_out_s = h + paddingH - kh;
const int w_out_s = w + pad_w - kw * dilation_w; const int w_out_s = w + paddingW - kw;
if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) { if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
const int h_out = h_out_s / stride_h; const int h_out = h_out_s / strideH;
const int w_out = w_out_s / stride_w; const int w_out = w_out_s / strideW;
//it affect the effectives // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
if ((h_out >= 0) && (h_out < top_height) if ((h_out >= 0) && (h_out < outputHeight)
&& (w_out >= 0) && (w_out < top_width)) { && (w_out >= 0) && (w_out < outputWidth)) {
const int offset = ((n * channels + c) * top_height + h_out) const int offset = ((n * outputChannels + c) * outputHeight + h_out)
* top_width + w_out; * outputWidth + w_out;
value += (*weight) * top_diff[offset]; value += (*weight) * top_diff[offset];
} }
} }
...@@ -92,32 +114,33 @@ __global__ void ConvolutionDepthwiseBottomBackward(const int nthreads, ...@@ -92,32 +114,33 @@ __global__ void ConvolutionDepthwiseBottomBackward(const int nthreads,
} }
template <class T> template <class T>
__global__ void ConvolutionDepthwiseWeightBackward(const int num_i, const int nthreads, __global__
const T* const top_diff, const T* const bottom_data, void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
const int num, const int channels, const int top_height, const T* const top_diff, const T* const inputData,
const int top_width, const int bottom_height, const int bottom_width, const int num, const int outputChannels, const int outputHeight,
const int kernel_h, const int kernel_w, const int stride_h, const int outputWidth, const int inputHeight, const int inputWidth,
const int stride_w, const int pad_h, const int pad_w, const int filterHeight, const int filterWidth, const int strideH,
const int dilation_h, const int dilation_w, T* const buffer_data) { const int strideW, const int paddingH, const int paddingW,
T* const buffer_data) {
int index = int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < nthreads) { if (index < nthreads) {
const int h = (index / top_width) % top_height; const int h = (index / outputWidth) % outputHeight;
const int w = index % top_width; const int w = index % outputWidth;
const int kh = (index / kernel_w / top_height / top_width) const int kh = (index / filterWidth / outputHeight / outputWidth)
% kernel_h; % filterHeight;
const int kw = (index / top_height / top_width) % kernel_w; const int kw = (index / outputHeight / outputWidth) % filterWidth;
const int h_in = -pad_h + h * stride_h + kh * dilation_h; const int h_in = -paddingH + h * strideH + kh;
const int w_in = -pad_w + w * stride_w + kw * dilation_w; const int w_in = -paddingW + w * strideW + kw;
if ((h_in >= 0) && (h_in < bottom_height) if ((h_in >= 0) && (h_in < inputHeight)
&& (w_in >= 0) && (w_in < bottom_width)) { && (w_in >= 0) && (w_in < inputWidth)) {
const int c = index / kernel_h / kernel_w / top_height / top_width; const int c = index / filterHeight / filterWidth / outputHeight / outputWidth;
const int n = num_i; const int n = num_i;
const int top_offset = ((n * channels + c) * top_height + h) const int top_offset = ((n * outputChannels + c) * outputHeight + h)
* top_width + w; * outputWidth + w;
const int bottom_offset = ((n * channels + c) * bottom_height + h_in) const int bottom_offset = ((n * outputChannels + c) * inputHeight + h_in)
* bottom_width + w_in; * inputWidth + w_in;
buffer_data[index] = top_diff[top_offset] * bottom_data[bottom_offset]; buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
} else { } else {
buffer_data[index] = 0; buffer_data[index] = 0;
} }
...@@ -134,6 +157,8 @@ public: ...@@ -134,6 +157,8 @@ public:
int outputChannels, int outputChannels,
int outputHeight, int outputHeight,
int outputWidth, int outputWidth,
int inputHeight,
int inputWidth,
int filterHeight, int filterHeight,
int filterWidth, int filterWidth,
int strideH, int strideH,
...@@ -148,7 +173,7 @@ public: ...@@ -148,7 +173,7 @@ public:
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(blockX, blockY);
ConvolutionDepthwiseWeightForward<T> ConvolutionDepthwiseForward<T>
<<< grid, threads, 0, STREAM_DEFAULT >>>( <<< grid, threads, 0, STREAM_DEFAULT >>>(
outputSize, outputSize,
inputData, inputData,
...@@ -157,6 +182,8 @@ public: ...@@ -157,6 +182,8 @@ public:
outputChannels, outputChannels,
outputHeight, outputHeight,
outputWidth, outputWidth,
inputHeight,
inputWidth,
filterHeight, filterHeight,
filterWidth, filterWidth,
strideH, strideH,
...@@ -193,7 +220,7 @@ public: ...@@ -193,7 +220,7 @@ public:
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(blockX, blockY);
ConvolutionDepthwiseBottomBackward<T> ConvolutionDepthwiseInputBackward<T>
// NOLINT_NEXT_LINE(whitespace/operators) // NOLINT_NEXT_LINE(whitespace/operators)
<<< grid, threads, 0, STREAM_DEFAULT >>>( <<< grid, threads, 0, STREAM_DEFAULT >>>(
inputSize, inputSize,
...@@ -244,10 +271,10 @@ public: ...@@ -244,10 +271,10 @@ public:
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(blockX, blockY); dim3 grid(blockX, blockY);
ConvolutionDepthwiseWeightBackward<T> ConvolutionDepthwiseFilterBackward<T>
<<< grid, threads, 0, STREAM_DEFAULT >>>( <<< grid, threads, 0, STREAM_DEFAULT >>>(
i, num_i,
size, colDataSize,
outputGrad, outputGrad,
inputData, inputData,
batchSize, batchSize,
...@@ -264,8 +291,8 @@ public: ...@@ -264,8 +291,8 @@ public:
paddingW, paddingW,
colData colData
); );
GemmFunctor<Device, real> gemm; GemmFunctor<DEVICE_TYPE_GPU, real> gemm;
int M = size / outputHeight / outputWidth; int M = colDataSize / outputHeight / outputWidth;
int N = 1; int N = 1;
int K = outputHeight * outputWidth; int K = outputHeight * outputWidth;
gemm(CblasNoTrans, gemm(CblasNoTrans,
...@@ -273,23 +300,25 @@ public: ...@@ -273,23 +300,25 @@ public:
M, M,
N, N,
K, K,
1.0f, (T)1.0,
colData, colData,
K, K,
multiplierData, multiplierData,
N, N,
1.0f, (T)1.0,
filterGrad, filterGrad,
N); N);
//gemv //gemv
} }
}; };
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>; #ifdef PADDLE_TYPE_DOUBLE
template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>; using real=double;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>; #else
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>; using real=float;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>; #endif
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>; template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, real>;
template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, real>;
template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, real>;
} // namespace paddle } // namespace paddle
...@@ -21,7 +21,8 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, ...@@ -21,7 +21,8 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
/* Initialize the basic parent class */ /* Initialize the basic parent class */
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv") isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv" ||
config_.type() == "depthwise_conv")
? false ? false
: true; : true;
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#include "DepthwiseConvLayer.h" #include "DepthwiseConvLayer.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include <iostream>
namespace paddle { namespace paddle {
...@@ -79,6 +80,7 @@ void DepthwiseConvLayer::forward(PassType passType) { ...@@ -79,6 +80,7 @@ void DepthwiseConvLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight(); size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight();
// std::cout << "outputSize" << getOutputSize() <<std::endl;
resetOutput(batchSize, getOutputSize()); resetOutput(batchSize, getOutputSize());
// Calculate the shape of the input, output, and filter. // Calculate the shape of the input, output, and filter.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册