diff --git a/paddle/trainer/Tester.cpp b/paddle/trainer/Tester.cpp index d3b88019faa04b7cebf44dd63678aa9d4ffb5252..f57e09d40af7f4c14d181bce2158b1d00fc41c0b 100644 --- a/paddle/trainer/Tester.cpp +++ b/paddle/trainer/Tester.cpp @@ -90,13 +90,20 @@ void Tester::testOneDataBatch( testContext_.numSamples += dataBatch.getSize(); } -void Tester::testOnePeriod() { +void Tester::testOnePeriod(bool finishPass) { DataBatch dataBatch; int64_t batchSize = config_->getOptConfig().batch_size(); + bool testAllData = - intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod; - int batches = - testAllData ? std::numeric_limits::max() : intconfig_->testPeriod; + (!finishPass && !intconfig_->testBatchesWhileTraining) || + (finishPass && !intconfig_->testBatchesWhileEnd); + int batches; + if (testAllData) { + batches = std::numeric_limits::max(); + } else { + batches = finishPass ? + intconfig_->testBatchesWhileEnd : intconfig_->testBatchesWhileTraining; + } std::vector outArgs; @@ -108,7 +115,8 @@ void Tester::testOnePeriod() { if (intconfig_->prevBatchState) { gradientMachine_->resetState(); } - if (testAllData) { + if ((!finishPass && !intconfig_->testBatchesWhileTraining) || + (finishPass && !intconfig_->testBatchesWhileEnd)) { break; } else { num = testDataProvider_->getNextBatch(batchSize, &dataBatch); diff --git a/paddle/trainer/Tester.h b/paddle/trainer/Tester.h index 671ffc5220ebaf2e009225191f6a77e6fea80d33..21e11422aa494ee3925b6c9e204c45ce3268100b 100644 --- a/paddle/trainer/Tester.h +++ b/paddle/trainer/Tester.h @@ -67,7 +67,7 @@ public: * It is convenience to test small set of data when test data set is large and * is training at same time. */ - void testOnePeriod(); + void testOnePeriod(bool finishPass = true); void startTestPeriod(); void finishTestPeriod(); void testOneDataBatch(const DataBatch& dataBatch, diff --git a/paddle/trainer/TesterConfig.h b/paddle/trainer/TesterConfig.h index d5e644ce6124710c76a463d521c16451e22b5462..b7b550dec7cc90b182ed927c246ae52958fb624f 100644 --- a/paddle/trainer/TesterConfig.h +++ b/paddle/trainer/TesterConfig.h @@ -38,12 +38,17 @@ struct TesterConfig { /** * indicate test period */ - int testPeriod; + int testPeriodWhileTraining; /** - * indicate whether testing data in one period + * indicate how many batches are used for testing under training */ - bool testAllDataInOnePeriod; + bool testBatchesWhileTraining; + + /** + * indicate how many batches are used for testing at pass end + */ + bool testBatchesWhileEnd; /** * indicate whether to save previous batch state diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 32c4bad239ec901862a94e5996e171dc5be3cdc3..477813b4748f5f481e47aa697acd3d7ee081c2ea 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -40,31 +40,28 @@ limitations under the License. */ #include "TrainerConfigHelper.h" P_DEFINE_string(config, "", "Trainer config file"); -P_DEFINE_int32(test_period, 0, - "Run test every so many train batches." - " 0 for testing after each pass." - " If not 0, test log_period batches." - " If 0, test on all test data"); -P_DEFINE_int32(test_batches_while_training, 0, +P_DEFINE_int32(test_period, 0, + "This option was deprecated, use test_period_while_training " + " instead. "); +P_DEFINE_int32(test_period_while_training, 0, "Run test every so many train batches." - " 0 for testing after each pass." " If not 0, test log_period batches." + " If 0, test nothing."); +P_DEFINE_int32(test_batches_while_training, 1000, + "test test_batches_while_training batches if test_period != 0." " If 0, test on all test data"); - P_DEFINE_int32(test_batches_while_end, 0, - "Run test every so many train batches." - " 0 for testing after each pass." - " If not 0, test log_period batches." - " If 0, test on all test data"); + "test test_batches_while_end batches at pass end." + " Always run test at pass end." + " If not 0, test test_batches_while_end batches." + " If 0, test on all test data."); +P_DEFINE_bool(test_all_data_in_one_period, false, + "This option was deprecated, use test_batches_while_training " + "and test_batches_while_end instead"); P_DEFINE_bool(local, true, "Train in local mode or not"); -P_DEFINE_bool( - test_all_data_in_one_period, false, - "true will test all data in one test peroid." - "Otherwise test (batch_size * log_peroid) data in one test period."); - P_DEFINE_int32(average_test_period, 0, "Do test on average parameter every so" " many batches. MUST be devided by FLAGS_log_period." @@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { FOR_TIMING(globalStat.reset()); } - if (testDataProvider_ && FLAGS_test_period > 0 && - trainPassContext_.batchId % FLAGS_test_period == 0) { - tester_->testOnePeriod(); + if (testDataProvider_ && FLAGS_test_period_while_training > 0 && + trainPassContext_.batchId % FLAGS_test_period_while_training == 0) { + tester_->testOnePeriod(false); } if (FLAGS_saving_period_by_batches > 0 && @@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { 0 == FLAGS_trainer_id) { trainerInternal_.getParameterUpdater()->catchUpWith(); if (testDataProvider_) { - tester_->testOnePeriod(); + tester_->testOnePeriod(false); } paramUtil_->saveParametersOnePass( trainPassContext_.passId, trainPassContext_.passInnerId); @@ -636,8 +633,19 @@ void Trainer::test() { std::unique_ptr Trainer::createTesterConfig() { TesterConfig* conf = new TesterConfig; - conf->testPeriod = FLAGS_test_period; - conf->testAllDataInOnePeriod = FLAGS_test_all_data_in_one_period; + if (FLAGS_test_period) { + LOG(WARNING) + << "--test_period was deprecated, use --test_period_while_training" + << "--test_batches_while_training --test_batches_while_end instead."; + } + if (FLAGS_test_all_data_in_one_period) { + LOG(WARNING) + << "--test_all_data_in_one_period was deprecated, use" + << " --test_batches_while_training and --test_batches_while_end instead"; + } + conf->testPeriodWhileTraining = FLAGS_test_period_while_training; + conf->testBatchesWhileTraining = FLAGS_test_batches_while_training; + conf->testBatchesWhileEnd = FLAGS_test_batches_while_end; conf->prevBatchState = FLAGS_prev_batch_state; conf->logPeriod = FLAGS_log_period; conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;