diff --git a/paddle/gserver/activations/MKLDNNActivation.cpp b/paddle/gserver/activations/MKLDNNActivation.cpp index 3b32d7e2d83f4c17725d4bbcb3303ed9721b0527..18c5638100065109fba1f0647a1c5f91256f7b9d 100644 --- a/paddle/gserver/activations/MKLDNNActivation.cpp +++ b/paddle/gserver/activations/MKLDNNActivation.cpp @@ -189,27 +189,19 @@ Error __must_check MKLDNNSoftmaxActivation::forward(Argument& act) { Error __must_check MKLDNNSoftmaxActivation::backward(Argument& act) { MatrixPtr outputV = act.value; MatrixPtr outputG = act.grad; - - if (outputG->useGpu()) { - outputG->softmaxBackward(*outputV); - } else { - SetDevice device(act.deviceId); - Matrix::resizeOrCreate(sftMaxDot_, - outputG->getHeight(), - outputG->getWidth(), - /* trans */ false, - useGpu(act.deviceId)); - Matrix::resizeOrCreate(sftMaxSum_, - outputG->getHeight(), - 1, - /* trans */ false, - useGpu(act.deviceId)); - - sftMaxDot_->dotMul(*outputG, *outputV); - sftMaxSum_->colMerge(*sftMaxDot_); - - act.grad->softmaxDerivative(*act.value, *sftMaxSum_); - } + Matrix::resizeOrCreate(sftMaxDot_, + outputG->getHeight(), + outputG->getWidth(), + /* trans */ false, + /* useGpu */ false); + Matrix::resizeOrCreate(sftMaxSum_, + outputG->getHeight(), + 1, + /* trans */ false, + /* useGpu */ false); + sftMaxDot_->dotMul(*outputG, *outputV); + sftMaxSum_->colMerge(*sftMaxDot_); + act.grad->softmaxDerivative(*act.value, *sftMaxSum_); return Error(); }