提交 1f743d38 编写于 作者: W wangyanfei01

Redesign test_period meaning:

 * always do test on all test data
 * do test at the end of each pass if test_period=0, otherwise do test if test_period batches passed
上级 b62c80f1
......@@ -68,7 +68,7 @@ It looks like there are a lot of arguments. However, most of them are for develo
</tr>
<tr>
<td class="left">test_period_while_training</td>
<td class="left">test_period</td>
<td class="left"></td><td class="left"></td><td class="left"></td><td class="left"></td>
</tr>
......@@ -143,13 +143,8 @@ It looks like there are a lot of arguments. However, most of them are for develo
</tr>
<tr>
<td class="left" rowspan = "2">testing during training</td><td class="left">test_batches_while_training</td>
<td class="left"></td><td class="left"></td><td class="left"><</td><td class="left"></td>
</tr>
<tr>
<td class="left" rowspan = "2">testing during training</td><td class="left">test_batches_while_end</td>
<td class="left"></td><td class="left"></td><td class="left"><</td><td class="left"></td>
<td class="left" rowspan = "2">testing during training</td><td class="left">test_period</td>
<td class="left"></td><td class="left"></td><td class="left"></td><td class="left"></td>
</tr>
<tr>
......
......@@ -109,8 +109,8 @@
- Load parameter from this pass to test.
- type: int32 (default: -1).
* `--test_period_while_training`
- Run test every test_period_while_training batches while doing training. If not 0, test test_batches_while_training batches, if 0, test nothing.
* `--test_period`
- if equal 0, do test on all test data at the end of each pass while if equal non-zero, do test on all test data once each test_period batches passed while training is going on.
- type: int32 (default: 0).
* `--test_wait`
......@@ -121,14 +121,6 @@
- File that saves the model list when testing. It was set automatically when using cluster submitting environment after setting model_path.
- type: string (default: "", null).
* `--test_batches_while_training`
- Test test_batches_while_training batches if test_batches_while_training != 0 while doing training. If 0, test on all test data.
- type: bool (default: 1000).
* `--test_batches_while_end`
- Test test_batches_while_end batches if test_batches_while_end != 0 at pass end. If 0, test on all test data.
- type: bool (default: 0).
* `--predict_output_dir`
- Directory that saves the layer output. It is configured in Outputs() in network config. Default, this argument is null, meaning save nothing. Specify this directory if you want to save feature map of some layers in testing mode. Note that, layer outputs are values after activation function.
- type: string (default: "", null).
......
......@@ -10,9 +10,7 @@ paddle train \
--config=network_config \
--save_dir=output \
--trainer_count=COUNT \ #(default:1)
--test_period_while_training=M \ #(default:0)
--test_batches_while_training=BATCHES \#(default:1000)
--test_batches_while_end=BATCHES \ #(default:0)
--test_period=M \ #(default:0)
--num_passes=N \ #(defalut:100)
--log_period=K \ #(default:100)
--dot_period=1000 \ #(default:1)
......
......@@ -90,20 +90,11 @@ void Tester::testOneDataBatch(
testContext_.numSamples += dataBatch.getSize();
}
void Tester::testOnePeriod(bool finishPass) {
void Tester::testOnePeriod() {
DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size();
bool testAllData =
(!finishPass && !intconfig_->testBatchesWhileTraining) ||
(finishPass && !intconfig_->testBatchesWhileEnd);
int batches;
if (testAllData) {
batches = std::numeric_limits<int>::max();
} else {
batches = finishPass ?
intconfig_->testBatchesWhileEnd : intconfig_->testBatchesWhileTraining;
}
int batches = std::numeric_limits<int>::max();
std::vector<Argument> outArgs;
......@@ -115,12 +106,7 @@ void Tester::testOnePeriod(bool finishPass) {
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
if ((!finishPass && !intconfig_->testBatchesWhileTraining) ||
(finishPass && !intconfig_->testBatchesWhileEnd)) {
break;
} else {
num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
}
break;
}
testOneDataBatch(dataBatch, &outArgs);
}
......
......@@ -67,7 +67,7 @@ public:
* It is convenience to test small set of data when test data set is large and
* is training at same time.
*/
void testOnePeriod(bool finishPass = true);
void testOnePeriod();
void startTestPeriod();
void finishTestPeriod();
void testOneDataBatch(const DataBatch& dataBatch,
......
......@@ -38,17 +38,7 @@ struct TesterConfig {
/**
* indicate test period
*/
int testPeriodWhileTraining;
/**
* indicate how many batches are used for testing under training
*/
bool testBatchesWhileTraining;
/**
* indicate how many batches are used for testing at pass end
*/
bool testBatchesWhileEnd;
int testPeriod;
/**
* indicate whether to save previous batch state
......
......@@ -42,24 +42,13 @@ limitations under the License. */
P_DEFINE_string(config, "", "Trainer config file");
P_DEFINE_int32(test_period, 0,
"This option was deprecated, use test_period_while_training "
" instead. ");
P_DEFINE_int32(test_period_while_training, 0,
"Run test every test_period_while_training batches."
" If not 0, test test_batches_while_training batches."
" If 0, test nothing.");
P_DEFINE_int32(test_batches_while_training, 1000,
"test test_batches_while_training batches if "
"test_batches_while_training != 0."
" If 0, test on all test data");
P_DEFINE_int32(test_batches_while_end, 0,
"test test_batches_while_end batches at pass end."
" Always run test at pass end."
" If not 0, test test_batches_while_end batches."
" If 0, test on all test data.");
"if equal 0, do test on all test data at the end of "
"each pass while if equal non-zero, do test on all test "
"data once each test_period batches passed while "
"training is going on");
P_DEFINE_bool(test_all_data_in_one_period, false,
"This option was deprecated, use test_batches_while_training "
"and test_batches_while_end instead");
"This option was deprecated, since we will always do "
"test on all test set ");
P_DEFINE_bool(local, true, "Train in local mode or not");
......@@ -467,9 +456,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
FOR_TIMING(globalStat.reset());
}
if (testDataProvider_ && FLAGS_test_period_while_training > 0 &&
trainPassContext_.batchId % FLAGS_test_period_while_training == 0) {
tester_->testOnePeriod(false);
if (testDataProvider_ && FLAGS_test_period > 0 &&
trainPassContext_.batchId % FLAGS_test_period == 0) {
tester_->testOnePeriod();
}
if (FLAGS_saving_period_by_batches > 0 &&
......@@ -478,7 +467,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
0 == FLAGS_trainer_id) {
trainerInternal_.getParameterUpdater()->catchUpWith();
if (testDataProvider_) {
tester_->testOnePeriod(false);
tester_->testOnePeriod();
}
paramUtil_->saveParametersOnePass(
trainPassContext_.passId, trainPassContext_.passInnerId);
......@@ -636,17 +625,18 @@ std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig;
if (FLAGS_test_period) {
LOG(WARNING)
<< "--test_period was deprecated, use --test_period_while_training"
<< "--test_batches_while_training --test_batches_while_end instead.";
<< "The meaning of --test_period is changed: "
<< "if equal 0, do test on all test data at the end of "
<< "each pass while if equal non-zero, do test on all test "
<< "data once each test_period batches passed while "
<< "training is going on";
}
if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING)
<< "--test_all_data_in_one_period was deprecated, use"
<< " --test_batches_while_training and --test_batches_while_end instead";
<< "--test_all_data_in_one_period was deprecated, since "
<< "we will always do test on all test set ";
}
conf->testPeriodWhileTraining = FLAGS_test_period_while_training;
conf->testBatchesWhileTraining = FLAGS_test_batches_while_training;
conf->testBatchesWhileEnd = FLAGS_test_batches_while_end;
conf->testPeriod = FLAGS_test_period;
conf->prevBatchState = FLAGS_prev_batch_state;
conf->logPeriod = FLAGS_log_period;
conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册