diff --git a/paddle/function/DepthwiseConvOp.cpp b/paddle/function/DepthwiseConvOp.cpp index 31eccda67d56773e5e4d891779e958bf4c317ca8..0ac83f5824b91f4882bb45d085e3e5970e71a520 100644 --- a/paddle/function/DepthwiseConvOp.cpp +++ b/paddle/function/DepthwiseConvOp.cpp @@ -15,7 +15,6 @@ limitations under the License. */ #include "DepthwiseConvOp.h" #include "ConvOp.h" #include "GemmFunctor.h" -//#include "paddle/math/MemoryHandle.h" namespace paddle { @@ -28,6 +27,7 @@ public: int outputChannels, int outputHeight, int outputWidth, + int inputChannels, int inputHeight, int inputWidth, int filterHeight, @@ -114,7 +114,7 @@ public: const TensorShape& output = outputs[0].shape(); size_t batchSize = input[0]; - // size_t inputChannels = input[1]; + size_t inputChannels = input[1]; size_t inputHeight = input[2]; size_t inputWidth = input[3]; size_t filterHeight = getFilterHeight(filter); @@ -134,6 +134,7 @@ public: outputChannels, outputHeight, outputWidth, + inputChannels, inputHeight, inputWidth, filterHeight, @@ -168,8 +169,6 @@ public: CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); check(inputs, outputs); - // Since the implementation of Col2ImFunctor is ADD_TO, - // this function only supports ADD_TO mode. CHECK_EQ(outputs[0].getArgType(), ADD_TO); const TensorShape& output = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); @@ -228,12 +227,11 @@ public: } void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { - // CHECK_EQ(numInputs_, inputs.size()); - // CHECK_EQ(numOutputs_, outputs.size()); + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); check(inputs, outputs); const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); - // const TensorShape& multiplier = inputs[2].shape(); const TensorShape& filter = outputs[0].shape(); size_t batchSize = input[0]; diff --git a/paddle/function/DepthwiseConvOp.h b/paddle/function/DepthwiseConvOp.h index 356ff37c6a968ddd8aa92c4aa98275176b906da0..2b9bef4cd77c47f457e645eaf771e8fb42082040 100644 --- a/paddle/function/DepthwiseConvOp.h +++ b/paddle/function/DepthwiseConvOp.h @@ -29,6 +29,7 @@ namespace paddle { * \param[in] outputChannels channels of outputData. * \param[in] outputHeight height of outputData. * \param[in] outputWidth width of outputData. + * \param[in] inputChannels channels of inputData. * \param[in] inputHeight height of inputData. * \param[in] inputWidth width of inputData.. * \param[in] filterHeight height of filter. @@ -49,8 +50,9 @@ public: int outputChannels, int outputHeight, int outputWidth, + int inputChannels, int inputHeight, - int intputWidth, + int inputWidth, int filterHeight, int filterWidth, int strideH, diff --git a/paddle/function/DepthwiseConvOpGpu.cu b/paddle/function/DepthwiseConvOpGpu.cu index 737f091ab8fa77b01d23d441882353d5b2be4946..7740b7022dbf285df8d067c8a92910f49dba1836 100644 --- a/paddle/function/DepthwiseConvOpGpu.cu +++ b/paddle/function/DepthwiseConvOpGpu.cu @@ -24,7 +24,7 @@ __global__ void ConvolutionDepthwiseForward(const int nthreads, const T* const inputData, const T* const filterData, const int batchSize, const int outputChannels, const int outputHeight, - const int outputWidth, const int inputHeight, const int inputWidth, + const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, const int filterHeight, const int filterWidth, const int strideH, const int strideW, const int paddingH, const int paddingW, T* const outputData) { @@ -39,36 +39,36 @@ void ConvolutionDepthwiseForward(const int nthreads, const int w = index % outputWidth; const T* weight = filterData + c * filterHeight * filterWidth; T value = 0; - const int h_in_start = -paddingH + h * strideH; - const int w_in_start = -paddingW + w * strideW; - const int h_in_end = -paddingH + h * strideH + filterHeight - 1; - const int w_in_end = -paddingW + w * strideW + filterWidth - 1; + const int h_in_start = -paddingH + h * strideH; + const int w_in_start = -paddingW + w * strideW; + const int h_in_end = -paddingH + h * strideH + filterHeight - 1; + const int w_in_end = -paddingW + w * strideW + filterWidth - 1; if ((h_in_start >= 0) && (h_in_end < inputHeight) &&(w_in_start >= 0) && (w_in_end < inputWidth)) { - for (int kh = 0; kh < filterHeight; ++kh) { - for (int kw = 0; kw < filterWidth; ++kw) { - const int h_in = -paddingH + h * strideH + kh; - const int w_in = -paddingW + w * strideW + kw; - const int offset = ((n * outputChannels + c) * inputHeight + h_in) + for (int kh = 0; kh < filterHeight; ++kh) { + for (int kw = 0; kw < filterWidth; ++kw) { + const int h_in = -paddingH + h * strideH + kh; + const int w_in = -paddingW + w * strideW + kw; + const int offset = ((n * inputChannels + c) * inputHeight + h_in) * inputWidth + w_in; - value += (*weight) * inputData[offset]; - ++weight; - } - } - }else{ - for (int kh = 0; kh < filterHeight; ++kh) { - for (int kw = 0; kw < filterWidth; ++kw) { - const int h_in = -paddingH + h * strideH + kh; - const int w_in = -paddingW + w * strideW + kw; - if ((h_in >= 0) && (h_in < inputHeight) - && (w_in >= 0) && (w_in < inputWidth)) { - const int offset = ((n * outputChannels + c) * inputHeight + h_in) - * inputWidth + w_in; - value += (*weight) * inputData[offset]; - } - ++weight; + value += (*weight) * inputData[offset]; + ++weight; } } + }else{ + for (int kh = 0; kh < filterHeight; ++kh) { + for (int kw = 0; kw < filterWidth; ++kw) { + const int h_in = -paddingH + h * strideH + kh; + const int w_in = -paddingW + w * strideW + kw; + if ((h_in >= 0) && (h_in < inputHeight) + && (w_in >= 0) && (w_in < inputWidth)) { + const int offset = ((n * outputChannels + c) * inputHeight + h_in) + * inputWidth + w_in; + value += (*weight) * inputData[offset]; + } + ++weight; + } + } } outputData[index] = value; } @@ -80,15 +80,15 @@ __global__ void ConvolutionDepthwiseInputBackward(const int nthreads, const T* const top_diff, const T* const weight_data, const int num, const int outputChannels, const int outputHeight, - const int outputWidth, const int inputHeight, const int inputWidth, + const int outputWidth,const int inputChannels, const int inputHeight, const int inputWidth, const int filterHeight, const int filterWidth, const int strideH, const int strideW, const int paddingH, const int paddingW, T* const bottom_diff) { int index = (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; if(index < nthreads) { - const int n = index / outputChannels / inputHeight / inputWidth; - const int c = (index / inputHeight / inputWidth) % outputChannels; + const int n = index / inputChannels / inputHeight / inputWidth; + const int c = (index / inputHeight / inputWidth) % inputChannels; const int h = (index / inputWidth) % inputHeight; const int w = index % inputWidth; const T* weight = weight_data + c * filterHeight * filterWidth; @@ -100,7 +100,7 @@ void ConvolutionDepthwiseInputBackward(const int nthreads, if (((h_out_s % strideH) == 0) && ((w_out_s % strideW) == 0)) { const int h_out = h_out_s / strideH; const int w_out = w_out_s / strideW; - // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize + // TODO(zhaolong) : the 'if' affect the effectiveness, it needs to optimize if ((h_out >= 0) && (h_out < outputHeight) && (w_out >= 0) && (w_out < outputWidth)) { const int offset = ((n * outputChannels + c) * outputHeight + h_out) @@ -121,7 +121,7 @@ __global__ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, const T* const top_diff, const T* const inputData, const int num, const int outputChannels, const int outputHeight, - const int outputWidth, const int inputHeight, const int inputWidth, + const int outputWidth, const int inputChannels, const int inputHeight, const int inputWidth, const int filterHeight, const int filterWidth, const int strideH, const int strideW, const int paddingH, const int paddingW, T* const buffer_data) { @@ -141,7 +141,7 @@ void ConvolutionDepthwiseFilterBackward(const int num_i, const int nthreads, const int n = num_i; const int top_offset = ((n * outputChannels + c) * outputHeight + h) * outputWidth + w; - const int bottom_offset = ((n * outputChannels + c) * inputHeight + h_in) + const int bottom_offset = ((n * inputChannels + c) * inputHeight + h_in) * inputWidth + w_in; buffer_data[index] = top_diff[top_offset] * inputData[bottom_offset]; } else { @@ -159,6 +159,7 @@ public: int outputChannels, int outputHeight, int outputWidth, + int inputChannels, int inputHeight, int inputWidth, int filterHeight, @@ -186,6 +187,7 @@ public: outputChannels, outputHeight, outputWidth, + inputChannels, inputHeight, inputWidth, filterHeight, @@ -218,7 +220,7 @@ public: int paddingW, T* inputGrad){ - int inputSize = batchSize * inputChannels * inputHeight * inputWidth; + int inputSize = batchSize * inputChannels * inputHeight * inputWidth; size_t blocks = (inputSize + 1024 -1) / 1024; size_t blockX = 512; @@ -237,6 +239,7 @@ public: outputChannels, outputHeight, outputWidth, + inputChannels, inputHeight, inputWidth, filterHeight, @@ -277,11 +280,11 @@ public: size_t blockY = (blocks+512-1)/512; dim3 threads(1024, 1); dim3 grid(blockX, blockY); - BaseMatrix filterGradMatrix(inputChannels * filterHeight * filterWidth, 1, filterGrad, false, true); + BaseMatrix filterGradMatrix(inputChannels * filterHeight * filterWidth, 1, filterGrad, false, true); for(int i = 0; i < batchSize; i++) { - ConvolutionDepthwiseFilterBackward - <<< grid, threads, 0, STREAM_DEFAULT >>>( + ConvolutionDepthwiseFilterBackward + <<< grid, threads, 0, STREAM_DEFAULT >>>( i, colDataSize, outputGrad, @@ -290,6 +293,7 @@ public: outputChannels, outputHeight, outputWidth, + inputChannels, inputHeight, inputWidth, filterHeight, @@ -299,12 +303,12 @@ public: paddingH, paddingW, colData - ); - int M = colDataSize / outputHeight / outputWidth; - int K = outputHeight * outputWidth; + ); + int M = colDataSize / outputHeight / outputWidth; + int K = outputHeight * outputWidth; BaseMatrix colMatrix(M, K, colData, false, true); - filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); + filterGradMatrix.sumRows(colMatrix, (T)1.0, (T)1.0); } } };