diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 6ae60102b3e431727c0954e8b8073bfe0534f8ee..3159026e6b92355ba7480b09535388c969a504e2 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -518,7 +518,7 @@ void TrainerThread::computeThread() { backward(); break; case MultiGradientMachine::TASK_COPY_IN_ARGS: - copyInArgs(); + batchSize_ = copyInArgs(); inArgsCopied_ = true; multiMachine_->waitForCopyInArgs(); break;