From efae51ce240e83daff7d2042e14f7710286e9827 Mon Sep 17 00:00:00 2001
From: xzl <zlx_hg@163.com>
Date: Fri, 7 Jul 2017 21:36:02 +0800
Subject: [PATCH] add the mobilenet gpu acceleration, cpu is in the process

---
 paddle/function/DepthwiseConvOp.cpp          |  19 +-
 paddle/function/DepthwiseConvOp.h            |   7 +-
 paddle/function/DepthwiseConvOpGpu.cu        | 201 +++++++++++--------
 paddle/gserver/layers/ConvBaseLayer.cpp      |   3 +-
 paddle/gserver/layers/DepthwiseConvLayer.cpp |   2 +
 5 files changed, 130 insertions(+), 102 deletions(-)

diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp
index ad332d2931b..d4272c72f24 100644
--- a/paddle/function/DepthwiseConvOp.cpp
+++ b/paddle/function/DepthwiseConvOp.cpp
@@ -18,11 +18,6 @@ limitations under the License. */
 
 namespace paddle {
 
-/*
- * imData = [input_channels, input_height, input_width]
- * colData = [input_channels, filter_height, filter_width,
- *            output_height, output_width]
- */
 template <class T>
 class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> {
 public:
@@ -33,6 +28,8 @@ public:
                   int outputChannels,
                   int outputHeight,
                   int outputWidth,
+                  int inputHeight,
+                  int inputWidth,
                   int filterHeight,
                   int filterWidth,
                   int strideH,
@@ -40,7 +37,7 @@ public:
                   int paddingH,
                   int paddingW,
                   T* outputData) {
-    // NO_IMPLEMENTATION
+    // TODO(zhaolong) : cpu implementation of depthwise convolution
   }
 };
 
@@ -118,8 +115,8 @@ public:
 
     size_t batchSize = input[0];
     // size_t inputChannels = input[1];
-    // size_t inputHeight = input[2];
-    // size_t inputWidth = input[3];
+    size_t inputHeight = input[2];
+    size_t inputWidth = input[3];
     size_t filterHeight = getFilterHeight(filter);
     size_t filterWidth = getFilterWidth(filter);
     size_t outputChannels = output[1];
@@ -139,6 +136,8 @@ public:
                   outputChannels,
                   outputHeight,
                   outputWidth,
+                  inputHeight,
+                  inputWidth,
                   filterHeight,
                   filterWidth,
                   strideH(),
@@ -233,8 +232,8 @@ public:
   }
 
   void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
-    CHECK_EQ(numInputs_, inputs.size());
-    CHECK_EQ(numOutputs_, outputs.size());
+    // CHECK_EQ(numInputs_, inputs.size());
+    // CHECK_EQ(numOutputs_, outputs.size());
     check(inputs, outputs);
     const TensorShape& output = inputs[0].shape();
     const TensorShape& input = inputs[1].shape();
