From d9062cd9ee1297547c16d57c0d5024ceb3555d2f Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Thu, 26 Oct 2017 00:43:47 +0800 Subject: [PATCH] Add sparse matrix support in factorization machine layer --- .../layers/FactorizationMachineLayer.cpp | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp index e5c9d1a90d5..06658a28413 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.cpp +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -62,7 +62,12 @@ void FactorizationMachineLayer::forward(PassType passType) { outV->sumRows(*tmpOut_, 0.5, 0); x2_ = inputV->clone(0, 0, useGpu_); - inputV->square2(*x2_); + if (dynamic_cast(x2_.get())) { + x2_->copyFrom(*inputV); + (dynamic_cast(x2_.get()))->square2(); + } else { + inputV->square2(*x2_); + } latentVectors_->getW()->square2(*v2_); tmpOut_->mul(*x2_, *v2_); outV->sumRows(*tmpOut_, -0.5, 1.0); @@ -93,11 +98,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { /* Calculate the gradients of the latentVectors_ matrix */ if (latentVectors_->getWGrad()) { MatrixPtr tmpIn = inputV->clone(0, 0, useGpu_); - tmpIn->rowScale(0, *inputV, *oGrad); - - latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + if (dynamic_cast(inputV.get())) { + CpuSparseMatrix* inputV_s = dynamic_cast(inputV.get()); + CpuSparseMatrix* x2_s = dynamic_cast(x2_.get()); + CpuSparseMatrix* tmpIn_s = dynamic_cast(tmpIn.get()); + tmpIn_s->copyFrom(*inputV_s); + tmpIn_s->rowScale(0, *inputV_s, *oGrad); + latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + tmpIn_s->rowScale(0, *x2_s, *oGrad); + } else { + tmpIn->rowScale(0, *inputV, *oGrad); + latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + tmpIn->rowScale(0, *x2_, *oGrad); + } - tmpIn->rowScale(0, *x2_, *oGrad); tmpSum->sumCols(*tmpIn, -1, 0); latentVectors_->getWGrad()->addRowScale( 0, *latentVectors_->getW(), *tmpSum_T); -- GitLab