diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp index e5c9d1a90d5bc6d6ee22b0f5f4fd366a38eecdc2..06658a28413827840c017a6fcba998edb86d2a6c 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.cpp +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -62,7 +62,12 @@ void FactorizationMachineLayer::forward(PassType passType) { outV->sumRows(*tmpOut_, 0.5, 0); x2_ = inputV->clone(0, 0, useGpu_); - inputV->square2(*x2_); + if (dynamic_cast(x2_.get())) { + x2_->copyFrom(*inputV); + (dynamic_cast(x2_.get()))->square2(); + } else { + inputV->square2(*x2_); + } latentVectors_->getW()->square2(*v2_); tmpOut_->mul(*x2_, *v2_); outV->sumRows(*tmpOut_, -0.5, 1.0); @@ -93,11 +98,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { /* Calculate the gradients of the latentVectors_ matrix */ if (latentVectors_->getWGrad()) { MatrixPtr tmpIn = inputV->clone(0, 0, useGpu_); - tmpIn->rowScale(0, *inputV, *oGrad); - - latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + if (dynamic_cast(inputV.get())) { + CpuSparseMatrix* inputV_s = dynamic_cast(inputV.get()); + CpuSparseMatrix* x2_s = dynamic_cast(x2_.get()); + CpuSparseMatrix* tmpIn_s = dynamic_cast(tmpIn.get()); + tmpIn_s->copyFrom(*inputV_s); + tmpIn_s->rowScale(0, *inputV_s, *oGrad); + latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + tmpIn_s->rowScale(0, *x2_s, *oGrad); + } else { + tmpIn->rowScale(0, *inputV, *oGrad); + latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); + tmpIn->rowScale(0, *x2_, *oGrad); + } - tmpIn->rowScale(0, *x2_, *oGrad); tmpSum->sumCols(*tmpIn, -1, 0); latentVectors_->getWGrad()->addRowScale( 0, *latentVectors_->getW(), *tmpSum_T);