diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 9abda18d544fe252296cfa2e3a4484f1ca3f1e37..8ef5e9d0c116dd088b5c5c318dfb47c245b471fa 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -166,12 +166,16 @@ MultiGradientMachine::MultiGradientMachine(const ModelConfig& config, outArgStream_ = HPPL_STREAM_1; + start(); +} + +void MultiGradientMachine::start() { for (auto& thread : threads_) { thread->start(); } } -MultiGradientMachine::~MultiGradientMachine() { +void MultiGradientMachine::finish() { for (auto& thread : threads_) { thread->stop(); } @@ -445,7 +449,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, gradStream_ = HPPL_STREAM_2; valueStream_ = HPPL_STREAM_3; - stopping_ = false; + stopping_ = true; updateCounter_ = 0; parameterUpdated_ = false; } @@ -453,6 +457,10 @@ TrainerThread::TrainerThread(const ModelConfig& config, TrainerThread::~TrainerThread() { stop(); } void TrainerThread::start() { + if (!stopping_) return; + + stopping_ = false; + gradientMachine_->start(); computeThread_.reset(new std::thread([this]() { computeThread(); })); diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.h b/paddle/gserver/gradientmachines/MultiGradientMachine.h index c005c0ed67fed287e780772683c6b0af9c7c73b7..5e7622f929fd57de6e38855528a752b5586c4cd1 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.h +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.h @@ -176,7 +176,9 @@ public: explicit MultiGradientMachine(const ModelConfig& config, bool useGpu); - virtual ~MultiGradientMachine(); + virtual void start(); + + virtual void finish(); virtual void prefetch(const std::vector& inArgs);