From 71a316ea1f79fe0fef451d98ee7e89e6abcdca7c Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Dec 2016 23:12:22 +0800 Subject: [PATCH] Remove unused cost parameter in ParameterUpdater --- paddle/parameter/ParameterUpdaterBase.h | 6 +++--- paddle/trainer/ParameterUpdater.h | 8 ++++---- paddle/trainer/RemoteParameterUpdater.cpp | 4 ++-- paddle/trainer/RemoteParameterUpdater.h | 4 ++-- paddle/trainer/ThreadParameterUpdater.cpp | 2 +- paddle/trainer/ThreadParameterUpdater.h | 2 +- paddle/trainer/Trainer.cpp | 2 +- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/parameter/ParameterUpdaterBase.h b/paddle/parameter/ParameterUpdaterBase.h index 5401046f67..d13dbe33cf 100644 --- a/paddle/parameter/ParameterUpdaterBase.h +++ b/paddle/parameter/ParameterUpdaterBase.h @@ -38,7 +38,7 @@ public: virtual void startPass() {} // called by Trainer then finishing a pass, ruturn true if pass accepted - virtual bool finishPass(real cost = 0) { return true; } + virtual bool finishPass() { return true; } // called by Trainer before backward() of a batch // Return the type of pass it needs. This pass type will be passed @@ -112,9 +112,9 @@ public: [&](int tid, size_t numThreads) { updaters_[tid]->startPass(); }); } - virtual bool finishPass(real cost = 0) { + virtual bool finishPass() { syncThreadPool_->execPlusOwner( - [&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); }); + [&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); }); return true; } diff --git a/paddle/trainer/ParameterUpdater.h b/paddle/trainer/ParameterUpdater.h index e52b5cd318..9e62580ccb 100644 --- a/paddle/trainer/ParameterUpdater.h +++ b/paddle/trainer/ParameterUpdater.h @@ -102,9 +102,9 @@ public: * @param cost sum cost during one pass. * @return true if accept (used for owlqn). */ - virtual bool finishPass(real cost) { + virtual bool finishPass() { optimizer_->finishPass(); - return ParameterUpdater::finishPass(cost); + return ParameterUpdater::finishPass(); } /** @@ -220,9 +220,9 @@ public: averager_->startPass(); SgdLocalUpdater::startPass(); } - virtual bool finishPass(real cost) { + virtual bool finishPass() { averager_->finishPass(); - return SgdLocalUpdater::finishPass(cost); + return SgdLocalUpdater::finishPass(); } /// apply the averaged parameter to PARAMETER_VALUE diff --git a/paddle/trainer/RemoteParameterUpdater.cpp b/paddle/trainer/RemoteParameterUpdater.cpp index 974e78fa17..6ee2ed9158 100644 --- a/paddle/trainer/RemoteParameterUpdater.cpp +++ b/paddle/trainer/RemoteParameterUpdater.cpp @@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() { } } -bool RemoteParameterUpdater::finishPass(real cost) { +bool RemoteParameterUpdater::finishPass() { if (localUpdater_) { localUpdater_->finishPass(); } @@ -711,7 +711,7 @@ void SparseRemoteParameterUpdater::startPass() { } } -bool SparseRemoteParameterUpdater::finishPass(real cost) { +bool SparseRemoteParameterUpdater::finishPass() { if (config_.algorithm() == TrainAlgorithm::SGD) { parameterClient_->waitPassFinish(); } else { diff --git a/paddle/trainer/RemoteParameterUpdater.h b/paddle/trainer/RemoteParameterUpdater.h index 66055c778e..8c5d5bb66b 100644 --- a/paddle/trainer/RemoteParameterUpdater.h +++ b/paddle/trainer/RemoteParameterUpdater.h @@ -90,7 +90,7 @@ public: */ virtual void finishBatch(real cost); virtual void startPass(); - virtual bool finishPass(real cost); + virtual bool finishPass(); #ifndef PADDLE_DISABLE_TIMER virtual void setForwardbackwardTime(uint64_t delta) { @@ -281,7 +281,7 @@ public: /// send all sparse related parameters to all pservers virtual void finishBatch(real cost); virtual void startPass(); - virtual bool finishPass(real cost); + virtual bool finishPass(); virtual void apply(); virtual void restore(); diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index 049022b1f1..36d42ed7e9 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() { } } -bool SgdThreadUpdater::finishPass(real cost) { +bool SgdThreadUpdater::finishPass() { catchUpWith(); for (auto& para : parameters_) { diff --git a/paddle/trainer/ThreadParameterUpdater.h b/paddle/trainer/ThreadParameterUpdater.h index d01ac689f9..61f337ecb3 100644 --- a/paddle/trainer/ThreadParameterUpdater.h +++ b/paddle/trainer/ThreadParameterUpdater.h @@ -47,7 +47,7 @@ public: virtual void startPass(); // Use the finishPass() function of the base optimizer. - virtual bool finishPass(real cost); + virtual bool finishPass(); virtual void init(std::vector& parameters); virtual PassType startBatch(int64_t batchSize); diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 1eec2c432d..031e3b7cf1 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) { trainerInternal_.getGradientMachine()->onPassEnd(); - bool accepted = trainerInternal_.getParameterUpdater()->finishPass(cost); + bool accepted = trainerInternal_.getParameterUpdater()->finishPass(); globalStat.setThreadInfo(true); globalStat.printAllStatus(); -- GitLab