From 4172fc09c39b61c3cb1933687680bab15153b59f Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Wed, 1 Nov 2017 21:51:23 +0800 Subject: [PATCH] Add sparse input support for factorization machine layer --- paddle/gserver/layers/FactorizationMachineLayer.cpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp index 06658a284..3bd8d7cb4 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.cpp +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -104,15 +104,21 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { CpuSparseMatrix* tmpIn_s = dynamic_cast(tmpIn.get()); tmpIn_s->copyFrom(*inputV_s); tmpIn_s->rowScale(0, *inputV_s, *oGrad); - latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + latentVectors_->getWGrad()->mul(*tmpIn_s->getTranspose(), *tmpMul_, 1, 1); tmpIn_s->rowScale(0, *x2_s, *oGrad); + + MatrixPtr ones = Matrix::create(1, inputV->getHeight(), false, useGpu_); + ones->zeroMem(); + ones->add(-1); + tmpSum->mul(*ones, *tmpIn_s, 1, 0); } else { tmpIn->rowScale(0, *inputV, *oGrad); latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); tmpIn->rowScale(0, *x2_, *oGrad); + + tmpSum->sumCols(*tmpIn, -1, 0); } - tmpSum->sumCols(*tmpIn, -1, 0); latentVectors_->getWGrad()->addRowScale( 0, *latentVectors_->getW(), *tmpSum_T); -- GitLab