提交 adc58397 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #969 from reyoung/feature/clean_gradient_machine_start

Remove not used params in GradientMachine::start
...@@ -212,11 +212,7 @@ public: ...@@ -212,11 +212,7 @@ public:
* @note This function will only been implemented and used in a * @note This function will only been implemented and used in a
* multithreaded environment. * multithreaded environment.
*/ */
virtual void start(const TrainerConfig& config, virtual void start() {}
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
}
/** /**
* @brief check each work-thread whether is failed/error/finish, * @brief check each work-thread whether is failed/error/finish,
......
...@@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config, ...@@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); } TrainerThread::~TrainerThread() { stop(); }
void TrainerThread::start() { void TrainerThread::start() {
gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr); gradientMachine_->start();
computeThread_.reset(new std::thread([this]() { computeThread(); })); computeThread_.reset(new std::thread([this]() { computeThread(); }));
......
...@@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() { ...@@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() {
} }
} }
void MultiNetwork::start(const TrainerConfig& config, void MultiNetwork::start() {
DataProviderPtr dataProvider) {
for (auto& subNetwork : subNetworks_) { for (auto& subNetwork : subNetworks_) {
subNetwork->start(config, dataProvider); subNetwork->start();
} }
} }
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
return subNetworks_; return subNetworks_;
} }
virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider); virtual void start();
virtual void finish(); virtual void finish();
......
...@@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs, ...@@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs,
backward(callback); backward(callback);
} }
void ParallelNeuralNetwork::start(const TrainerConfig& config, void ParallelNeuralNetwork::start() {
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
for (auto& thread : threads_) { for (auto& thread : threads_) {
thread->start(); thread->start();
} }
......
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
PassType passType, PassType passType,
const UpdateCallback &callback = NULL); const UpdateCallback &callback = NULL);
virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider); virtual void start();
void addComputeThread(int deviceId); void addComputeThread(int deviceId);
......
...@@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) { ...@@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) {
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]); parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
} }
} }
gradientMachine->start(trainer.getConfig(), nullptr); gradientMachine->start();
gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN); gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN);
for (size_t i = 0; i < in.outGrads.size(); i++) { for (size_t i = 0; i < in.outGrads.size(); i++) {
// If the all the layers in the config have no parameters, also // If the all the layers in the config have no parameters, also
......
...@@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer { ...@@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer {
public: public:
void startTrain() { void startTrain() {
GradientMachine& gm = *this->trainerInternal_.getGradientMachine(); GradientMachine& gm = *this->trainerInternal_.getGradientMachine();
gm.start(this->getConfig(), dataProvider_); gm.start();
} }
void finishTrain() { void finishTrain() {
......
...@@ -257,7 +257,7 @@ void Tester::test() { ...@@ -257,7 +257,7 @@ void Tester::test() {
CHECK(testDataProvider_) << "TestData is not specified"; CHECK(testDataProvider_) << "TestData is not specified";
testDataProvider_->setSkipShuffle(); testDataProvider_->setSkipShuffle();
testDataProvider_->reset(); testDataProvider_->reset();
gradientMachine_->start(*config_, testDataProvider_); gradientMachine_->start();
// For evaluation // For evaluation
std::vector<std::string> modelList; std::vector<std::string> modelList;
......
...@@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) { ...@@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
} }
real Trainer::checkGradient() { real Trainer::checkGradient() {
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); trainerInternal_.getGradientMachine()->start();
std::vector<ParameterPtr>& parameters = std::vector<ParameterPtr>& parameters =
trainerInternal_.getGradientMachine()->getNonStaticParameters(); trainerInternal_.getGradientMachine()->getNonStaticParameters();
DataBatch dataBatch; DataBatch dataBatch;
...@@ -390,7 +390,7 @@ void Trainer::startTrain() { ...@@ -390,7 +390,7 @@ void Trainer::startTrain() {
dataProvider_->reset(); dataProvider_->reset();
} }
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); trainerInternal_.getGradientMachine()->start();
} }
void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); } void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }
......
...@@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) { ...@@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) {
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch); trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
CHECK(dataBatch.getSize()) << "No data from data provider"; CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams(); vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); trainer.getGradientMachine()->start();
for (int i = 0; i < 2; ++i) { for (int i = 0; i < 2; ++i) {
trainer.getGradientMachine()->forwardBackward( trainer.getGradientMachine()->forwardBackward(
inArgs, &Data.outArgs, PASS_TRAIN); inArgs, &Data.outArgs, PASS_TRAIN);
......
...@@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) { ...@@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) {
CHECK(dataBatch.getSize()) << "No data from data provider"; CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams(); vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr); trainer.getGradientMachine()->start();
trainer.getGradientMachine()->forwardBackward( trainer.getGradientMachine()->forwardBackward(
inArgs, &data.outArgs, PASS_TRAIN); inArgs, &data.outArgs, PASS_TRAIN);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册