提交 74a699a7 编写于 作者: W wangmeng28

change clone to resizeOrCreate in fm layer

上级 13ec6f99
......@@ -58,16 +58,22 @@ void FactorizationMachineLayer::forward(PassType passType) {
inputMulFactor_, batchSize, factorSize_, false, useGpu_);
Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_);
REGISTER_TIMER_INFO("InputMulFactorTimer", getName().c_str());
REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str());
inputMulFactor_->mul(*inputV, *latentVectors_->getW());
inputMulFactor_->square2(*tmpOut_);
outV->sumRows(*tmpOut_, 0.5, 0);
inputSquare_ = inputV->clone(0, 0, useGpu_);
if (dynamic_cast<CpuSparseMatrix*>(inputSquare_.get())) {
if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
Matrix::resizeOrCreateSparseMatrix(inputSquare_,
inputV->getHeight(),
inputV->getWidth(),
inputV->getElementCnt(),
inputV->getValueType());
inputSquare_->copyFrom(*inputV);
(dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()))->square2();
} else {
Matrix::resizeOrCreate(
inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
inputV->square2(*inputSquare_);
}
latentVectors_->getW()->square2(*latentVectorsSquare_);
......@@ -75,7 +81,7 @@ void FactorizationMachineLayer::forward(PassType passType) {
outV->sumRows(*tmpOut_, -0.5, 1.0);
/* activation */ {
REGISTER_TIMER_INFO("FmAtvTimer", getName().c_str());
REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str());
forwardActivation();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册