提交 4172fc09 编写于 作者: W wangmeng28

Add sparse input support for factorization machine layer

上级 d9062cd9
......@@ -104,15 +104,21 @@ void FactorizationMachineLayer::backward(const UpdateCallback& callback) {
CpuSparseMatrix* tmpIn_s = dynamic_cast<CpuSparseMatrix*>(tmpIn.get());
tmpIn_s->copyFrom(*inputV_s);
tmpIn_s->rowScale(0, *inputV_s, *oGrad);
latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1);
latentVectors_->getWGrad()->mul(*tmpIn_s->getTranspose(), *tmpMul_, 1, 1);
tmpIn_s->rowScale(0, *x2_s, *oGrad);
MatrixPtr ones = Matrix::create(1, inputV->getHeight(), false, useGpu_);
ones->zeroMem();
ones->add(-1);
tmpSum->mul(*ones, *tmpIn_s, 1, 0);
} else {
tmpIn->rowScale(0, *inputV, *oGrad);
latentVectors_->getWGrad()->mul(*tmpIn->getTranspose(), *tmpMul_, 1, 1);
tmpIn->rowScale(0, *x2_, *oGrad);
}
tmpSum->sumCols(*tmpIn, -1, 0);
}
latentVectors_->getWGrad()->addRowScale(
0, *latentVectors_->getW(), *tmpSum_T);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册