diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index e7db487e7e896183003eb744aa8304f52b09d76c..faadca69abd76e96b31103da8af53eb9e9cddc43 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -358,6 +358,7 @@ void MultiGradientMachine::getOutArgs(std::vector* outArgs, REGISTER_TIMER("waitOutArgs"); thread->waitOutArgsReady(); } + // outArgs_.size() only need to be calculated once. static int size = threads_[threads_.size() - 1]->getOutArgs().size(); outArgs_.resize(size); @@ -574,9 +575,9 @@ void TrainerThread::forward() { REGISTER_TIMER("thread_forward"); if (batchSize_ > 0) { gradientMachine_->forward( - inArgs_, &outArgs_, multiMachine_->getPassType()); + inArgs_, &outArgs_, multiMachine_->getPassType()); } else { - outArgs_.clear(); + outArgs_.clear(); } } outArgsReadySem_.post(); diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index 31bb28b6fedc901e502371b3801e1a5afe32a932..70203bbb97fe79d72fbc6bd2b5d427cb1de7b61f 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -470,7 +470,6 @@ protected: /// indicate whether inArgs is copied before forward() bool inArgsCopied_; - int batchSize_; };