提交 74a699a7 编写于 作者: W wangmeng28

change clone to resizeOrCreate in fm layer

上级 13ec6f99
...@@ -58,16 +58,22 @@ void FactorizationMachineLayer::forward(PassType passType) { ...@@ -58,16 +58,22 @@ void FactorizationMachineLayer::forward(PassType passType) {
inputMulFactor_, batchSize, factorSize_, false, useGpu_); inputMulFactor_, batchSize, factorSize_, false, useGpu_);
Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_); Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_);
REGISTER_TIMER_INFO("InputMulFactorTimer", getName().c_str()); REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str());
inputMulFactor_->mul(*inputV, *latentVectors_->getW()); inputMulFactor_->mul(*inputV, *latentVectors_->getW());
inputMulFactor_->square2(*tmpOut_); inputMulFactor_->square2(*tmpOut_);
outV->sumRows(*tmpOut_, 0.5, 0); outV->sumRows(*tmpOut_, 0.5, 0);
inputSquare_ = inputV->clone(0, 0, useGpu_); if (dynamic_cast<CpuSparseMatrix*>(inputV.get())) {
if (dynamic_cast<CpuSparseMatrix*>(inputSquare_.get())) { Matrix::resizeOrCreateSparseMatrix(inputSquare_,
inputV->getHeight(),
inputV->getWidth(),
inputV->getElementCnt(),
inputV->getValueType());
inputSquare_->copyFrom(*inputV); inputSquare_->copyFrom(*inputV);
(dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()))->square2(); (dynamic_cast<CpuSparseMatrix*>(inputSquare_.get()))->square2();
} else { } else {
Matrix::resizeOrCreate(
inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_);
inputV->square2(*inputSquare_); inputV->square2(*inputSquare_);
} }
latentVectors_->getW()->square2(*latentVectorsSquare_); latentVectors_->getW()->square2(*latentVectorsSquare_);
...@@ -75,7 +81,7 @@ void FactorizationMachineLayer::forward(PassType passType) { ...@@ -75,7 +81,7 @@ void FactorizationMachineLayer::forward(PassType passType) {
outV->sumRows(*tmpOut_, -0.5, 1.0); outV->sumRows(*tmpOut_, -0.5, 1.0);
/* activation */ { /* activation */ {
REGISTER_TIMER_INFO("FmAtvTimer", getName().c_str()); REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str());
forwardActivation(); forwardActivation();
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册