提交 677c79b6 编写于 作者: Y Yu Yang

Merge branch 'feature/clean_parameter_updater_finish_pass' into feature/mnist_train_api

......@@ -38,7 +38,7 @@ public:
virtual void startPass() {}
// called by Trainer then finishing a pass, ruturn true if pass accepted
virtual bool finishPass(real cost = 0) { return true; }
virtual bool finishPass() { return true; }
// called by Trainer before backward() of a batch
// Return the type of pass it needs. This pass type will be passed
......@@ -112,9 +112,9 @@ public:
[&](int tid, size_t numThreads) { updaters_[tid]->startPass(); });
}
virtual bool finishPass(real cost = 0) {
virtual bool finishPass() {
syncThreadPool_->execPlusOwner(
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); });
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); });
return true;
}
......
......@@ -102,9 +102,9 @@ public:
* @param cost sum cost during one pass.
* @return true if accept (used for owlqn).
*/
virtual bool finishPass(real cost) {
virtual bool finishPass() {
optimizer_->finishPass();
return ParameterUpdater::finishPass(cost);
return ParameterUpdater::finishPass();
}
/**
......@@ -220,9 +220,9 @@ public:
averager_->startPass();
SgdLocalUpdater::startPass();
}
virtual bool finishPass(real cost) {
virtual bool finishPass() {
averager_->finishPass();
return SgdLocalUpdater::finishPass(cost);
return SgdLocalUpdater::finishPass();
}
/// apply the averaged parameter to PARAMETER_VALUE
......
......@@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() {
}
}
bool RemoteParameterUpdater::finishPass(real cost) {
bool RemoteParameterUpdater::finishPass() {
if (localUpdater_) {
localUpdater_->finishPass();
}
......@@ -712,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() {
}
}
bool SparseRemoteParameterUpdater::finishPass(real cost) {
bool SparseRemoteParameterUpdater::finishPass() {
if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassFinish();
} else {
......
......@@ -90,7 +90,7 @@ public:
*/
virtual void finishBatch(real cost);
virtual void startPass();
virtual bool finishPass(real cost);
virtual bool finishPass();
#ifndef PADDLE_DISABLE_TIMER
virtual void setForwardbackwardTime(uint64_t delta) {
......@@ -281,7 +281,7 @@ public:
/// send all sparse related parameters to all pservers
virtual void finishBatch(real cost);
virtual void startPass();
virtual bool finishPass(real cost);
virtual bool finishPass();
virtual void apply();
virtual void restore();
......
......@@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() {
}
}
bool SgdThreadUpdater::finishPass(real cost) {
bool SgdThreadUpdater::finishPass() {
catchUpWith();
for (auto& para : parameters_) {
......
......@@ -47,7 +47,7 @@ public:
virtual void startPass();
// Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost);
virtual bool finishPass();
virtual void init(const std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize);
......
......@@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) {
trainerInternal_.getGradientMachine()->onPassEnd();
bool accepted = trainerInternal_.getParameterUpdater()->finishPass(cost);
bool accepted = trainerInternal_.getParameterUpdater()->finishPass();
globalStat.setThreadInfo(true);
globalStat.printAllStatus();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册