diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 123273f916f5d33e2543d9f5f28573c3b5761e28..4654d0206413ec198da62af12e294cd5b442e735 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const { void MultiGradientMachine::eval(Evaluator* evaluator) const { for (auto& thread : threads_) { SetDevice device(thread->getDeviceId()); - thread->getGradientMachine()->eval(evaluator); + if (thread->hasInputData()) { + thread->getGradientMachine()->eval(evaluator); + } } } @@ -356,14 +358,19 @@ void MultiGradientMachine::getOutArgs(std::vector* outArgs, REGISTER_TIMER("waitOutArgs"); thread->waitOutArgsReady(); } - outArgs_.resize(threads_[0]->getOutArgs().size()); + + outArgs_.resize(threads_[threads_.size() - 1]->getOutArgs().size()); REGISTER_TIMER("copyOutArgs"); for (size_t i = 0; i < outArgs_.size(); ++i) { std::vector args; args.reserve(threads_.size()); for (auto& thread : threads_) { - args.push_back(thread->getOutArgs()[i]); + // If the thread input is empty, then the output is empty. + auto tmp = thread->getOutArgs(); + if (tmp.size() > 0) { + args.push_back(tmp[i]); + } } outArgs_[i].concat(args, useGpu_, outArgStream_, passType); } @@ -534,7 +541,7 @@ void TrainerThread::prefetch() { void TrainerThread::forward() { if (!inArgsCopied_) { REGISTER_TIMER("copyInArgs"); - copyInArgs(); + batchSize_ = copyInArgs(); } else { inArgsCopied_ = false; } @@ -564,7 +571,12 @@ void TrainerThread::forward() { { REGISTER_TIMER("thread_forward"); - gradientMachine_->forward(inArgs_, &outArgs_, multiMachine_->getPassType()); + if (batchSize_ > 0) { + gradientMachine_->forward( + inArgs_, &outArgs_, multiMachine_->getPassType()); + } else { + outArgs_.clear(); + } } outArgsReadySem_.post(); } @@ -574,7 +586,13 @@ void TrainerThread::backward() { if (multiMachine_->isPassGrad()) { copyOutputGrad(); } - gradientMachine_->backward(backwardCallback_); + if (batchSize_ > 0) { + gradientMachine_->backward(backwardCallback_); + } else { + for (size_t i = parameters_.size(); i > 0; i--) { + backwardCallback(parameters_[i - 1].get()); + } + } if (multiMachine_->hasNonstaticCpuParamters()) { mergeCpuGradients(); } @@ -732,7 +750,7 @@ void TrainerThread::notifyValueReady(int paramId) { notifyValueDispatch(paramId); } -void TrainerThread::copyInArgs() { +int TrainerThread::copyInArgs() { const std::vector& fullInArgs = multiMachine_->getInArgs(); int numThreads = multiMachine_->getAllThreads().size(); int32_t numSequences = fullInArgs[0].getNumSequences(); @@ -748,7 +766,7 @@ void TrainerThread::copyInArgs() { } if (copySize == 0) { - return; + return 0; } for (size_t i = 0; i < fullInArgs.size(); i++) { @@ -758,6 +776,7 @@ void TrainerThread::copyInArgs() { copySize, FLAGS_parallel_nn ? false : multiMachine_->useGpu()); } + return copySize; } void TrainerThread::mergeCpuGradients() { diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index 838a52b5153af63adbce5788824b9f541f22517c..70203bbb97fe79d72fbc6bd2b5d427cb1de7b61f 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -387,6 +387,9 @@ public: /// copy the output gradient from the main GradientMachine. void copyOutputGrad(); + /// Whether the thread has input data. + bool hasInputData() { return batchSize_ != 0; } + protected: void mergeCpuGradients(); @@ -407,7 +410,7 @@ protected: void copyGradToBufferThread(); void gradCollectThread(); - void copyInArgs(); + int copyInArgs(); void forward(); void backward(); void backwardCallback(Parameter* para); @@ -467,6 +470,7 @@ protected: /// indicate whether inArgs is copied before forward() bool inArgsCopied_; + int batchSize_; }; } // namespace paddle