From 11588b36700cc1dd444b524c4cff0d785fe7f769 Mon Sep 17 00:00:00 2001 From: xzl Date: Tue, 18 Jul 2017 22:07:26 +0800 Subject: [PATCH] support inputchannels != outputchannels of depthwiseconv --- paddle/function/DepthwiseConvOp.cpp | 13 ++- paddle/function/DepthwiseConvOp.h | 10 +- paddle/function/DepthwiseConvOpGpu.cu | 117 +++++++++++++----------- paddle/gserver/tests/test_LayerGrad.cpp | 2 +- 4 files changed, 85 insertions(+), 57 deletions(-) diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp index 0ac83f5824..d1430239bc 100644 --- a/paddle/function/DepthwiseConvOp.cpp +++ b/paddle/function/DepthwiseConvOp.cpp @@ -30,6 +30,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -53,6 +54,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -75,6 +77,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -122,6 +125,7 @@ public: size_t outputChannels = output[1]; size_t outputHeight = output[2]; size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; real* inputData = inputs[0].data(); real* filterData = inputs[1].data(); @@ -137,6 +141,7 @@ public: inputChannels, inputHeight, inputWidth, + filterMultiplier, filterHeight, filterWidth, strideH(), @@ -183,6 +188,7 @@ public: size_t outputChannels = output[1]; size_t outputHeight = output[2]; size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; real* outputGrad = inputs[0].data(); real* filterData = inputs[1].data(); @@ -198,6 +204,7 @@ public: inputChannels, inputHeight, inputWidth, + filterMultiplier, filterHeight, filterWidth, strideH(), @@ -243,13 +250,14 @@ public: size_t outputChannels = output[1]; size_t outputHeight = output[2]; size_t outputWidth = output[3]; + size_t filterMultiplier = outputChannels / groups_; real* outputGrad = inputs[0].data(); real* inputData = inputs[1].data(); real* filterGrad = outputs[0].data(); - int size = - inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; + int size = outputChannels * filterHeight * filterWidth * outputHeight * + outputWidth; resizeBuffer(size); real* colData = reinterpret_cast(memory_->getBuf()); @@ -264,6 +272,7 @@ public: inputChannels, inputHeight, inputWidth, + filterMultiplier, filterHeight, filterWidth, strideH(), diff --git a/paddle/function/DepthwiseConvOp.h b/paddle/function/DepthwiseConvOp.h index 2b9bef4cd7..1bf70e52f3 100644 --- a/paddle/function/DepthwiseConvOp.h +++ b/paddle/function/DepthwiseConvOp.h @@ -32,6 +32,7 @@ namespace paddle { * \param[in] inputChannels channels of inputData. * \param[in] inputHeight height of inputData. * \param[in] inputWidth width of inputData.. + * \param[in] filterMultiplier equals to outputChannels/groups_. * \param[in] filterHeight height of filter. * \param[in] filterWidth widht of filter. * \param[in] strideH stride size in height direction. @@ -53,6 +54,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -74,7 +76,8 @@ public: * \param[in] outputWidth width of outputData. * \param[in] inputChannels channels of input data. * \param[in] inputHeight height of inputData. - * \param[in] inputWidth width of inputData.. + * \param[in] inputWidth width of inputData. + * \param[in] filterMultiplier equals to outputChannels/groups_. * \param[in] filterHeight height of filter. * \param[in] filterWidth widht of filter. * \param[in] strideH stride size in height direction. @@ -96,6 +99,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -116,7 +120,8 @@ public: * \param[in] outputWidth width of outputData. * \param[in] inputChannels channels of input data. * \param[in] inputHeight height of inputData. - * \param[in] inputWidth width of inputData.. + * \param[in] inputWidth width of inputData. + * \param[in] filterMultiplier equals to outputChannels/groups_. * \param[in] filterHeight height of filter. * \param[in] filterWidth widht of filter. * \param[in] strideH stride size in height direction. @@ -140,6 +145,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu index 7740b7022d..51aed9ffcf 100644 --- a/paddle/function/DepthwiseConvOpGpu.cu +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -25,7 +25,7 @@ void ConvolutionDepthwiseForward(const int nthreads, const T* const inputData, const T* const filterData, const int batchSize, const int outputChannels, const int outputHeight, const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, - const int filterHeight, const int filterWidth, const int strideH, + const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH, const int strideW, const int paddingH, const int paddingW, T* const outputData) { @@ -33,23 +33,25 @@ void ConvolutionDepthwiseForward(const int nthreads, (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if(index < nthreads) { - const int n = index / outputChannels / outputHeight / outputWidth; - const int c = (index / outputHeight / outputWidth) % outputChannels; - const int h = (index / outputWidth) % outputHeight; - const int w = index % outputWidth; - const T* weight = filterData + c * filterHeight * filterWidth; + const int batch = index / outputChannels / outputHeight / outputWidth; + const int c_out = (index / outputHeight / outputWidth) % outputChannels; + const int h_out = (index / outputWidth) % outputHeight; + const int w_out = index % outputWidth; + + const int c_in = c_out / filterMultiplier; + const T* weight = filterData + c_out * filterHeight * filterWidth; T value = 0; - const int h_in_start = -paddingH + h * strideH; - const int w_in_start = -paddingW + w * strideW; - const int h_in_end = -paddingH + h * strideH + filterHeight - 1; - const int w_in_end = -paddingW + w * strideW + filterWidth - 1; + const int h_in_start = -paddingH + h_out * strideH; + const int w_in_start = -paddingW + w_out * strideW; + const int h_in_end = -paddingH + h_out * strideH + filterHeight - 1; + const int w_in_end = -paddingW + w_out * strideW + filterWidth - 1; if ((h_in_start >= 0) && (h_in_end < inputHeight) &&(w_in_start >= 0) && (w_in_end < inputWidth)) { for (int kh = 0; kh < filterHeight; ++kh) { for (int kw = 0; kw < filterWidth; ++kw) { - const int h_in = -paddingH + h * strideH + kh; - const int w_in = -paddingW + w * strideW + kw; - const int offset = ((n * inputChannels + c) * inputHeight + h_in) + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; + const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in) * inputWidth + w_in; value += (*weight) * inputData[offset]; ++weight; @@ -58,11 +60,11 @@ void ConvolutionDepthwiseForward(const int nthreads, }else{ for (int kh = 0; kh < filterHeight; ++kh) { for (int kw = 0; kw < filterWidth; ++kw) { - const int h_in = -paddingH + h * strideH + kh; - const int w_in = -paddingW + w * strideW + kw; + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) { - const int offset = ((n * outputChannels + c) * inputHeight + h_in) + const int offset = ((batch * inputChannels + c_in) * inputHeight + h_in) * inputWidth + w_in; value += (*weight) * inputData[offset]; } @@ -81,38 +83,42 @@ void ConvolutionDepthwiseInputBackward(const int nthreads, const T* const top_diff, const T* const weight_data, const int num, const int outputChannels, const int outputHeight, const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, - const int filterHeight, const int filterWidth, const int strideH, + const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH, const int strideW, const int paddingH, const int paddingW, T* const bottom_diff) { int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if(index < nthreads) { - const int n = index / inputChannels / inputHeight / inputWidth; - const int c = (index / inputHeight / inputWidth) % inputChannels; - const int h = (index / inputWidth) % inputHeight; - const int w = index % inputWidth; - const T* weight = weight_data + c * filterHeight * filterWidth; + const int batch = index / inputChannels / inputHeight / inputWidth; + const int c_in = (index / inputHeight / inputWidth) % inputChannels; + const int h_in = (index / inputWidth) % inputHeight; + const int w_in = index % inputWidth; + const int c_out_start = c_in * filterMultiplier; T value = 0; - for (int kh = 0; kh < filterHeight; ++kh) { - for (int kw = 0; kw < filterWidth; ++kw) { - const int h_out_s = h + paddingH - kh; - const int w_out_s = w + paddingW - kw; - if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) { - const int h_out = h_out_s / strideH; - const int w_out = w_out_s / strideW; - // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize - if ((h_out >= 0) && (h_out < outputHeight) - && (w_out >= 0) && (w_out < outputWidth)) { - const int offset = ((n * outputChannels + c) * outputHeight + h_out) - * outputWidth + w_out; - value += (*weight) * top_diff[offset]; - } + for(int c_out = c_out_start; c_out < c_out_start + filterMultiplier; c_out ++){ + //weight bixu c_out + const T* weight = weight_data + c_out * filterHeight * filterWidth; + for (int kh = 0; kh < filterHeight; ++kh) { + for (int kw = 0; kw < filterWidth; ++kw) { + const int h_out_s = h_in + paddingH - kh; + const int w_out_s = w_in + paddingW - kw; + if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) { + const int h_out = h_out_s / strideH; + const int w_out = w_out_s / strideW; + // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize + if ((h_out >= 0) && (h_out < outputHeight) + && (w_out >= 0) && (w_out < outputWidth)) { + const int offset = ((batch * outputChannels + c_out) * outputHeight + h_out) + * outputWidth + w_out; + value += (*weight) * top_diff[offset]; + } + } + ++weight; + } } - ++weight; - } } bottom_diff[index] += value; - } + } } // CUDA kernel to compute the depthwise convolution backprop w.r.t filter. @@ -122,26 +128,27 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, const T* const top_diff, const T* const inputData, const int num, const int outputChannels, const int outputHeight, const int outputWidth, const int inputChannels, const int inputHeight, const int inputWidth, - const int filterHeight, const int filterWidth, const int strideH, + const int filterMultiplier, const int filterHeight, const int filterWidth, const int strideH, const int strideW, const int paddingH, const int paddingW, T* const buffer_data) { int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if (index < nthreads) { - const int h = (index / outputWidth) % outputHeight; - const int w = index % outputWidth; + const int h_out = (index / outputWidth) % outputHeight; + const int w_out = index % outputWidth; const int kh = (index / filterWidth / outputHeight / outputWidth) % filterHeight; const int kw = (index / outputHeight / outputWidth) % filterWidth; - const int h_in = -paddingH + h * strideH + kh; - const int w_in = -paddingW + w * strideW + kw; + const int h_in = -paddingH + h_out * strideH + kh; + const int w_in = -paddingW + w_out * strideW + kw; if ((h_in >= 0) && (h_in < inputHeight) && (w_in >= 0) && (w_in < inputWidth)) { - const int c = index / filterHeight / filterWidth / outputHeight / outputWidth; - const int n = num_i; - const int top_offset = ((n * outputChannels + c) * outputHeight + h) - * outputWidth + w; - const int bottom_offset = ((n * inputChannels + c) * inputHeight + h_in) + const int c_out = index / filterHeight / filterWidth / outputHeight / outputWidth; + const int c_in = c_out / filterMultiplier; + const int batch = num_i; + const int top_offset = ((batch * outputChannels + c_out) * outputHeight + h_out) + * outputWidth + w_out; + const int bottom_offset = ((batch * inputChannels + c_in) * inputHeight + h_in) * inputWidth + w_in; buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset]; } else { @@ -162,6 +169,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -190,6 +198,7 @@ public: inputChannels, inputHeight, inputWidth, + filterMultiplier, filterHeight, filterWidth, strideH, @@ -212,6 +221,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -242,6 +252,7 @@ public: inputChannels, inputHeight, inputWidth, + filterMultiplier, filterHeight, filterWidth, strideH, @@ -264,6 +275,7 @@ public: int inputChannels, int inputHeight, int inputWidth, + int filterMultiplier, int filterHeight, int filterWidth, int strideH, @@ -273,14 +285,14 @@ public: T* colData, T* filterGrad){ - int colDataSize = inputChannels * filterHeight * filterWidth * outputHeight * outputWidth; + int colDataSize = outputChannels * filterHeight * filterWidth * outputHeight * outputWidth; size_t blocks = (colDataSize + 1024 -1) / 1024; size_t blockX = 512; size_t blockY = (blocks+512-1)/512; dim3 threads(1024, 1); dim3 grid(blockX, blockY); - BaseMatrix filterGradMatrix(inputChannels * filterHeight * filterWidth, 1, filterGrad, false, true); + BaseMatrix filterGradMatrix(outputChannels * filterHeight * filterWidth, 1, filterGrad, false, true); for(int i = 0; i < batchSize; i++) { ConvolutionDepthwiseFilterBackward @@ -296,6 +308,7 @@ public: inputChannels, inputHeight, inputWidth, + filterMultiplier, filterHeight, filterWidth, strideH, @@ -304,8 +317,8 @@ public: paddingW, colData ); - int M = colDataSize / outputHeight / outputWidth; int K = outputHeight * outputWidth; + int M = colDataSize / K; BaseMatrix colMatrix(M, K, colData, false, true); filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index 50e7a91d3f..2f28cec53e 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -355,7 +355,7 @@ void testDepthwiseConvLayer(const string& type, bool useGpu) { config.layerConfig.set_partial_sum(1); config.layerConfig.set_shared_biases(true); - config.inputDefs.push_back({INPUT_DATA, "layer_0", 2048, 96}); + config.inputDefs.push_back({INPUT_DATA, "layer_0", 2048, 192 / 2}); LayerInputConfig* input = config.layerConfig.add_inputs(); ConvConfig* conv = input->mutable_conv_conf(); conv->set_filter_size(2); -- GitLab