From 56f29658ba935e539ac8977f44a5c942cb09bc29 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 20 Dec 2016 22:12:56 +0800 Subject: [PATCH] Remove not used params in GradientMachine::start --- paddle/gserver/gradientmachines/GradientMachine.h | 6 +----- paddle/gserver/gradientmachines/MultiGradientMachine.cpp | 2 +- paddle/gserver/gradientmachines/MultiNetwork.cpp | 5 ++--- paddle/gserver/gradientmachines/MultiNetwork.h | 2 +- paddle/gserver/gradientmachines/ParallelNeuralNetwork.cpp | 6 +----- paddle/gserver/gradientmachines/ParallelNeuralNetwork.h | 2 +- paddle/gserver/tests/test_NetworkCompare.cpp | 2 +- paddle/gserver/tests/test_RecurrentGradientMachine.cpp | 2 +- paddle/trainer/Tester.cpp | 2 +- paddle/trainer/Trainer.cpp | 4 ++-- paddle/trainer/tests/test_Compare.cpp | 2 +- paddle/trainer/tests/test_CompareTwoNets.cpp | 2 +- 12 files changed, 14 insertions(+), 23 deletions(-) diff --git a/paddle/gserver/gradientmachines/GradientMachine.h b/paddle/gserver/gradientmachines/GradientMachine.h index 579eca71d..ad82869ae 100644 --- a/paddle/gserver/gradientmachines/GradientMachine.h +++ b/paddle/gserver/gradientmachines/GradientMachine.h @@ -212,11 +212,7 @@ public: * @note This function will only been implemented and used in a * multithreaded environment. */ - virtual void start(const TrainerConfig& config, - DataProviderPtr dataProvider) { - (void)config; - (void)dataProvider; - } + virtual void start() {} /** * @brief check each work-thread whether is failed/error/finish, diff --git a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp index 88c098b35..95a4c0e16 100644 --- a/paddle/gserver/gradientmachines/MultiGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/MultiGradientMachine.cpp @@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, TrainerThread::~TrainerThread() { stop(); } void TrainerThread::start() { - gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr); + gradientMachine_->start(); computeThread_.reset(new std::thread([this]() { computeThread(); })); diff --git a/paddle/gserver/gradientmachines/MultiNetwork.cpp b/paddle/gserver/gradientmachines/MultiNetwork.cpp index 6eb3d8db9..f1308f372 100644 --- a/paddle/gserver/gradientmachines/MultiNetwork.cpp +++ b/paddle/gserver/gradientmachines/MultiNetwork.cpp @@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() { } } -void MultiNetwork::start(const TrainerConfig& config, - DataProviderPtr dataProvider) { +void MultiNetwork::start() { for (auto& subNetwork : subNetworks_) { - subNetwork->start(config, dataProvider); + subNetwork->start(); } } diff --git a/paddle/gserver/gradientmachines/MultiNetwork.h b/paddle/gserver/gradientmachines/MultiNetwork.h index 89fbf32b4..f04406b98 100644 --- a/paddle/gserver/gradientmachines/MultiNetwork.h +++ b/paddle/gserver/gradientmachines/MultiNetwork.h @@ -54,7 +54,7 @@ public: return subNetworks_; } - virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider); + virtual void start(); virtual void finish(); diff --git a/paddle/gserver/gradientmachines/ParallelNeuralNetwork.cpp b/paddle/gserver/gradientmachines/ParallelNeuralNetwork.cpp index 980a5851a..c6e3a3b32 100644 --- a/paddle/gserver/gradientmachines/ParallelNeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/ParallelNeuralNetwork.cpp @@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector& inArgs, backward(callback); } -void ParallelNeuralNetwork::start(const TrainerConfig& config, - DataProviderPtr dataProvider) { - (void)config; - (void)dataProvider; - +void ParallelNeuralNetwork::start() { for (auto& thread : threads_) { thread->start(); } diff --git a/paddle/gserver/gradientmachines/ParallelNeuralNetwork.h b/paddle/gserver/gradientmachines/ParallelNeuralNetwork.h index 8f445b1de..39f5682a5 100644 --- a/paddle/gserver/gradientmachines/ParallelNeuralNetwork.h +++ b/paddle/gserver/gradientmachines/ParallelNeuralNetwork.h @@ -56,7 +56,7 @@ public: PassType passType, const UpdateCallback &callback = NULL); - virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider); + virtual void start(); void addComputeThread(int deviceId); diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index fc60228f8..0d2610595 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) { parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]); } } - gradientMachine->start(trainer.getConfig(), nullptr); + gradientMachine->start(); gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN); for (size_t i = 0; i < in.outGrads.size(); i++) { // If the all the layers in the config have no parameters, also diff --git a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp index e19cf35cd..150850da4 100644 --- a/paddle/gserver/tests/test_RecurrentGradientMachine.cpp +++ b/paddle/gserver/tests/test_RecurrentGradientMachine.cpp @@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer { public: void startTrain() { GradientMachine& gm = *this->trainerInternal_.getGradientMachine(); - gm.start(this->getConfig(), dataProvider_); + gm.start(); } void finishTrain() { diff --git a/paddle/trainer/Tester.cpp b/paddle/trainer/Tester.cpp index 24fac3e5a..13aa28ae5 100644 --- a/paddle/trainer/Tester.cpp +++ b/paddle/trainer/Tester.cpp @@ -257,7 +257,7 @@ void Tester::test() { CHECK(testDataProvider_) << "TestData is not specified"; testDataProvider_->setSkipShuffle(); testDataProvider_->reset(); - gradientMachine_->start(*config_, testDataProvider_); + gradientMachine_->start(); // For evaluation std::vector modelList; diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 1eec2c432..6c57467cc 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) { } real Trainer::checkGradient() { - trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); + trainerInternal_.getGradientMachine()->start(); std::vector& parameters = trainerInternal_.getGradientMachine()->getNonStaticParameters(); DataBatch dataBatch; @@ -390,7 +390,7 @@ void Trainer::startTrain() { dataProvider_->reset(); } - trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); + trainerInternal_.getGradientMachine()->start(); } void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); } diff --git a/paddle/trainer/tests/test_Compare.cpp b/paddle/trainer/tests/test_Compare.cpp index 72fc76bea..e855a8fe2 100644 --- a/paddle/trainer/tests/test_Compare.cpp +++ b/paddle/trainer/tests/test_Compare.cpp @@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) { trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch); CHECK(dataBatch.getSize()) << "No data from data provider"; vector& inArgs = dataBatch.getStreams(); - trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); + trainer.getGradientMachine()->start(); for (int i = 0; i < 2; ++i) { trainer.getGradientMachine()->forwardBackward( inArgs, &Data.outArgs, PASS_TRAIN); diff --git a/paddle/trainer/tests/test_CompareTwoNets.cpp b/paddle/trainer/tests/test_CompareTwoNets.cpp index 80c61e259..94f65e545 100644 --- a/paddle/trainer/tests/test_CompareTwoNets.cpp +++ b/paddle/trainer/tests/test_CompareTwoNets.cpp @@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) { CHECK(dataBatch.getSize()) << "No data from data provider"; vector& inArgs = dataBatch.getStreams(); - trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); + trainer.getGradientMachine()->start(); trainer.getGradientMachine()->forwardBackward( inArgs, &data.outArgs, PASS_TRAIN); -- GitLab