diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp index 06658a28413827840c017a6fcba998edb86d2a6c..3bd8d7cb4c7c63fc4fcdc931cf6935b02dbf0824 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.cpp +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -104,15 +104,21 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { CpuSparseMatrix* tmpIn_s = dynamic_cast(tmpIn.get()); tmpIn_s->copyFrom(*inputV_s); tmpIn_s->rowScale(0, *inputV_s, *oGrad); - latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + latentVectors_->getWGrad()->mul(*tmpIn_s->getTranspose(), *tmpMul_, 1, 1); tmpIn_s->rowScale(0, *x2_s, *oGrad); + + MatrixPtr ones = Matrix::create(1, inputV->getHeight(), false, useGpu_); + ones->zeroMem(); + ones->add(-1); + tmpSum->mul(*ones, *tmpIn_s, 1, 0); } else { tmpIn->rowScale(0, *inputV, *oGrad); latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); tmpIn->rowScale(0, *x2_, *oGrad); + + tmpSum->sumCols(*tmpIn, -1, 0); } - tmpSum->sumCols(*tmpIn, -1, 0); latentVectors_->getWGrad()->addRowScale( 0, *latentVectors_->getW(), *tmpSum_T);