From 6fed6f2079902c86c43161f916c3450094fde6d0 Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Mon, 20 Nov 2017 20:44:52 +0800 Subject: [PATCH] Add support of sparse_binary_vector as input for fm layer --- .../layers/FactorizationMachineLayer.cpp | 20 +++++++++----- .../layers/FactorizationMachineLayer.h | 1 + paddle/math/CpuSparseMatrix.cpp | 26 ++++++++++++++----- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp index f0f1738f305..b665fb6dfc4 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.cpp +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -96,15 +96,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { /* Calculate the gradients of the latentVectors_ matrix */ if (latentVectors_->getWGrad()) { - MatrixPtr tmpInput = inputV->clone(0, 0, useGpu_); if (dynamic_cast(inputV.get())) { + Matrix::resizeOrCreateSparseMatrix(tmpInput_, + inputV->getHeight(), + inputV->getWidth(), + inputV->getElementCnt()); + CpuSparseMatrix* sparseInputV = dynamic_cast(inputV.get()); CpuSparseMatrix* sparseInputSquare = dynamic_cast(inputSquare_.get()); CpuSparseMatrix* sparseTmpInput = - dynamic_cast(tmpInput.get()); + dynamic_cast(tmpInput_.get()); sparseTmpInput->copyFrom(*sparseInputV); + sparseTmpInput->rowScale(0, *sparseInputV, *oGrad); latentVectors_->getWGrad()->mul( *sparseTmpInput->getTranspose(), *inputMulFactor_, 1, 1); @@ -115,12 +120,15 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { negOnes_->add(-1); tmpSum_->mul(*negOnes_, *sparseTmpInput, 1, 0); } else { - tmpInput->rowScale(0, *inputV, *oGrad); + Matrix::resizeOrCreate( + tmpInput_, inputV->getHeight(), inputV->getWidth(), false, useGpu_); + + tmpInput_->rowScale(0, *inputV, *oGrad); latentVectors_->getWGrad()->mul( - *tmpInput->getTranspose(), *inputMulFactor_, 1, 1); - tmpInput->rowScale(0, *inputSquare_, *oGrad); + *tmpInput_->getTranspose(), *inputMulFactor_, 1, 1); + tmpInput_->rowScale(0, *inputSquare_, *oGrad); - tmpSum_->sumCols(*tmpInput, -1, 0); + tmpSum_->sumCols(*tmpInput_, -1, 0); } latentVectors_->getWGrad()->addRowScale( diff --git a/paddle/gserver/layers/FactorizationMachineLayer.h b/paddle/gserver/layers/FactorizationMachineLayer.h index 3bc36daaab3..df20a49934d 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.h +++ b/paddle/gserver/layers/FactorizationMachineLayer.h @@ -61,6 +61,7 @@ private: // Store temporary calculation result MatrixPtr tmpOut_; MatrixPtr tmpSum_; + MatrixPtr tmpInput_; // Negative identity matrix MatrixPtr negOnes_; diff --git a/paddle/math/CpuSparseMatrix.cpp b/paddle/math/CpuSparseMatrix.cpp index 6a432cd16b7..dc6979cf5a5 100644 --- a/paddle/math/CpuSparseMatrix.cpp +++ b/paddle/math/CpuSparseMatrix.cpp @@ -266,13 +266,25 @@ void CpuSparseMatrix::rowScale(size_t cCol, CpuSparseMatrix& b, Matrix& c) { CHECK_EQ(width_, b.getWidth()); real* A = getValue(); real* B = b.getValue(); - for (size_t i = 0; i < height_; i++) { - size_t start = getRowStartIdx(i); - size_t end = getRowStartIdx(i + 1); - CHECK_EQ(start, b.getRowStartIdx(i)); - CHECK_EQ(end, b.getRowStartIdx(i + 1)); - for (size_t j = start; j < end; j++) { - A[j] = B[j] * c.getElement(i, cCol); + if (b.getValueType() == FLOAT_VALUE) { + for (size_t i = 0; i < height_; i++) { + size_t start = getRowStartIdx(i); + size_t end = getRowStartIdx(i + 1); + CHECK_EQ(start, b.getRowStartIdx(i)); + CHECK_EQ(end, b.getRowStartIdx(i + 1)); + for (size_t j = start; j < end; j++) { + A[j] = B[j] * c.getElement(i, cCol); + } + } + } else if (b.getValueType() == NO_VALUE) { + for (size_t i = 0; i < height_; i++) { + size_t start = getRowStartIdx(i); + size_t end = getRowStartIdx(i + 1); + CHECK_EQ(start, b.getRowStartIdx(i)); + CHECK_EQ(end, b.getRowStartIdx(i + 1)); + for (size_t j = start; j < end; j++) { + A[j] = c.getElement(i, cCol); + } } } } -- GitLab