diff --git a/paddle/parameter/ParameterUpdaterBase.h b/paddle/parameter/ParameterUpdaterBase.h index 88148d9b769e9b6eca90f9651a121e926543d7c2..b230e170c15f1b004c5357fb7d0ad2204d01f44b 100644 --- a/paddle/parameter/ParameterUpdaterBase.h +++ b/paddle/parameter/ParameterUpdaterBase.h @@ -38,7 +38,7 @@ public: virtual void startPass() {} // called by Trainer then finishing a pass, ruturn true if pass accepted - virtual bool finishPass(real cost = 0) { return true; } + virtual bool finishPass() { return true; } // called by Trainer before backward() of a batch // Return the type of pass it needs. This pass type will be passed @@ -112,9 +112,9 @@ public: [&](int tid, size_t numThreads) { updaters_[tid]->startPass(); }); } - virtual bool finishPass(real cost = 0) { + virtual bool finishPass() { syncThreadPool_->execPlusOwner( - [&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); }); + [&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); }); return true; } diff --git a/paddle/trainer/ParameterUpdater.h b/paddle/trainer/ParameterUpdater.h index 4dae77567f8f4d097c583567275d4b90122feb6a..c3207e63ce72b73a57c2e40c72c5259f0ae61bc9 100644 --- a/paddle/trainer/ParameterUpdater.h +++ b/paddle/trainer/ParameterUpdater.h @@ -102,9 +102,9 @@ public: * @param cost sum cost during one pass. * @return true if accept (used for owlqn). */ - virtual bool finishPass(real cost) { + virtual bool finishPass() { optimizer_->finishPass(); - return ParameterUpdater::finishPass(cost); + return ParameterUpdater::finishPass(); } /** @@ -220,9 +220,9 @@ public: averager_->startPass(); SgdLocalUpdater::startPass(); } - virtual bool finishPass(real cost) { + virtual bool finishPass() { averager_->finishPass(); - return SgdLocalUpdater::finishPass(cost); + return SgdLocalUpdater::finishPass(); } /// apply the averaged parameter to PARAMETER_VALUE diff --git a/paddle/trainer/RemoteParameterUpdater.cpp b/paddle/trainer/RemoteParameterUpdater.cpp index 630f55d998d9f5b5b2880aa02b025e6e56e1f064..6939738203f41e0c1f7204d54834e34b2cd90682 100644 --- a/paddle/trainer/RemoteParameterUpdater.cpp +++ b/paddle/trainer/RemoteParameterUpdater.cpp @@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() { } } -bool RemoteParameterUpdater::finishPass(real cost) { +bool RemoteParameterUpdater::finishPass() { if (localUpdater_) { localUpdater_->finishPass(); } @@ -712,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() { } } -bool SparseRemoteParameterUpdater::finishPass(real cost) { +bool SparseRemoteParameterUpdater::finishPass() { if (config_.algorithm() == TrainAlgorithm::SGD) { parameterClient_->waitPassFinish(); } else { diff --git a/paddle/trainer/RemoteParameterUpdater.h b/paddle/trainer/RemoteParameterUpdater.h index ec6ed443d33db1d695194092b34d6090a4b5ab94..7794b209009a3429e810074b61e1d5bffa8b3a4e 100644 --- a/paddle/trainer/RemoteParameterUpdater.h +++ b/paddle/trainer/RemoteParameterUpdater.h @@ -90,7 +90,7 @@ public: */ virtual void finishBatch(real cost); virtual void startPass(); - virtual bool finishPass(real cost); + virtual bool finishPass(); #ifndef PADDLE_DISABLE_TIMER virtual void setForwardbackwardTime(uint64_t delta) { @@ -281,7 +281,7 @@ public: /// send all sparse related parameters to all pservers virtual void finishBatch(real cost); virtual void startPass(); - virtual bool finishPass(real cost); + virtual bool finishPass(); virtual void apply(); virtual void restore(); diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index 2a76d5723ccb68896f8ddbfad31a9d7d84adcf55..870d4a4b0246fe244bbd3796ec14449eb181aad2 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() { } } -bool SgdThreadUpdater::finishPass(real cost) { +bool SgdThreadUpdater::finishPass() { catchUpWith(); for (auto& para : parameters_) { diff --git a/paddle/trainer/ThreadParameterUpdater.h b/paddle/trainer/ThreadParameterUpdater.h index 198435c0f30056a9467b8a076c8443ae243e7c3f..880f1f9ddc49a1193ce23901419d988cae84eb88 100644 --- a/paddle/trainer/ThreadParameterUpdater.h +++ b/paddle/trainer/ThreadParameterUpdater.h @@ -47,7 +47,7 @@ public: virtual void startPass(); // Use the finishPass() function of the base optimizer. - virtual bool finishPass(real cost); + virtual bool finishPass(); virtual void init(const std::vector& parameters); virtual PassType startBatch(int64_t batchSize); diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 1eec2c432d235ef484b688db08aae8a39f878a85..031e3b7cf199c197a9cafbf809afe4dfab15b87b 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) { trainerInternal_.getGradientMachine()->onPassEnd(); - bool accepted = trainerInternal_.getParameterUpdater()->finishPass(cost); + bool accepted = trainerInternal_.getParameterUpdater()->finishPass(); globalStat.setThreadInfo(true); globalStat.printAllStatus();