diff --git a/paddle/trainer/Tester.cpp b/paddle/trainer/Tester.cpp index db7916cf1413d19dd82062886d7adeeb1a3aa61f..d941a574c0fd443c1f3e01799b5d3d5af1b03c85 100644 --- a/paddle/trainer/Tester.cpp +++ b/paddle/trainer/Tester.cpp @@ -17,22 +17,22 @@ limitations under the License. */ #include #include -#include #include -#include +#include #include +#include #include +#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/PythonUtil.h" #include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" -#include "paddle/utils/GlobalConstants.h" +#include "TesterConfig.h" +#include "paddle/gserver/gradientmachines/GradientMachineMode.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "paddle/gserver/layers/ValidationLayer.h" -#include "paddle/gserver/gradientmachines/GradientMachineMode.h" -#include "TesterConfig.h" namespace paddle { @@ -66,6 +66,7 @@ Tester::Tester(const std::shared_ptr& config, } void Tester::startTestPeriod() { + testDataProvider_->reset(); testEvaluator_->start(); testContext_.cost = 0; testContext_.numSamples = 0; @@ -87,27 +88,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch, void Tester::testOnePeriod() { DataBatch dataBatch; int64_t batchSize = config_->getOptConfig().batch_size(); - - int batches = std::numeric_limits::max(); - std::vector outArgs; - startTestPeriod(); - for (int i = 0; i < batches; ++i) { - int num = testDataProvider_->getNextBatch(batchSize, &dataBatch); - if (num == 0) { - testDataProvider_->reset(); - if (intconfig_->prevBatchState) { - gradientMachine_->resetState(); - } - break; - } + while (testDataProvider_->getNextBatch(batchSize, &dataBatch) != 0) { testOneDataBatch(dataBatch, &outArgs); } finishTestPeriod(); } void Tester::finishTestPeriod() { + if (intconfig_->prevBatchState) { + gradientMachine_->resetState(); + } testEvaluator_->finish(); CHECK_GT(testContext_.numSamples, 0) << "There is no samples in your test batch. Possibly " diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 99bf45d247a9c61819ba92147616cb7ff75f89fb..9c83c207ede99bddeeab5f56d90d357ee8b56edd 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -17,36 +17,38 @@ limitations under the License. */ #include #include -#include #include -#include +#include #include +#include #include +#include "paddle/utils/Excepts.h" +#include "paddle/utils/GlobalConstants.h" #include "paddle/utils/PythonUtil.h" #include "paddle/utils/Stat.h" #include "paddle/utils/Util.h" -#include "paddle/utils/Excepts.h" -#include "paddle/utils/GlobalConstants.h" -#include "paddle/gserver/gradientmachines/NeuralNetwork.h" -#include "paddle/gserver/gradientmachines/GradientMachineMode.h" -#include "paddle/gserver/layers/ValidationLayer.h" +#include "RemoteParameterUpdater.h" #include "TesterConfig.h" #include "ThreadParameterUpdater.h" -#include "RemoteParameterUpdater.h" #include "TrainerConfigHelper.h" +#include "paddle/gserver/gradientmachines/GradientMachineMode.h" +#include "paddle/gserver/gradientmachines/NeuralNetwork.h" +#include "paddle/gserver/layers/ValidationLayer.h" P_DEFINE_string(config, "", "Trainer config file"); -P_DEFINE_int32(test_period, 0, +P_DEFINE_int32(test_period, + 0, "if equal 0, do test on all test data at the end of " "each pass. While if equal non-zero, do test on all test " "data every test_period batches"); -P_DEFINE_bool(test_all_data_in_one_period, false, - "This option was deprecated, since we will always do " - "test on all test set "); +P_DEFINE_bool(test_all_data_in_one_period, + false, + "This option was deprecated, since we will always do " + "test on all test set "); P_DEFINE_bool(local, true, "Train in local mode or not"); @@ -392,10 +394,6 @@ void Trainer::startTrain() { dataProvider_->reset(); } - if (this->testDataProvider_) { - this->testDataProvider_->reset(); - } - trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); } @@ -630,16 +628,14 @@ void Trainer::test() { tester_->test(); } std::unique_ptr Trainer::createTesterConfig() { TesterConfig* conf = new TesterConfig; if (FLAGS_test_period) { - LOG(WARNING) - << "The meaning of --test_period is changed: " - << "if equal 0, do test on all test data at the end of " - << "each pass. While if equal non-zero, do test on all test " - << "data every test_period batches "; + LOG(WARNING) << "The meaning of --test_period is changed: " + << "if equal 0, do test on all test data at the end of " + << "each pass. While if equal non-zero, do test on all test " + << "data every test_period batches "; } if (FLAGS_test_all_data_in_one_period) { - LOG(WARNING) - << "--test_all_data_in_one_period was deprecated, since " - << "we will always do test on all test set "; + LOG(WARNING) << "--test_all_data_in_one_period was deprecated, since " + << "we will always do test on all test set "; } conf->testPeriod = FLAGS_test_period; conf->prevBatchState = FLAGS_prev_batch_state;