提交 101a9a4e 编写于 作者: H hedaoyuan 提交者: GitHub

Merge pull request #1566 from hedaoyuan/multi-gradient-machine-error

Fix MultiGradientMachine error
...@@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const { ...@@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const {
void MultiGradientMachine::eval(Evaluator* evaluator) const { void MultiGradientMachine::eval(Evaluator* evaluator) const {
for (auto& thread : threads_) { for (auto& thread : threads_) {
SetDevice device(thread->getDeviceId()); SetDevice device(thread->getDeviceId());
thread->getGradientMachine()->eval(evaluator); if (thread->hasInputData()) {
thread->getGradientMachine()->eval(evaluator);
}
} }
} }
...@@ -356,14 +358,19 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs, ...@@ -356,14 +358,19 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
REGISTER_TIMER("waitOutArgs"); REGISTER_TIMER("waitOutArgs");
thread->waitOutArgsReady(); thread->waitOutArgsReady();
} }
outArgs_.resize(threads_[0]->getOutArgs().size());
outArgs_.resize(threads_[threads_.size() - 1]->getOutArgs().size());
REGISTER_TIMER("copyOutArgs"); REGISTER_TIMER("copyOutArgs");
for (size_t i = 0; i < outArgs_.size(); ++i) { for (size_t i = 0; i < outArgs_.size(); ++i) {
std::vector<Argument> args; std::vector<Argument> args;
args.reserve(threads_.size()); args.reserve(threads_.size());
for (auto& thread : threads_) { for (auto& thread : threads_) {
args.push_back(thread->getOutArgs()[i]); // If the thread input is empty, then the output is empty.
auto tmp = thread->getOutArgs();
if (tmp.size() > 0) {
args.push_back(tmp[i]);
}
} }
outArgs_[i].concat(args, useGpu_, outArgStream_, passType); outArgs_[i].concat(args, useGpu_, outArgStream_, passType);
} }
...@@ -534,7 +541,7 @@ void TrainerThread::prefetch() { ...@@ -534,7 +541,7 @@ void TrainerThread::prefetch() {
void TrainerThread::forward() { void TrainerThread::forward() {
if (!inArgsCopied_) { if (!inArgsCopied_) {
REGISTER_TIMER("copyInArgs"); REGISTER_TIMER("copyInArgs");
copyInArgs(); batchSize_ = copyInArgs();
} else { } else {
inArgsCopied_ = false; inArgsCopied_ = false;
} }
...@@ -564,7 +571,12 @@ void TrainerThread::forward() { ...@@ -564,7 +571,12 @@ void TrainerThread::forward() {
{ {
REGISTER_TIMER("thread_forward"); REGISTER_TIMER("thread_forward");
gradientMachine_->forward(inArgs_, &outArgs_, multiMachine_->getPassType()); if (batchSize_ > 0) {
gradientMachine_->forward(
inArgs_, &outArgs_, multiMachine_->getPassType());
} else {
outArgs_.clear();
}
} }
outArgsReadySem_.post(); outArgsReadySem_.post();
} }
...@@ -574,7 +586,13 @@ void TrainerThread::backward() { ...@@ -574,7 +586,13 @@ void TrainerThread::backward() {
if (multiMachine_->isPassGrad()) { if (multiMachine_->isPassGrad()) {
copyOutputGrad(); copyOutputGrad();
} }
gradientMachine_->backward(backwardCallback_); if (batchSize_ > 0) {
gradientMachine_->backward(backwardCallback_);
} else {
for (size_t i = parameters_.size(); i > 0; i--) {
backwardCallback(parameters_[i - 1].get());
}
}
if (multiMachine_->hasNonstaticCpuParamters()) { if (multiMachine_->hasNonstaticCpuParamters()) {
mergeCpuGradients(); mergeCpuGradients();
} }
...@@ -732,7 +750,7 @@ void TrainerThread::notifyValueReady(int paramId) { ...@@ -732,7 +750,7 @@ void TrainerThread::notifyValueReady(int paramId) {
notifyValueDispatch(paramId); notifyValueDispatch(paramId);
} }
void TrainerThread::copyInArgs() { int TrainerThread::copyInArgs() {
const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs(); const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs();
int numThreads = multiMachine_->getAllThreads().size(); int numThreads = multiMachine_->getAllThreads().size();
int32_t numSequences = fullInArgs[0].getNumSequences(); int32_t numSequences = fullInArgs[0].getNumSequences();
...@@ -748,7 +766,7 @@ void TrainerThread::copyInArgs() { ...@@ -748,7 +766,7 @@ void TrainerThread::copyInArgs() {
} }
if (copySize == 0) { if (copySize == 0) {
return; return 0;
} }
for (size_t i = 0; i < fullInArgs.size(); i++) { for (size_t i = 0; i < fullInArgs.size(); i++) {
...@@ -758,6 +776,7 @@ void TrainerThread::copyInArgs() { ...@@ -758,6 +776,7 @@ void TrainerThread::copyInArgs() {
copySize, copySize,
FLAGS_parallel_nn ? false : multiMachine_->useGpu()); FLAGS_parallel_nn ? false : multiMachine_->useGpu());
} }
return copySize;
} }
void TrainerThread::mergeCpuGradients() { void TrainerThread::mergeCpuGradients() {
......
...@@ -387,6 +387,9 @@ public: ...@@ -387,6 +387,9 @@ public:
/// copy the output gradient from the main GradientMachine. /// copy the output gradient from the main GradientMachine.
void copyOutputGrad(); void copyOutputGrad();
/// Whether the thread has input data.
bool hasInputData() { return batchSize_ != 0; }
protected: protected:
void mergeCpuGradients(); void mergeCpuGradients();
...@@ -407,7 +410,7 @@ protected: ...@@ -407,7 +410,7 @@ protected:
void copyGradToBufferThread(); void copyGradToBufferThread();
void gradCollectThread(); void gradCollectThread();
void copyInArgs(); int copyInArgs();
void forward(); void forward();
void backward(); void backward();
void backwardCallback(Parameter* para); void backwardCallback(Parameter* para);
...@@ -467,6 +470,7 @@ protected: ...@@ -467,6 +470,7 @@ protected:
/// indicate whether inArgs is copied before forward() /// indicate whether inArgs is copied before forward()
bool inArgsCopied_; bool inArgsCopied_;
int batchSize_;
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册