From 53da530d90553ea37ea859314e772c5857a8e8d1 Mon Sep 17 00:00:00 2001 From: Luo Tao Date: Tue, 28 Mar 2017 17:25:37 +0800 Subject: [PATCH] package avg_gpu_backward --- paddle/cuda/include/hl_sequence.h | 6 +++ paddle/cuda/include/stub/hl_sequence_stub.h | 6 +++ paddle/cuda/src/hl_cuda_sequence.cu | 51 +++++++++++++++++++-- paddle/gserver/layers/AverageLayer.cpp | 42 ++--------------- paddle/gserver/layers/AverageLayer.h | 2 - paddle/math/Matrix.cpp | 49 ++++++++++++++++++++ paddle/math/Matrix.h | 8 ++++ paddle/math/tests/test_matrixCompare.cpp | 16 +++++-- 8 files changed, 133 insertions(+), 47 deletions(-) diff --git a/paddle/cuda/include/hl_sequence.h b/paddle/cuda/include/hl_sequence.h index 9f9d8f972e3..973ddcceed9 100644 --- a/paddle/cuda/include/hl_sequence.h +++ b/paddle/cuda/include/hl_sequence.h @@ -159,4 +159,10 @@ extern void hl_sequence_avg_forward(real* dst, int width, const int mode); +extern void hl_sequence_avg_backward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode); #endif /* HL_SEQUENCE_H_ */ diff --git a/paddle/cuda/include/stub/hl_sequence_stub.h b/paddle/cuda/include/stub/hl_sequence_stub.h index 05e51bce9e1..920b417b1c7 100644 --- a/paddle/cuda/include/stub/hl_sequence_stub.h +++ b/paddle/cuda/include/stub/hl_sequence_stub.h @@ -57,4 +57,10 @@ inline void hl_sequence_avg_forward(real* dst, int width, const int mode) {} +inline void hl_sequence_avg_backward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode) {} #endif // HL_SEQUENCE_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_sequence.cu b/paddle/cuda/src/hl_cuda_sequence.cu index ba823de2720..0fe2877f89f 100644 --- a/paddle/cuda/src/hl_cuda_sequence.cu +++ b/paddle/cuda/src/hl_cuda_sequence.cu @@ -325,12 +325,12 @@ __global__ void KeSequenceAvgForward(real* dst, int seqLength = end - start; if (seqLength == 0) return; real sum = 0.0; - for (int i = 0; i < seqLength; i++) { - sum += src[(start + i) * width + col]; + for (int i = start; i < end; i++) { + sum += src[i * width + col]; } sum = mode == 1 ? sum : (mode == 0 ? sum / seqLength : sum * my_rsqrt((real)seqLength)); - dst[row * width + col] = sum; + dst[gid] = sum; } } @@ -354,3 +354,48 @@ void hl_sequence_avg_forward(real* dst, (dst, src, starts, height, width, mode); CHECK_SYNC("hl_sequence_avg_forward failed"); } + +__global__ void KeSequenceAvgBackward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int row = gid / width; + int col = gid % width; + + if (gid < height * width) { + int start = starts[row]; + int end = starts[row + 1]; + int seqLength = end - start; + if (seqLength == 0) return; + real grad = src[gid]; + grad = mode == 1 ? grad : + (mode == 0 ? grad / seqLength : grad * my_rsqrt((real)seqLength)); + for (int i = start; i < end; i++) { + dst[i * width + col] += grad; + } + } +} + +void hl_sequence_avg_backward(real* dst, + real* src, + const int* starts, + int height, + int width, + const int mode) { + CHECK_NOTNULL(dst); + CHECK_NOTNULL(src); + CHECK_NOTNULL(starts); + + int block = 512; + int grid = DIVUP(width * height, 512); + + CHECK(mode == 0 || mode == 1 || mode == 2) + << "mode error in hl_sequence_avg_backward!"; + + KeSequenceAvgBackward<<< grid, block, 0, STREAM_DEFAULT >>> + (dst, src, starts, height, width, mode); + CHECK_SYNC("hl_sequence_avg_backward failed"); +} diff --git a/paddle/gserver/layers/AverageLayer.cpp b/paddle/gserver/layers/AverageLayer.cpp index b8955ab04f2..96cc4288c6f 100644 --- a/paddle/gserver/layers/AverageLayer.cpp +++ b/paddle/gserver/layers/AverageLayer.cpp @@ -26,8 +26,6 @@ bool AverageLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { SequencePoolLayer::init(layerMap, parameterMap); - dataMtx_ = Matrix::create(nullptr, 1, 1, false, useGpu_); - outMtx_ = Matrix::create(nullptr, 1, getSize(), false, useGpu_); // average strategy if (config_.average_strategy() == "average") { mode_ = kAverage; @@ -60,43 +58,9 @@ void AverageLayer::forward(PassType passType) { void AverageLayer::backward(const UpdateCallback& callback) { SequencePoolLayer::backward(callback); - const int* starts = startPositions_->getData(false); - MatrixPtr grad = getInputGrad(0); - - if (grad) { - size_t dim = getSize(); - real* gradientData = getInputGrad(0)->getData(); - real* gradient = getOutputGrad()->getData(); - size_t numSequences = startPositions_->getSize() - 1; - for (size_t sequenceId = 0; sequenceId < numSequences; ++sequenceId) { - // TODO(Dangqingqing) optimization for GPU - int sequenceLength = starts[sequenceId + 1] - starts[sequenceId]; - if (0 == sequenceLength) { - // empty sequence - continue; - } - dataMtx_->setData( - gradientData + starts[sequenceId] * dim, sequenceLength, dim); - outMtx_->setData(gradient + sequenceId * dim); - switch (mode_) { - case kAverage: { - // plain average - dataMtx_->addBias(*outMtx_, 1.0f / sequenceLength); - break; - } - case kSum: { - // sum instead of average - dataMtx_->addBias(*outMtx_, 1.0f); - break; - } - case kAverageSquareRootN: { - // divide by square root of sequenceLength - dataMtx_->addBias(*outMtx_, 1.0f / sqrt(sequenceLength)); - break; - } - default: { LOG(FATAL) << "should not reach here"; } - } - } + if (getInputGrad(0)) { + getInputGrad(0)->sequenceAvgBackward( + *getOutputGrad(), *startPositions_->getVector(useGpu_), mode_); } } diff --git a/paddle/gserver/layers/AverageLayer.h b/paddle/gserver/layers/AverageLayer.h index 621e1d7bb12..332552a3047 100644 --- a/paddle/gserver/layers/AverageLayer.h +++ b/paddle/gserver/layers/AverageLayer.h @@ -45,8 +45,6 @@ public: void backward(const UpdateCallback& callback = nullptr) override; protected: - MatrixPtr outMtx_; - MatrixPtr dataMtx_; int mode_; }; } // namespace paddle diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 9eead5b62c6..5f30a15f2eb 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -483,6 +483,20 @@ void GpuMatrix::sequenceAvgForward(Matrix& a, hl_sequence_avg_forward(dst, src, starts, height, width, mode); } +void GpuMatrix::sequenceAvgBackward(Matrix& a, + const IVector& startsPos, + int mode) { + size_t height = a.getHeight(); + size_t width = getWidth(); + CHECK_EQ(height, startsPos.getSize() - 1); + CHECK_EQ(width, a.getWidth()); + real* dst = getData(); + real* src = a.getData(); + const int* starts = startsPos.getData(); + + hl_sequence_avg_backward(dst, src, starts, height, width, mode); +} + /* this = scaleAB*(a*b) + scaleT*this */ void GpuMatrix::mul(const GpuMatrix& a, const GpuMatrix& b, @@ -2304,6 +2318,41 @@ void CpuMatrix::sequenceAvgForward(Matrix& a, } } +void CpuMatrix::sequenceAvgBackward(Matrix& a, + const IVector& startsPos, + int mode) { + size_t height = a.getHeight(); + size_t width = getWidth(); + CHECK_EQ(height, startsPos.getSize() - 1); + CHECK_EQ(width, a.getWidth()); + real* dst = getData(); + real* src = a.getData(); + const int* starts = startsPos.getData(); + MatrixPtr outMtx = Matrix::create(nullptr, 1, width, false, false); + MatrixPtr dataMtx = Matrix::create(nullptr, 1, width, false, false); + for (size_t i = 0; i < height; ++i) { + int sequenceLength = starts[i + 1] - starts[i]; + if (0 == sequenceLength) { + // empty sequence + continue; + } + outMtx->setData(dst + starts[i] * width, sequenceLength, width); + dataMtx->setData(src + i * width); + if (mode == 0) { + // plain average + outMtx->addBias(*dataMtx, 1.0f / sequenceLength); + } else if (mode == 1) { + // sum instead of average + outMtx->addBias(*dataMtx, 1.0f); + } else if (mode == 2) { + // divide by square root of sequenceLength + outMtx->addBias(*dataMtx, 1.0f / std::sqrt(sequenceLength)); + } else { + LOG(FATAL) << "should not reach here"; + } + } +} + /* this = scaleAB*(a*b) + scaleT*this*/ void CpuMatrix::mul(const Matrix& a, const Matrix& b, diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index dbdb6296145..3252adb19e4 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -461,6 +461,12 @@ public: LOG(FATAL) << "Not implemented"; } + virtual void sequenceAvgBackward(Matrix& a, + const IVector& startsPos, + int mode) { + LOG(FATAL) << "Not implemented"; + } + /** * @code * this = scaleAB*(a*b) + scaleT*this @@ -1203,6 +1209,7 @@ public: void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); + void sequenceAvgBackward(Matrix& a, const IVector& startsPos, int mode); /** * @code @@ -1619,6 +1626,7 @@ public: void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); + void sequenceAvgBackward(Matrix& a, const IVector& startsPos, int mode); /** * @code diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 08b64c1bb6f..dd19fe516fb 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -685,7 +685,7 @@ TEST(SMatrix, topK) { } } -void testMatrixSequenceAvgForward(int batchSize, int inputDim, int mode) { +void testMatrixSequenceAvg(int batchSize, int inputDim, int mode) { MatrixPtr cpuInput = std::make_shared(batchSize, inputDim); MatrixPtr gpuInput = std::make_shared(batchSize, inputDim); cpuInput->randomizeUniform(); @@ -706,15 +706,25 @@ void testMatrixSequenceAvgForward(int batchSize, int inputDim, int mode) { gpuOutput->sequenceAvgForward(*gpuInput, *gpuSequence, mode); TensorCheckErr(*cpuOutput, *gpuOutput); + + MatrixPtr cpuInGrad = std::make_shared(batchSize, inputDim); + MatrixPtr gpuInGrad = std::make_shared(batchSize, inputDim); + cpuInGrad->randomizeUniform(); + gpuInGrad->copyFrom(*cpuInGrad); + + cpuInGrad->sequenceAvgBackward(*cpuOutput, *cpuSequence, mode); + gpuInGrad->sequenceAvgBackward(*gpuOutput, *gpuSequence, mode); + + TensorCheckErr(*cpuInGrad, *gpuInGrad); } -TEST(Matrix, sequenceAvgForward) { +TEST(Matrix, sequenceAvg) { for (auto batchSize : {10, 128, 6000}) { for (auto inputDim : {32, 100, 512}) { for (auto mode : {0, 1, 2}) { VLOG(3) << " batchSize=" << batchSize << " inputDim=" << inputDim << " mode=" << mode; - testMatrixSequenceAvgForward(batchSize, inputDim, mode); + testMatrixSequenceAvg(batchSize, inputDim, mode); } } } -- GitLab