提交 80cf42c8 编写于 作者: G gongweibao

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into recordio

...@@ -166,11 +166,21 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, ...@@ -166,11 +166,21 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config,
outArgStream_ = HPPL_STREAM_1; outArgStream_ = HPPL_STREAM_1;
start();
}
void MultiGradientMachine::start() {
for (auto& thread : threads_) { for (auto& thread : threads_) {
thread->start(); thread->start();
} }
} }
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
std::vector<const std::vector<ParameterPtr>*> std::vector<const std::vector<ParameterPtr>*>
MultiGradientMachine::getSlaveParameters() { MultiGradientMachine::getSlaveParameters() {
std::vector<const std::vector<ParameterPtr>*> vec; std::vector<const std::vector<ParameterPtr>*> vec;
...@@ -326,12 +336,6 @@ void MultiGradientMachine::onPassEnd() { ...@@ -326,12 +336,6 @@ void MultiGradientMachine::onPassEnd() {
} }
} }
void MultiGradientMachine::finish() {
for (auto& thread : threads_) {
thread->stop();
}
}
Evaluator* MultiGradientMachine::makeEvaluator() const { Evaluator* MultiGradientMachine::makeEvaluator() const {
return threads_[0]->getGradientMachine()->makeEvaluator(); return threads_[0]->getGradientMachine()->makeEvaluator();
} }
...@@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, ...@@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
gradStream_ = HPPL_STREAM_2; gradStream_ = HPPL_STREAM_2;
valueStream_ = HPPL_STREAM_3; valueStream_ = HPPL_STREAM_3;
stopping_ = false; stopping_ = true;
updateCounter_ = 0; updateCounter_ = 0;
parameterUpdated_ = false; parameterUpdated_ = false;
} }
...@@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config, ...@@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); } TrainerThread::~TrainerThread() { stop(); }
void TrainerThread::start() { void TrainerThread::start() {
if (!stopping_) return;
stopping_ = false;
gradientMachine_->start(); gradientMachine_->start();
computeThread_.reset(new std::thread([this]() { computeThread(); })); computeThread_.reset(new std::thread([this]() { computeThread(); }));
......
...@@ -176,6 +176,10 @@ public: ...@@ -176,6 +176,10 @@ public:
explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); explicit MultiGradientMachine(const ModelConfig& config, bool useGpu);
virtual void start();
virtual void finish();
virtual void prefetch(const std::vector<Argument>& inArgs); virtual void prefetch(const std::vector<Argument>& inArgs);
virtual void forward(const std::vector<Argument>& inArgs, virtual void forward(const std::vector<Argument>& inArgs,
...@@ -193,8 +197,6 @@ public: ...@@ -193,8 +197,6 @@ public:
virtual void onPassEnd(); virtual void onPassEnd();
virtual void finish();
virtual Evaluator* makeEvaluator() const; virtual Evaluator* makeEvaluator() const;
virtual void eval(Evaluator* evaluator) const; virtual void eval(Evaluator* evaluator) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册