diff --git a/doc/ui/cmd_argument/argument_outline.md b/doc/ui/cmd_argument/argument_outline.md index bafa5dfef2c63ad5ed87211f1a2c0ec6cbcbbe05..013edbc9047817d7f6b82c4d5188412bd2ce41d6 100644 --- a/doc/ui/cmd_argument/argument_outline.md +++ b/doc/ui/cmd_argument/argument_outline.md @@ -68,7 +68,7 @@ It looks like there are a lot of arguments. However, most of them are for develo -test_period_while_training +test_period √√ @@ -143,13 +143,8 @@ It looks like there are a lot of arguments. However, most of them are for develo -testing during trainingtest_batches_while_training -√√√< - - - -testing during trainingtest_batches_while_end -√√√< +testing during trainingtest_period +√√ diff --git a/doc/ui/cmd_argument/detail_introduction.md b/doc/ui/cmd_argument/detail_introduction.md index 1f7e406a53ed4b1cf76a3aa7460326ee634d8114..823a2266191b5f8e8b932b6ccef7af721601ed78 100644 --- a/doc/ui/cmd_argument/detail_introduction.md +++ b/doc/ui/cmd_argument/detail_introduction.md @@ -109,8 +109,8 @@ - Load parameter from this pass to test. - type: int32 (default: -1). -* `--test_period_while_training` - - Run test every test_period_while_training batches while doing training. If not 0, test test_batches_while_training batches, if 0, test nothing. +* `--test_period` + - if equal 0, do test on all test data at the end of each pass while if equal non-zero, do test on all test data once each test_period batches passed while training is going on. - type: int32 (default: 0). * `--test_wait` @@ -121,14 +121,6 @@ - File that saves the model list when testing. It was set automatically when using cluster submitting environment after setting model_path. - type: string (default: "", null). -* `--test_batches_while_training` - - Test test_batches_while_training batches if test_batches_while_training != 0 while doing training. If 0, test on all test data. - - type: bool (default: 1000). - -* `--test_batches_while_end` - - Test test_batches_while_end batches if test_batches_while_end != 0 at pass end. If 0, test on all test data. - - type: bool (default: 0). - * `--predict_output_dir` - Directory that saves the layer output. It is configured in Outputs() in network config. Default, this argument is null, meaning save nothing. Specify this directory if you want to save feature map of some layers in testing mode. Note that, layer outputs are values after activation function. - type: string (default: "", null). diff --git a/doc/ui/cmd_argument/use_case.md b/doc/ui/cmd_argument/use_case.md index b243560106df1281c9a0094f4ae6a0156292b36c..4d7bb33f36fe258ee24796eedc9296065923e58f 100644 --- a/doc/ui/cmd_argument/use_case.md +++ b/doc/ui/cmd_argument/use_case.md @@ -10,9 +10,7 @@ paddle train \ --config=network_config \ --save_dir=output \ --trainer_count=COUNT \ #(default:1) - --test_period_while_training=M \ #(default:0) - --test_batches_while_training=BATCHES \#(default:1000) - --test_batches_while_end=BATCHES \ #(default:0) + --test_period=M \ #(default:0) --num_passes=N \ #(defalut:100) --log_period=K \ #(default:100) --dot_period=1000 \ #(default:1) diff --git a/paddle/trainer/Tester.cpp b/paddle/trainer/Tester.cpp index f57e09d40af7f4c14d181bce2158b1d00fc41c0b..217b8c60be5855a6013d2dc39a9b2b669f43a224 100644 --- a/paddle/trainer/Tester.cpp +++ b/paddle/trainer/Tester.cpp @@ -90,20 +90,11 @@ void Tester::testOneDataBatch( testContext_.numSamples += dataBatch.getSize(); } -void Tester::testOnePeriod(bool finishPass) { +void Tester::testOnePeriod() { DataBatch dataBatch; int64_t batchSize = config_->getOptConfig().batch_size(); - bool testAllData = - (!finishPass && !intconfig_->testBatchesWhileTraining) || - (finishPass && !intconfig_->testBatchesWhileEnd); - int batches; - if (testAllData) { - batches = std::numeric_limits::max(); - } else { - batches = finishPass ? - intconfig_->testBatchesWhileEnd : intconfig_->testBatchesWhileTraining; - } + int batches = std::numeric_limits::max(); std::vector outArgs; @@ -115,12 +106,7 @@ void Tester::testOnePeriod(bool finishPass) { if (intconfig_->prevBatchState) { gradientMachine_->resetState(); } - if ((!finishPass && !intconfig_->testBatchesWhileTraining) || - (finishPass && !intconfig_->testBatchesWhileEnd)) { - break; - } else { - num = testDataProvider_->getNextBatch(batchSize, &dataBatch); - } + break; } testOneDataBatch(dataBatch, &outArgs); } diff --git a/paddle/trainer/Tester.h b/paddle/trainer/Tester.h index 21e11422aa494ee3925b6c9e204c45ce3268100b..671ffc5220ebaf2e009225191f6a77e6fea80d33 100644 --- a/paddle/trainer/Tester.h +++ b/paddle/trainer/Tester.h @@ -67,7 +67,7 @@ public: * It is convenience to test small set of data when test data set is large and * is training at same time. */ - void testOnePeriod(bool finishPass = true); + void testOnePeriod(); void startTestPeriod(); void finishTestPeriod(); void testOneDataBatch(const DataBatch& dataBatch, diff --git a/paddle/trainer/TesterConfig.h b/paddle/trainer/TesterConfig.h index b7b550dec7cc90b182ed927c246ae52958fb624f..8392bbcda512058fe830de99275e23c11bef76cc 100644 --- a/paddle/trainer/TesterConfig.h +++ b/paddle/trainer/TesterConfig.h @@ -38,17 +38,7 @@ struct TesterConfig { /** * indicate test period */ - int testPeriodWhileTraining; - - /** - * indicate how many batches are used for testing under training - */ - bool testBatchesWhileTraining; - - /** - * indicate how many batches are used for testing at pass end - */ - bool testBatchesWhileEnd; + int testPeriod; /** * indicate whether to save previous batch state diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 507a080cc4fbe58d81523509b54e8a2f7ae3be74..aca896770ae4ed9341518770374695bb2ea24a38 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -42,24 +42,13 @@ limitations under the License. */ P_DEFINE_string(config, "", "Trainer config file"); P_DEFINE_int32(test_period, 0, - "This option was deprecated, use test_period_while_training " - " instead. "); -P_DEFINE_int32(test_period_while_training, 0, - "Run test every test_period_while_training batches." - " If not 0, test test_batches_while_training batches." - " If 0, test nothing."); -P_DEFINE_int32(test_batches_while_training, 1000, - "test test_batches_while_training batches if " - "test_batches_while_training != 0." - " If 0, test on all test data"); -P_DEFINE_int32(test_batches_while_end, 0, - "test test_batches_while_end batches at pass end." - " Always run test at pass end." - " If not 0, test test_batches_while_end batches." - " If 0, test on all test data."); + "if equal 0, do test on all test data at the end of " + "each pass while if equal non-zero, do test on all test " + "data once each test_period batches passed while " + "training is going on"); P_DEFINE_bool(test_all_data_in_one_period, false, - "This option was deprecated, use test_batches_while_training " - "and test_batches_while_end instead"); + "This option was deprecated, since we will always do " + "test on all test set "); P_DEFINE_bool(local, true, "Train in local mode or not"); @@ -467,9 +456,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { FOR_TIMING(globalStat.reset()); } - if (testDataProvider_ && FLAGS_test_period_while_training > 0 && - trainPassContext_.batchId % FLAGS_test_period_while_training == 0) { - tester_->testOnePeriod(false); + if (testDataProvider_ && FLAGS_test_period > 0 && + trainPassContext_.batchId % FLAGS_test_period == 0) { + tester_->testOnePeriod(); } if (FLAGS_saving_period_by_batches > 0 && @@ -478,7 +467,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { 0 == FLAGS_trainer_id) { trainerInternal_.getParameterUpdater()->catchUpWith(); if (testDataProvider_) { - tester_->testOnePeriod(false); + tester_->testOnePeriod(); } paramUtil_->saveParametersOnePass( trainPassContext_.passId, trainPassContext_.passInnerId); @@ -636,17 +625,18 @@ std::unique_ptr Trainer::createTesterConfig() { TesterConfig* conf = new TesterConfig; if (FLAGS_test_period) { LOG(WARNING) - << "--test_period was deprecated, use --test_period_while_training" - << "--test_batches_while_training --test_batches_while_end instead."; + << "The meaning of --test_period is changed: " + << "if equal 0, do test on all test data at the end of " + << "each pass while if equal non-zero, do test on all test " + << "data once each test_period batches passed while " + << "training is going on"; } if (FLAGS_test_all_data_in_one_period) { LOG(WARNING) - << "--test_all_data_in_one_period was deprecated, use" - << " --test_batches_while_training and --test_batches_while_end instead"; + << "--test_all_data_in_one_period was deprecated, since " + << "we will always do test on all test set "; } - conf->testPeriodWhileTraining = FLAGS_test_period_while_training; - conf->testBatchesWhileTraining = FLAGS_test_batches_while_training; - conf->testBatchesWhileEnd = FLAGS_test_batches_while_end; + conf->testPeriod = FLAGS_test_period; conf->prevBatchState = FLAGS_prev_batch_state; conf->logPeriod = FLAGS_log_period; conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;