提交 9741ade8 编写于 作者: W wangmeng28

Change pow to square in factorization machine layer

上级 0574915e
......@@ -57,12 +57,12 @@ void FactorizationMachineLayer::forward(PassType passType) {
REGISTER_TIMER_INFO("FwMulTimer", getName().c_str());
tmpMul_->mul(*inputV, *latentVectors_->getW());
tmpOut_->pow2(*tmpMul_, 2);
tmpMul_->square2(*tmpOut_);
outV->sumRows(*tmpOut_, 0.5, 0);
x2_ = inputV->clone(0, 0, useGpu_);
x2_->pow2(*inputV, 2);
v2_->pow2(*latentVectors_->getW(), 2);
inputV->square2(*x2_);
latentVectors_->getW()->square2(*v2_);
tmpOut_->mul(*x2_, *v2_);
outV->sumRows(*tmpOut_, -0.5, 1.0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册