提交 101a9a4e 编写于 作者: H hedaoyuan 提交者: GitHub

Merge pull request #1566 from hedaoyuan/multi-gradient-machine-error

Fix MultiGradientMachine error
......@@ -346,7 +346,9 @@ Evaluator* MultiGradientMachine::makeEvaluator() const {
void MultiGradientMachine::eval(Evaluator* evaluator) const {
for (auto& thread : threads_) {
SetDevice device(thread->getDeviceId());
thread->getGradientMachine()->eval(evaluator);
if (thread->hasInputData()) {
thread->getGradientMachine()->eval(evaluator);
}
}
}
......@@ -356,14 +358,19 @@ void MultiGradientMachine::getOutArgs(std::vector<Argument>* outArgs,
REGISTER_TIMER("waitOutArgs");
thread->waitOutArgsReady();
}
outArgs_.resize(threads_[0]->getOutArgs().size());
outArgs_.resize(threads_[threads_.size() - 1]->getOutArgs().size());
REGISTER_TIMER("copyOutArgs");
for (size_t i = 0; i < outArgs_.size(); ++i) {
std::vector<Argument> args;
args.reserve(threads_.size());
for (auto& thread : threads_) {
args.push_back(thread->getOutArgs()[i]);
// If the thread input is empty, then the output is empty.
auto tmp = thread->getOutArgs();
if (tmp.size() > 0) {
args.push_back(tmp[i]);
}
}
outArgs_[i].concat(args, useGpu_, outArgStream_, passType);
}
......@@ -534,7 +541,7 @@ void TrainerThread::prefetch() {
void TrainerThread::forward() {
if (!inArgsCopied_) {
REGISTER_TIMER("copyInArgs");
copyInArgs();
batchSize_ = copyInArgs();
} else {
inArgsCopied_ = false;
}
......@@ -564,7 +571,12 @@ void TrainerThread::forward() {
{
REGISTER_TIMER("thread_forward");
gradientMachine_->forward(inArgs_, &outArgs_, multiMachine_->getPassType());
if (batchSize_ > 0) {
gradientMachine_->forward(
inArgs_, &outArgs_, multiMachine_->getPassType());
} else {
outArgs_.clear();
}
}
outArgsReadySem_.post();
}
......@@ -574,7 +586,13 @@ void TrainerThread::backward() {
if (multiMachine_->isPassGrad()) {
copyOutputGrad();
}
gradientMachine_->backward(backwardCallback_);
if (batchSize_ > 0) {
gradientMachine_->backward(backwardCallback_);
} else {
for (size_t i = parameters_.size(); i > 0; i--) {
backwardCallback(parameters_[i - 1].get());
}
}
if (multiMachine_->hasNonstaticCpuParamters()) {
mergeCpuGradients();
}
......@@ -732,7 +750,7 @@ void TrainerThread::notifyValueReady(int paramId) {
notifyValueDispatch(paramId);
}
void TrainerThread::copyInArgs() {
int TrainerThread::copyInArgs() {
const std::vector<Argument>& fullInArgs = multiMachine_->getInArgs();
int numThreads = multiMachine_->getAllThreads().size();
int32_t numSequences = fullInArgs[0].getNumSequences();
......@@ -748,7 +766,7 @@ void TrainerThread::copyInArgs() {
}
if (copySize == 0) {
return;
return 0;
}
for (size_t i = 0; i < fullInArgs.size(); i++) {
......@@ -758,6 +776,7 @@ void TrainerThread::copyInArgs() {
copySize,
FLAGS_parallel_nn ? false : multiMachine_->useGpu());
}
return copySize;
}
void TrainerThread::mergeCpuGradients() {
......
......@@ -387,6 +387,9 @@ public:
/// copy the output gradient from the main GradientMachine.
void copyOutputGrad();
/// Whether the thread has input data.
bool hasInputData() { return batchSize_ != 0; }
protected:
void mergeCpuGradients();
......@@ -407,7 +410,7 @@ protected:
void copyGradToBufferThread();
void gradCollectThread();
void copyInArgs();
int copyInArgs();
void forward();
void backward();
void backwardCallback(Parameter* para);
......@@ -467,6 +470,7 @@ protected:
/// indicate whether inArgs is copied before forward()
bool inArgsCopied_;
int batchSize_;
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册