提交 d9062cd9 编写于 作者: W wangmeng28

Add sparse matrix support in factorization machine layer

上级 601c1a35
...@@ -62,7 +62,12 @@ void FactorizationMachineLayer::forward(PassType passType) { ...@@ -62,7 +62,12 @@ void FactorizationMachineLayer::forward(PassType passType) {
outV->sumRows(*tmpOut_, 0.5, 0); outV->sumRows(*tmpOut_, 0.5, 0);
x2_ = inputV->clone(0, 0, useGpu_); x2_ = inputV->clone(0, 0, useGpu_);
if (dynamic_cast<CpuSparseMatrix*>(x2_.get())) {
x2_->copyFrom(*inputV);
(dynamic_cast<CpuSparseMatrix*>(x2_.get()))->square2();
} else {
inputV->square2(*x2_); inputV->square2(*x2_);
}
latentVectors_->getW()->square2(*v2_); latentVectors_->getW()->square2(*v2_);
tmpOut_->mul(*x2_, *v2_); tmpOut_->mul(*x2_, *v2_);
outV->sumRows(*tmpOut_, -0.5, 1.0); outV->sumRows(*tmpOut_, -0.5, 1.0);
...@@ -93,11 +98,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) { ...@@ -93,11 +98,20 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
/* Calculate the gradients of the latentVectors_ matrix */ /* Calculate the gradients of the latentVectors_ matrix */
if (latentVectors_->getWGrad()) { if (latentVectors_->getWGrad()) {
MatrixPtr tmpIn = inputV->clone(0, 0, useGpu_); MatrixPtr tmpIn = inputV->clone(0, 0, useGpu_);
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
CpuSparseMatrix* inputV_s = dynamic_cast<CpuSparseMatrix*>(inputV.get());
CpuSparseMatrix* x2_s = dynamic_cast<CpuSparseMatrix*>(x2_.get());
CpuSparseMatrix* tmpIn_s = dynamic_cast<CpuSparseMatrix*>(tmpIn.get());
tmpIn_s->copyFrom(*inputV_s);
tmpIn_s->rowScale(0, *inputV_s, *oGrad);
latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1);
tmpIn_s->rowScale(0, *x2_s, *oGrad);
} else {
tmpIn->rowScale(0, *inputV, *oGrad); tmpIn->rowScale(0, *inputV, *oGrad);
latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1); latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1);
tmpIn->rowScale(0, *x2_, *oGrad); tmpIn->rowScale(0, *x2_, *oGrad);
}
tmpSum->sumCols(*tmpIn, -1, 0); tmpSum->sumCols(*tmpIn, -1, 0);
latentVectors_->getWGrad()->addRowScale( latentVectors_->getWGrad()->addRowScale(
0, *latentVectors_->getW(), *tmpSum_T); 0, *latentVectors_->getW(), *tmpSum_T);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册