提交 adc58397 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #969 from reyoung/feature/clean_gradient_machine_start

Remove not used params in GradientMachine::start
......@@ -212,11 +212,7 @@ public:
* @note This function will only been implemented and used in a
* multithreaded environment.
*/
virtual void start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
}
virtual void start() {}
/**
* @brief check each work-thread whether is failed/error/finish,
......
......@@ -441,7 +441,7 @@ TrainerThread::TrainerThread(const ModelConfig& config,
TrainerThread::~TrainerThread() { stop(); }
void TrainerThread::start() {
gradientMachine_->start(*(TrainerConfig*)nullptr, (DataProviderPtr) nullptr);
gradientMachine_->start();
computeThread_.reset(new std::thread([this]() { computeThread(); }));
......
......@@ -109,10 +109,9 @@ void MultiNetwork::onPassEnd() {
}
}
void MultiNetwork::start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
void MultiNetwork::start() {
for (auto& subNetwork : subNetworks_) {
subNetwork->start(config, dataProvider);
subNetwork->start();
}
}
......
......@@ -54,7 +54,7 @@ public:
return subNetworks_;
}
virtual void start(const TrainerConfig& config, DataProviderPtr dataProvider);
virtual void start();
virtual void finish();
......
......@@ -131,11 +131,7 @@ void ParallelNeuralNetwork::forwardBackward(const std::vector<Argument>& inArgs,
backward(callback);
}
void ParallelNeuralNetwork::start(const TrainerConfig& config,
DataProviderPtr dataProvider) {
(void)config;
(void)dataProvider;
void ParallelNeuralNetwork::start() {
for (auto& thread : threads_) {
thread->start();
}
......
......@@ -56,7 +56,7 @@ public:
PassType passType,
const UpdateCallback &callback = NULL);
virtual void start(const TrainerConfig &config, DataProviderPtr dataProvider);
virtual void start();
void addComputeThread(int deviceId);
......
......@@ -114,7 +114,7 @@ void calcGradient(DataIn& in, DataOut& out, const std::string& configPath) {
parameters[i]->getBuf(PARAMETER_VALUE)->copyFrom(*in.paraValues[i]);
}
}
gradientMachine->start(trainer.getConfig(), nullptr);
gradientMachine->start();
gradientMachine->forward(in.inArgs, &outArgs, PASS_TRAIN);
for (size_t i = 0; i < in.outGrads.size(); i++) {
// If the all the layers in the config have no parameters, also
......
......@@ -28,7 +28,7 @@ class TrainerForTest : public paddle::Trainer {
public:
void startTrain() {
GradientMachine& gm = *this->trainerInternal_.getGradientMachine();
gm.start(this->getConfig(), dataProvider_);
gm.start();
}
void finishTrain() {
......
......@@ -257,7 +257,7 @@ void Tester::test() {
CHECK(testDataProvider_) << "TestData is not specified";
testDataProvider_->setSkipShuffle();
testDataProvider_->reset();
gradientMachine_->start(*config_, testDataProvider_);
gradientMachine_->start();
// For evaluation
std::vector<std::string> modelList;
......
......@@ -308,7 +308,7 @@ static double genPerturbation(real* d, real* grad, size_t dim) {
}
real Trainer::checkGradient() {
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
trainerInternal_.getGradientMachine()->start();
std::vector<ParameterPtr>& parameters =
trainerInternal_.getGradientMachine()->getNonStaticParameters();
DataBatch dataBatch;
......@@ -390,7 +390,7 @@ void Trainer::startTrain() {
dataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
trainerInternal_.getGradientMachine()->start();
}
void Trainer::finishTrain() { trainerInternal_.getGradientMachine()->finish(); }
......
......@@ -50,7 +50,7 @@ void calcGradient(bool useGpu, comData& Data) {
trainer.getDataProvider()->getNextBatch(batchSize, &dataBatch);
CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr);
trainer.getGradientMachine()->start();
for (int i = 0; i < 2; ++i) {
trainer.getGradientMachine()->forwardBackward(
inArgs, &Data.outArgs, PASS_TRAIN);
......
......@@ -72,7 +72,7 @@ void calcGradient(ComData& data, const string configFile) {
CHECK(dataBatch.getSize()) << "No data from data provider";
vector<Argument>& inArgs = dataBatch.getStreams();
trainer.getGradientMachine()->start(trainer.getConfig(), nullptr);
trainer.getGradientMachine()->start();
trainer.getGradientMachine()->forwardBackward(
inArgs, &data.outArgs, PASS_TRAIN);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册