From 1f743d381cfb5dc80c25d4787ea1a703eb3ba07d Mon Sep 17 00:00:00 2001 From: wangyanfei01 Date: Mon, 28 Nov 2016 19:55:34 +0800 Subject: [PATCH] Redesign test_period meaning: * always do test on all test data * do test at the end of each pass if test_period=0, otherwise do test if test_period batches passed --- doc/ui/cmd_argument/argument_outline.md | 11 ++---- doc/ui/cmd_argument/detail_introduction.md | 12 +----- doc/ui/cmd_argument/use_case.md | 4 +- paddle/trainer/Tester.cpp | 20 ++-------- paddle/trainer/Tester.h | 2 +- paddle/trainer/TesterConfig.h | 12 +----- paddle/trainer/Trainer.cpp | 46 +++++++++------------- 7 files changed, 29 insertions(+), 78 deletions(-) diff --git a/doc/ui/cmd_argument/argument_outline.md b/doc/ui/cmd_argument/argument_outline.md index bafa5dfef2c..013edbc9047 100644 --- a/doc/ui/cmd_argument/argument_outline.md +++ b/doc/ui/cmd_argument/argument_outline.md @@ -68,7 +68,7 @@ It looks like there are a lot of arguments. However, most of them are for develo -test_period_while_training +test_period √√ @@ -143,13 +143,8 @@ It looks like there are a lot of arguments. However, most of them are for develo -testing during trainingtest_batches_while_training -√√√< - - - -testing during trainingtest_batches_while_end -√√√< +testing during trainingtest_period +√√ diff --git a/doc/ui/cmd_argument/detail_introduction.md b/doc/ui/cmd_argument/detail_introduction.md index 1f7e406a53e..823a2266191 100644 --- a/doc/ui/cmd_argument/detail_introduction.md +++ b/doc/ui/cmd_argument/detail_introduction.md @@ -109,8 +109,8 @@ - Load parameter from this pass to test. - type: int32 (default: -1). -* `--test_period_while_training` - - Run test every test_period_while_training batches while doing training. If not 0, test test_batches_while_training batches, if 0, test nothing. +* `--test_period` + - if equal 0, do test on all test data at the end of each pass while if equal non-zero, do test on all test data once each test_period batches passed while training is going on. - type: int32 (default: 0). * `--test_wait` @@ -121,14 +121,6 @@ - File that saves the model list when testing. It was set automatically when using cluster submitting environment after setting model_path. - type: string (default: "", null). -* `--test_batches_while_training` - - Test test_batches_while_training batches if test_batches_while_training != 0 while doing training. If 0, test on all test data. - - type: bool (default: 1000). - -* `--test_batches_while_end` - - Test test_batches_while_end batches if test_batches_while_end != 0 at pass end. If 0, test on all test data. - - type: bool (default: 0). - * `--predict_output_dir` - Directory that saves the layer output. It is configured in Outputs() in network config. Default, this argument is null, meaning save nothing. Specify this directory if you want to save feature map of some layers in testing mode. Note that, layer outputs are values after activation function. - type: string (default: "", null). diff --git a/doc/ui/cmd_argument/use_case.md b/doc/ui/cmd_argument/use_case.md index b243560106d..4d7bb33f36f 100644 --- a/doc/ui/cmd_argument/use_case.md +++ b/doc/ui/cmd_argument/use_case.md @@ -10,9 +10,7 @@ paddle train \ --config=network_config \ --save_dir=output \ --trainer_count=COUNT \ #(default:1) - --test_period_while_training=M \ #(default:0) - --test_batches_while_training=BATCHES \#(default:1000) - --test_batches_while_end=BATCHES \ #(default:0) + --test_period=M \ #(default:0) --num_passes=N \ #(defalut:100) --log_period=K \ #(default:100) --dot_period=1000 \ #(default:1) diff --git a/paddle/trainer/Tester.cpp b/paddle/trainer/Tester.cpp index f57e09d40af..217b8c60be5 100644 --- a/paddle/trainer/Tester.cpp +++ b/paddle/trainer/Tester.cpp @@ -90,20 +90,11 @@ void Tester::testOneDataBatch( testContext_.numSamples += dataBatch.getSize(); } -void Tester::testOnePeriod(bool finishPass) { +void Tester::testOnePeriod() { DataBatch dataBatch; int64_t batchSize = config_->getOptConfig().batch_size(); - bool testAllData = - (!finishPass && !intconfig_->testBatchesWhileTraining) || - (finishPass && !intconfig_->testBatchesWhileEnd); - int batches; - if (testAllData) { - batches = std::numeric_limits::max(); - } else { - batches = finishPass ? - intconfig_->testBatchesWhileEnd : intconfig_->testBatchesWhileTraining; - } + int batches = std::numeric_limits::max(); std::vector outArgs; @@ -115,12 +106,7 @@ void Tester::testOnePeriod(bool finishPass) { if (intconfig_->prevBatchState) { gradientMachine_->resetState(); } - if ((!finishPass && !intconfig_->testBatchesWhileTraining) || - (finishPass && !intconfig_->testBatchesWhileEnd)) { - break; - } else { - num = testDataProvider_->getNextBatch(batchSize, &dataBatch); - } + break; } testOneDataBatch(dataBatch, &outArgs); } diff --git a/paddle/trainer/Tester.h b/paddle/trainer/Tester.h index 21e11422aa4..671ffc5220e 100644 --- a/paddle/trainer/Tester.h +++ b/paddle/trainer/Tester.h @@ -67,7 +67,7 @@ public: * It is convenience to test small set of data when test data set is large and * is training at same time. */ - void testOnePeriod(bool finishPass = true); + void testOnePeriod(); void startTestPeriod(); void finishTestPeriod(); void testOneDataBatch(const DataBatch& dataBatch, diff --git a/paddle/trainer/TesterConfig.h b/paddle/trainer/TesterConfig.h index b7b550dec7c..8392bbcda51 100644 --- a/paddle/trainer/TesterConfig.h +++ b/paddle/trainer/TesterConfig.h @@ -38,17 +38,7 @@ struct TesterConfig { /** * indicate test period */ - int testPeriodWhileTraining; - - /** - * indicate how many batches are used for testing under training - */ - bool testBatchesWhileTraining; - - /** - * indicate how many batches are used for testing at pass end - */ - bool testBatchesWhileEnd; + int testPeriod; /** * indicate whether to save previous batch state diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 507a080cc4f..aca896770ae 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -42,24 +42,13 @@ limitations under the License. */ P_DEFINE_string(config, "", "Trainer config file"); P_DEFINE_int32(test_period, 0, - "This option was deprecated, use test_period_while_training " - " instead. "); -P_DEFINE_int32(test_period_while_training, 0, - "Run test every test_period_while_training batches." - " If not 0, test test_batches_while_training batches." - " If 0, test nothing."); -P_DEFINE_int32(test_batches_while_training, 1000, - "test test_batches_while_training batches if " - "test_batches_while_training != 0." - " If 0, test on all test data"); -P_DEFINE_int32(test_batches_while_end, 0, - "test test_batches_while_end batches at pass end." - " Always run test at pass end." - " If not 0, test test_batches_while_end batches." - " If 0, test on all test data."); + "if equal 0, do test on all test data at the end of " + "each pass while if equal non-zero, do test on all test " + "data once each test_period batches passed while " + "training is going on"); P_DEFINE_bool(test_all_data_in_one_period, false, - "This option was deprecated, use test_batches_while_training " - "and test_batches_while_end instead"); + "This option was deprecated, since we will always do " + "test on all test set "); P_DEFINE_bool(local, true, "Train in local mode or not"); @@ -467,9 +456,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { FOR_TIMING(globalStat.reset()); } - if (testDataProvider_ && FLAGS_test_period_while_training > 0 && - trainPassContext_.batchId % FLAGS_test_period_while_training == 0) { - tester_->testOnePeriod(false); + if (testDataProvider_ && FLAGS_test_period > 0 && + trainPassContext_.batchId % FLAGS_test_period == 0) { + tester_->testOnePeriod(); } if (FLAGS_saving_period_by_batches > 0 && @@ -478,7 +467,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { 0 == FLAGS_trainer_id) { trainerInternal_.getParameterUpdater()->catchUpWith(); if (testDataProvider_) { - tester_->testOnePeriod(false); + tester_->testOnePeriod(); } paramUtil_->saveParametersOnePass( trainPassContext_.passId, trainPassContext_.passInnerId); @@ -636,17 +625,18 @@ std::unique_ptr Trainer::createTesterConfig() { TesterConfig* conf = new TesterConfig; if (FLAGS_test_period) { LOG(WARNING) - << "--test_period was deprecated, use --test_period_while_training" - << "--test_batches_while_training --test_batches_while_end instead."; + << "The meaning of --test_period is changed: " + << "if equal 0, do test on all test data at the end of " + << "each pass while if equal non-zero, do test on all test " + << "data once each test_period batches passed while " + << "training is going on"; } if (FLAGS_test_all_data_in_one_period) { LOG(WARNING) - << "--test_all_data_in_one_period was deprecated, use" - << " --test_batches_while_training and --test_batches_while_end instead"; + << "--test_all_data_in_one_period was deprecated, since " + << "we will always do test on all test set "; } - conf->testPeriodWhileTraining = FLAGS_test_period_while_training; - conf->testBatchesWhileTraining = FLAGS_test_batches_while_training; - conf->testBatchesWhileEnd = FLAGS_test_batches_while_end; + conf->testPeriod = FLAGS_test_period; conf->prevBatchState = FLAGS_prev_batch_state; conf->logPeriod = FLAGS_log_period; conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver; -- GitLab