提交 239cadef 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #411 from backyes/bugfix_test_period

Re-design command options for testing for better understanding
...@@ -143,7 +143,7 @@ It looks like there are a lot of arguments. However, most of them are for develo ...@@ -143,7 +143,7 @@ It looks like there are a lot of arguments. However, most of them are for develo
</tr> </tr>
<tr> <tr>
<td class="left" rowspan = "2">testing during training</td><td class="left">test_all_data_in_one_period</td> <td class="left" rowspan = "2">testing during training</td><td class="left">test_period</td>
<td class="left"></td><td class="left"></td><td class="left"></td><td class="left"></td> <td class="left"></td><td class="left"></td><td class="left"></td><td class="left"></td>
</tr> </tr>
......
...@@ -31,7 +31,7 @@ ...@@ -31,7 +31,7 @@
- type: string (default: null). - type: string (default: null).
* `--version` * `--version`
- Whether to print version infomatrion. - Whether to print version information.
- type: bool (default: 0). - type: bool (default: 0).
* `--show_layer_stat` * `--show_layer_stat`
...@@ -110,8 +110,8 @@ ...@@ -110,8 +110,8 @@
- type: int32 (default: -1). - type: int32 (default: -1).
* `--test_period` * `--test_period`
- Run testing every test_period train batches. If not set, run testing each pass. - if equal 0, do test on all test data at the end of each pass. While if equal non-zero, do test on all test data every test_period batches.
- type: int32 (default: 1000). - type: int32 (default: 0).
* `--test_wait` * `--test_wait`
- Whether to wait for parameter per pass if not exist. If set test_data_path in submitting environment of cluster, it will launch one process to perfom testing, so we need to set test_wait=1. Note that in the cluster submitting environment, this argument has been set True by default. - Whether to wait for parameter per pass if not exist. If set test_data_path in submitting environment of cluster, it will launch one process to perfom testing, so we need to set test_wait=1. Note that in the cluster submitting environment, this argument has been set True by default.
...@@ -121,10 +121,6 @@ ...@@ -121,10 +121,6 @@
- File that saves the model list when testing. It was set automatically when using cluster submitting environment after setting model_path. - File that saves the model list when testing. It was set automatically when using cluster submitting environment after setting model_path.
- type: string (default: "", null). - type: string (default: "", null).
* `--test_all_data_in_one_period`
- This argument is usually used in testing period during traning. If true, all data will be tested in one test period. Otherwise (batch_size * log_peroid) data will be tested.
- type: bool (default: 0).
* `--predict_output_dir` * `--predict_output_dir`
- Directory that saves the layer output. It is configured in Outputs() in network config. Default, this argument is null, meaning save nothing. Specify this directory if you want to save feature map of some layers in testing mode. Note that, layer outputs are values after activation function. - Directory that saves the layer output. It is configured in Outputs() in network config. Default, this argument is null, meaning save nothing. Specify this directory if you want to save feature map of some layers in testing mode. Note that, layer outputs are values after activation function.
- type: string (default: "", null). - type: string (default: "", null).
......
...@@ -10,9 +10,8 @@ paddle train \ ...@@ -10,9 +10,8 @@ paddle train \
--config=network_config \ --config=network_config \
--save_dir=output \ --save_dir=output \
--trainer_count=COUNT \ #(default:1) --trainer_count=COUNT \ #(default:1)
--test_period=M \ #(default:1000) --test_period=M \ #(default:0)
--test_all_data_in_one_period=true \ #(default:false) --num_passes=N \ #(defalut:100)
--num_passes=N \ #(defalut:100)
--log_period=K \ #(default:100) --log_period=K \ #(default:100)
--dot_period=1000 \ #(default:1) --dot_period=1000 \ #(default:1)
#[--show_parameter_stats_period=100] \ #(default:0) #[--show_parameter_stats_period=100] \ #(default:0)
......
...@@ -87,10 +87,8 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch, ...@@ -87,10 +87,8 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
void Tester::testOnePeriod() { void Tester::testOnePeriod() {
DataBatch dataBatch; DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size(); int64_t batchSize = config_->getOptConfig().batch_size();
bool testAllData =
intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod; int batches = std::numeric_limits<int>::max();
int batches =
testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod;
std::vector<Argument> outArgs; std::vector<Argument> outArgs;
...@@ -102,11 +100,7 @@ void Tester::testOnePeriod() { ...@@ -102,11 +100,7 @@ void Tester::testOnePeriod() {
if (intconfig_->prevBatchState) { if (intconfig_->prevBatchState) {
gradientMachine_->resetState(); gradientMachine_->resetState();
} }
if (testAllData) { break;
break;
} else {
num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
}
} }
testOneDataBatch(dataBatch, &outArgs); testOneDataBatch(dataBatch, &outArgs);
} }
......
...@@ -39,11 +39,6 @@ struct TesterConfig { ...@@ -39,11 +39,6 @@ struct TesterConfig {
*/ */
int testPeriod; int testPeriod;
/**
* indicate whether testing data in one period
*/
bool testAllDataInOnePeriod;
/** /**
* indicate whether to save previous batch state * indicate whether to save previous batch state
*/ */
......
...@@ -39,20 +39,16 @@ limitations under the License. */ ...@@ -39,20 +39,16 @@ limitations under the License. */
#include "TrainerConfigHelper.h" #include "TrainerConfigHelper.h"
P_DEFINE_string(config, "", "Trainer config file"); P_DEFINE_string(config, "", "Trainer config file");
P_DEFINE_int32(test_period,
0,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test on all test data");
P_DEFINE_bool(local, true, "Train in local mode or not"); P_DEFINE_int32(test_period, 0,
"if equal 0, do test on all test data at the end of "
"each pass. While if equal non-zero, do test on all test "
"data every test_period batches");
P_DEFINE_bool(test_all_data_in_one_period, false,
"This option was deprecated, since we will always do "
"test on all test set ");
P_DEFINE_bool( P_DEFINE_bool(local, true, "Train in local mode or not");
test_all_data_in_one_period,
false,
"true will test all data in one test peroid."
"Otherwise test (batch_size * log_peroid) data in one test period.");
P_DEFINE_int32(average_test_period, P_DEFINE_int32(average_test_period,
0, 0,
...@@ -633,8 +629,19 @@ void Trainer::test() { tester_->test(); } ...@@ -633,8 +629,19 @@ void Trainer::test() { tester_->test(); }
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() { std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig; TesterConfig* conf = new TesterConfig;
if (FLAGS_test_period) {
LOG(WARNING)
<< "The meaning of --test_period is changed: "
<< "if equal 0, do test on all test data at the end of "
<< "each pass. While if equal non-zero, do test on all test "
<< "data every test_period batches ";
}
if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING)
<< "--test_all_data_in_one_period was deprecated, since "
<< "we will always do test on all test set ";
}
conf->testPeriod = FLAGS_test_period; conf->testPeriod = FLAGS_test_period;
conf->testAllDataInOnePeriod = FLAGS_test_all_data_in_one_period;
conf->prevBatchState = FLAGS_prev_batch_state; conf->prevBatchState = FLAGS_prev_batch_state;
conf->logPeriod = FLAGS_log_period; conf->logPeriod = FLAGS_log_period;
conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver; conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册