diff --git a/paddle/function/DepthwiseConvOp.h b/paddle/function/DepthwiseConvOp.h
index 8af1db974de..44290682def 100644
--- a/paddle/function/DepthwiseConvOp.h
+++ b/paddle/function/DepthwiseConvOp.h
@@ -18,11 +18,6 @@ limitations under the License. */
 
 namespace paddle {
 
-/*
- * imData = [input_channels, input_height, input_width]
- * colData = [input_channels, filter_height, filter_width,
- *            output_height, output_width]
- */
 template <DeviceType Device, class T>
 class DepthwiseConvFunctor {
 public:
@@ -33,6 +28,8 @@ public:
                   int outputChannels,
                   int outputHeight,
                   int outputWidth,
+                  int inputHeight,
+                  int intputWidth,
                   int filterHeight,
                   int filterWidth,
                   int strideH,
diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu
index 1b2d5d99ed2..08fe9221ac0 100644
--- a/paddle/function/DepthwiseConvOpGpu.cu
+++ b/paddle/function/DepthwiseConvOpGpu.cu
@@ -14,73 +14,95 @@ limitations under the License. */
 
 #include "ConvOp.h"
 #include "DepthwiseConvOp.h"
+#include "GemmFunctor.h"
+#include "paddle/math/MemoryHandle.h"
 
 namespace paddle {
 template <class T>
-__global__ void ConvolutionDepthwiseWeightForward(const int nthreads,
-    const T* const bottom_data, const T* const weight_data,
-    const int num, const int channels, const int top_height,
-    const int top_width, const int bottom_height, const int bottom_width,
-    const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, const int pad_h, const int pad_w,
-    const int dilation_h, const int dilation_w, T* const top_data) {
+__global__ 
+void ConvolutionDepthwiseForward(const int nthreads,
+    const T* const inputData, const T* const filterData,
+    const int batchSize, const int outputChannels, const int outputHeight,
+    const int outputWidth, const int inputHeight, const int inputWidth,
+    const int filterHeight, const int filterWidth, const int strideH,
+    const int strideW, const int paddingH, const int paddingW,
+    T* const outputData) {
 
   int index =
     (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
   
   if(index < nthreads) {
-    const int n = index / channels / top_height / top_width;
-    const int c = (index / top_height / top_width) % channels;
-    const int h = (index / top_width) % top_height;
-    const int w = index % top_width;
-    const T* weight = weight_data + c * kernel_h * kernel_w;
+    const int n = index / outputChannels / outputHeight / outputWidth;
+    const int c = (index / outputHeight / outputWidth) % outputChannels;
+    const int h = (index / outputWidth) % outputHeight;
+    const int w = index % outputWidth;
+    const T* weight = filterData + c * filterHeight * filterWidth;
     T value = 0;
-    for (int kh = 0; kh < kernel_h; ++kh) {
-      for (int kw = 0; kw < kernel_w; ++kw) {
-        const int h_in = -pad_h + h * stride_h + kh * dilation_h;
-        const int w_in = -pad_w + w * stride_w + kw * dilation_w;
-        if ((h_in >= 0) && (h_in < bottom_height)
-              && (w_in >= 0) && (w_in < bottom_width)) {
-          const int offset = ((n * channels + c) * bottom_height + h_in)
-                * bottom_width + w_in;
-          value += (*weight) * bottom_data[offset];
-        }
-        ++weight;
-      }
-    }
-    top_data[index] = value;
+	const int h_in_start = -paddingH + h * strideH;
+	const int w_in_start = -paddingW + w * strideW;
+	const int h_in_end = -paddingH + h * strideH + filterHeight - 1;
+	const int w_in_end = -paddingW + w * strideW + filterWidth - 1;
+    if ((h_in_start >= 0) && (h_in_end < inputHeight) 
+		 &&(w_in_start >= 0) && (w_in_end < inputWidth)) {
+		for (int kh = 0; kh < filterHeight; ++kh) {
+		  for (int kw = 0; kw < filterWidth; ++kw) {
+			const int h_in = -paddingH + h * strideH + kh;
+			const int w_in = -paddingW + w * strideW + kw;
+			  const int offset = ((n * outputChannels + c) * inputHeight + h_in)
+					* inputWidth + w_in;
+			  value += (*weight) * inputData[offset];
+			++weight;
+		  }
+		}
+	}else{
+		for (int kh = 0; kh < filterHeight; ++kh) {
+		  for (int kw = 0; kw < filterWidth; ++kw) {
+			const int h_in = -paddingH + h * strideH + kh;
+			const int w_in = -paddingW + w * strideW + kw;
+			if ((h_in >= 0) && (h_in < inputHeight)
+				  && (w_in >= 0) && (w_in < inputWidth)) {
+			  const int offset = ((n * outputChannels + c) * inputHeight + h_in)
+					* inputWidth + w_in;
+			  value += (*weight) * inputData[offset];
+			}
+			++weight;
+		  }
+		}
+	}
+    outputData[index] = value;
   }
 }
 
 template <class T>
-__global__ void ConvolutionDepthwiseBottomBackward(const int nthreads,
+__global__
+void ConvolutionDepthwiseInputBackward(const int nthreads,
     const T* const top_diff, const T* const weight_data,
-    const int num, const int channels, const int top_height,
-    const int top_width, const int bottom_height, const int bottom_width,
-    const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, const int pad_h, const int pad_w,
-    const int dilation_h, const int dilation_w, T* const bottom_diff) {
+    const int num, const int outputChannels, const int outputHeight,
+    const int outputWidth, const int inputHeight, const int inputWidth,
+    const int filterHeight, const int filterWidth, const int strideH,
+    const int strideW, const int paddingH, const int paddingW,
+     T* const bottom_diff) {
   int index =
     (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
   if(index < nthreads) {
-    const int n = index / channels / bottom_height / bottom_width;
-    const int c = (index / bottom_height / bottom_width) % channels;
-    const int h = (index / bottom_width) % bottom_height;
-    const int w = index % bottom_width;
-    const T* weight = weight_data + c * kernel_h * kernel_w;
+    const int n = index / outputChannels / inputHeight / inputWidth;
+    const int c = (index / inputHeight / inputWidth) % outputChannels;
+    const int h = (index / inputWidth) % inputHeight;
+    const int w = index % inputWidth;
+    const T* weight = weight_data + c * filterHeight * filterWidth;
     T value = 0;
-    for (int kh = 0; kh < kernel_h; ++kh) {
-      for (int kw = 0; kw < kernel_w; ++kw) {
-        const int h_out_s = h + pad_h - kh * dilation_h;
-        const int w_out_s = w + pad_w - kw * dilation_w;
-        if (((h_out_s % stride_h) == 0) && ((w_out_s % stride_w) == 0)) {
-          const int h_out = h_out_s / stride_h;
-          const int w_out = w_out_s / stride_w;
-	  //it affect the effectives
-          if ((h_out >= 0) && (h_out < top_height)
-                && (w_out >= 0) && (w_out < top_width)) {
-            const int offset = ((n * channels + c) * top_height + h_out)
-                  * top_width + w_out;
+    for (int kh = 0; kh < filterHeight; ++kh) {
+      for (int kw = 0; kw < filterWidth; ++kw) {
+        const int h_out_s = h + paddingH - kh;
+        const int w_out_s = w + paddingW - kw;
+        if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) {
+          const int h_out = h_out_s / strideH;
+          const int w_out = w_out_s / strideW;
+	     // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize
+          if ((h_out >= 0) && (h_out < outputHeight)
+                && (w_out >= 0) && (w_out < outputWidth)) {
+            const int offset = ((n * outputChannels + c) * outputHeight + h_out)
+                  * outputWidth + w_out;
             value += (*weight) * top_diff[offset];
           }
         }
@@ -92,32 +114,33 @@ __global__ void ConvolutionDepthwiseBottomBackward(const int nthreads,
 }
 
 template <class T>
-__global__ void ConvolutionDepthwiseWeightBackward(const int num_i, const int nthreads,
-    const T* const top_diff, const T* const bottom_data,
-    const int num, const int channels, const int top_height,
-    const int top_width, const int bottom_height, const int bottom_width,
-    const int kernel_h, const int kernel_w, const int stride_h,
-    const int stride_w, const int pad_h, const int pad_w,
-    const int dilation_h, const int dilation_w, T* const buffer_data) {
+__global__
+void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads,
+    const T* const top_diff, const T* const inputData,
+    const int num, const int outputChannels, const int outputHeight,
+    const int outputWidth, const int inputHeight, const int inputWidth,
+    const int filterHeight, const int filterWidth, const int strideH,
+    const int strideW, const int paddingH, const int paddingW,
+    T* const buffer_data) {
   int index =
     (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
   if (index < nthreads) {
-    const int h = (index / top_width) % top_height;
-    const int w = index % top_width;
-    const int kh = (index / kernel_w / top_height / top_width)
-          % kernel_h;
-    const int kw = (index / top_height / top_width) % kernel_w;
-    const int h_in = -pad_h + h * stride_h + kh * dilation_h;
-    const int w_in = -pad_w + w * stride_w + kw * dilation_w;
-    if ((h_in >= 0) && (h_in < bottom_height)
-          && (w_in >= 0) && (w_in < bottom_width)) {
-      const int c = index / kernel_h / kernel_w / top_height / top_width;
+    const int h = (index / outputWidth) % outputHeight;
+    const int w = index % outputWidth;
+    const int kh = (index / filterWidth / outputHeight / outputWidth)
+          % filterHeight;
+    const int kw = (index / outputHeight / outputWidth) % filterWidth;
+    const int h_in = -paddingH + h * strideH + kh;
+    const int w_in = -paddingW + w * strideW + kw;
+    if ((h_in >= 0) && (h_in < inputHeight)
+          && (w_in >= 0) && (w_in < inputWidth)) {
+      const int c = index / filterHeight / filterWidth / outputHeight / outputWidth;
       const int n = num_i;
-      const int top_offset = ((n * channels + c) * top_height + h)
-            * top_width + w;
-      const int bottom_offset = ((n * channels + c) * bottom_height + h_in)
-            * bottom_width + w_in;
-      buffer_data[index] = top_diff[top_offset] * bottom_data[bottom_offset];
+      const int top_offset = ((n * outputChannels + c) * outputHeight + h)
+            * outputWidth + w;
+      const int bottom_offset = ((n * outputChannels + c) * inputHeight + h_in)
+            * inputWidth + w_in;
+      buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset];
     } else {
       buffer_data[index] = 0;
     }
@@ -134,6 +157,8 @@ public:
             int outputChannels,
             int outputHeight,
             int outputWidth,
+			int inputHeight,
+			int inputWidth,
             int filterHeight,
             int filterWidth,
             int strideH,
@@ -148,7 +173,7 @@ public:
     dim3 threads(1024, 1);
     dim3 grid(blockX, blockY);
     
-    ConvolutionDepthwiseWeightForward<T>
+    ConvolutionDepthwiseForward<T>
         <<< grid, threads, 0, STREAM_DEFAULT >>>(
             outputSize, 
             inputData, 
@@ -157,6 +182,8 @@ public:
             outputChannels,
             outputHeight,
             outputWidth,
+			inputHeight,
+			inputWidth,
             filterHeight,
             filterWidth,
             strideH,
@@ -193,7 +220,7 @@ public:
     dim3 threads(1024, 1);
     dim3 grid(blockX, blockY);
 
-    ConvolutionDepthwiseBottomBackward<T>
+    ConvolutionDepthwiseInputBackward<T>
           // NOLINT_NEXT_LINE(whitespace/operators)
         <<< grid, threads, 0, STREAM_DEFAULT >>>(
             inputSize,
@@ -244,10 +271,10 @@ public:
         dim3 threads(1024, 1);
         dim3 grid(blockX, blockY);
 
-	    ConvolutionDepthwiseWeightBackward<T>
+	    ConvolutionDepthwiseFilterBackward<T>
             <<< grid, threads, 0, STREAM_DEFAULT >>>(
-                i,
-                size,
+                num_i,
+                colDataSize,
                 outputGrad,
                 inputData,
                 batchSize,
@@ -264,8 +291,8 @@ public:
                 paddingW,
                 colData
             );
-        GemmFunctor<Device, real> gemm;
-        int M = size / outputHeight / outputWidth;
+        GemmFunctor<DEVICE_TYPE_GPU, real> gemm;
+        int M = colDataSize / outputHeight / outputWidth;
         int N = 1;
         int K = outputHeight * outputWidth;
         gemm(CblasNoTrans,
@@ -273,23 +300,25 @@ public:
             M,
             N,
             K,
-            1.0f,
+            (T)1.0,
             colData,
             K,
             multiplierData,
             N,
-            1.0f,
+            (T)1.0,
             filterGrad,
             N);
         //gemv
     }
 };
 
-template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, float>;
-template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, double>;
-template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, float>;
-template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, double>;
-template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, float>;
-template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, double>;
+#ifdef PADDLE_TYPE_DOUBLE
+using real=double;
+#else 
+using real=float;
+#endif
+template class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, real>;
+template class DepthwiseConvFunctor<DEVICE_TYPE_GPU, real>;
+template class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, real>;
 
 }  // namespace paddle
diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp
index e161d89c38a..765c627c308 100644
--- a/paddle/gserver/layers/ConvBaseLayer.cpp
+++ b/paddle/gserver/layers/ConvBaseLayer.cpp
@@ -21,7 +21,8 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
                          const ParameterMap& parameterMap) {
   /* Initialize the basic parent class */
   Layer::init(layerMap, parameterMap);
-  isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv")
+  isDeconv_ = (config_.type() == "exconv" || config_.type() == "cudnn_conv" ||
+               config_.type() == "depthwise_conv")
                   ? false
                   : true;
 
diff --git a/paddle/gserver/layers/DepthwiseConvLayer.cpp b/paddle/gserver/layers/DepthwiseConvLayer.cpp
index 9df8a9df7cc..f07100d9497 100644
--- a/paddle/gserver/layers/DepthwiseConvLayer.cpp
+++ b/paddle/gserver/layers/DepthwiseConvLayer.cpp
@@ -15,6 +15,7 @@ limitations under the License. */
 #include "DepthwiseConvLayer.h"
 #include "paddle/utils/Logging.h"
 #include "paddle/utils/Stat.h"
+#include <iostream>
 
 namespace paddle {
 
@@ -79,6 +80,7 @@ void DepthwiseConvLayer::forward(PassType passType) {
   Layer::forward(passType);
 
   size_t batchSize = inputLayers_[0]->getOutputValue()->getHeight();
+  // std::cout << "outputSize" << getOutputSize() <<std::endl;
   resetOutput(batchSize, getOutputSize());
 
   // Calculate the shape of the input, output, and filter.
-- 
GitLab