diff --git a/paddle/gserver/layers/FactorizationMachineLayer.cpp b/paddle/gserver/layers/FactorizationMachineLayer.cpp index b665fb6dfc4a0971a0bc67ffe3ddfc60ac6bbbe3..be26b9ba88c279036f73b0a0baaff164755fe067 100644 --- a/paddle/gserver/layers/FactorizationMachineLayer.cpp +++ b/paddle/gserver/layers/FactorizationMachineLayer.cpp @@ -58,16 +58,22 @@ void FactorizationMachineLayer::forward(PassType passType) { inputMulFactor_, batchSize, factorSize_, false, useGpu_); Matrix::resizeOrCreate(tmpOut_, batchSize, factorSize_, false, useGpu_); - REGISTER_TIMER_INFO("InputMulFactorTimer", getName().c_str()); + REGISTER_TIMER_INFO("FmInputMulFactorTimer", getName().c_str()); inputMulFactor_->mul(*inputV, *latentVectors_->getW()); inputMulFactor_->square2(*tmpOut_); outV->sumRows(*tmpOut_, 0.5, 0); - inputSquare_ = inputV->clone(0, 0, useGpu_); - if (dynamic_cast(inputSquare_.get())) { + if (dynamic_cast(inputV.get())) { + Matrix::resizeOrCreateSparseMatrix(inputSquare_, + inputV->getHeight(), + inputV->getWidth(), + inputV->getElementCnt(), + inputV->getValueType()); inputSquare_->copyFrom(*inputV); (dynamic_cast(inputSquare_.get()))->square2(); } else { + Matrix::resizeOrCreate( + inputSquare_, inputV->getHeight(), inputV->getWidth(), false, useGpu_); inputV->square2(*inputSquare_); } latentVectors_->getW()->square2(*latentVectorsSquare_); @@ -75,7 +81,7 @@ void FactorizationMachineLayer::forward(PassType passType) { outV->sumRows(*tmpOut_, -0.5, 1.0); /* activation */ { - REGISTER_TIMER_INFO("FmAtvTimer", getName().c_str()); + REGISTER_TIMER_INFO("FmFwAtvTimer", getName().c_str()); forwardActivation(); } }