diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 8f2c0c4cb8ab4a5e380055b831852325bbdbe3b5..9ad1785fbb47f423622f753a5b92e9a196b846cd 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -89,11 +89,13 @@ public: protected: std::vector strides_; std::vector paddings_; + /// Group size, refer to grouped convolution in /// Alex Krizhevsky's paper: when group=2, the first half of the /// filters are only connected to the first half of the input channels, /// and the second half only connected to the second half. size_t groups_; + inline int strideH() const { return strides_[0]; } inline int strideW() const { return strides_[1]; } @@ -101,6 +103,20 @@ protected: inline int paddingH() const { return paddings_[0]; } inline int paddingW() const { return paddings_[1]; } + + // A temporary memory in convolution calculation. + MemoryHandlePtr memory_; + + template + void resizeBuffer(size_t newSize) { + if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) { + if (Device == DEVICE_TYPE_CPU) { + memory_ = std::make_shared(newSize * sizeof(real)); + } else { + memory_ = std::make_shared(newSize * sizeof(real)); + } + } + } }; } // namespace paddle diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 109ed20ab0666815f922fda1433c68af5540e5f0..6b5db1d62ed4030a984eb467d4b946dbdb712dd4 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -110,7 +110,7 @@ public: size_t size = inputChannels / groups_ * filterHeight * filterWidth * outputHeight * outputWidth; - resizeBuffer(size); + resizeBuffer(size); real* colData = reinterpret_cast(memory_->getBuf()); Im2ColFunctor im2col; @@ -120,7 +120,7 @@ public: (outputChannels / groups_) * outputHeight * outputWidth; size_t filterOffset = inputs[1].shape().getElements() / groups_; for (size_t i = 0; i < batchSize; i++) { - for (int g = 0; g < groups_; g++) { + for (size_t g = 0; g < groups_; g++) { im2col(inputData + g * inputOffset, inputChannels / groups_, inputHeight, @@ -138,7 +138,9 @@ public: int M = outputChannels / groups_; int N = outputHeight * outputWidth; int K = inputChannels / groups_ * filterHeight * filterWidth; - gemm(M, + gemm(CblasNoTrans, + CblasNoTrans, + M, N, K, 1.0f, @@ -154,19 +156,6 @@ public: outputData += outputChannels * outputHeight * outputWidth; } } - - void resizeBuffer(size_t newSize) { - if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) { - if (Device == DEVICE_TYPE_CPU) { - memory_ = std::make_shared(newSize * sizeof(real)); - } else { - memory_ = std::make_shared(newSize * sizeof(real)); - } - } - } - -private: - MemoryHandlePtr memory_; }; /* @@ -202,10 +191,73 @@ public: void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { CHECK_EQ(numInputs_, inputs.size()); CHECK_EQ(numOutputs_, outputs.size()); - const TensorShape& outputGrad = inputs[0].shape(); + CHECK_EQ(outputs[0].getArgType(), ASSIGN_TO); + const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); - const TensorShape& filterGrad = outputs[0].shape(); - check(input, filterGrad, outputGrad); + const TensorShape& filter = outputs[0].shape(); + check(input, filter, output); + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = filter[2]; + size_t filterWidth = filter[3]; + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* outputGrad = inputs[0].data(); + real* inputData = inputs[1].data(); + real* filterGrad = outputs[0].data(); + + size_t size = inputChannels / groups_ * filterHeight * filterWidth * + outputHeight * outputWidth; + resizeBuffer(size); + real* colData = reinterpret_cast(memory_->getBuf()); + + Im2ColFunctor im2col; + GemmFunctor gemm; + size_t inputOffset = (inputChannels / groups_) * inputHeight * inputWidth; + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + im2col(inputData + g * inputOffset, + inputChannels / groups_, + inputHeight, + inputWidth, + filterHeight, + filterWidth, + strideH(), + strideW(), + paddingH(), + paddingW(), + outputHeight, + outputWidth, + colData); + + int M = outputChannels / groups_; + int K = outputHeight * outputWidth; + int N = inputChannels / groups_ * filterHeight * filterWidth; + gemm(CblasNoTrans, + CblasTrans, + M, + N, + K, + 1.0f, + outputGrad + g * outputOffset, + K, + colData, + K, + 1.0f, + filterGrad + g * filterOffset, + N); + } + } + inputData += inputChannels * inputHeight * inputWidth; + outputGrad += outputChannels * outputHeight * outputWidth; } }; diff --git a/paddle/function/GemmFunctor.h b/paddle/function/GemmFunctor.h index 5fb2f8a6d9e8fde7273e2e0b0de1101fe5989eb2..d5db5cf5e7a855d89b262fe8cf42aa2c55f419f1 100644 --- a/paddle/function/GemmFunctor.h +++ b/paddle/function/GemmFunctor.h @@ -26,7 +26,9 @@ namespace paddle { template class GemmFunctor { public: - void operator()(const int M, + void operator()(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const T alpha, @@ -42,7 +44,9 @@ public: template class GemmFunctor { public: - void operator()(const int M, + void operator()(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const T alpha, @@ -53,26 +57,16 @@ public: const T beta, T* C, const int ldc) { - gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); + gemm(transA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } }; template class GemmFunctor { public: - void operator()(const int M, + void operator()(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const T alpha, @@ -84,9 +78,9 @@ public: T* C, const int ldc) { hl_matrix_mul((T*)A, - HPPL_OP_N, + transA == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, (T*)B, - HPPL_OP_N, + TransB == CblasNoTrans ? HPPL_OP_N : HPPL_OP_T, C, M, N,