提交 0a0c55d2 编写于 作者: W wangyanfei01

more friendly test options

上级 c6a0298e
...@@ -90,13 +90,20 @@ void Tester::testOneDataBatch( ...@@ -90,13 +90,20 @@ void Tester::testOneDataBatch(
testContext_.numSamples += dataBatch.getSize(); testContext_.numSamples += dataBatch.getSize();
} }
void Tester::testOnePeriod() { void Tester::testOnePeriod(bool finishPass) {
DataBatch dataBatch; DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size(); int64_t batchSize = config_->getOptConfig().batch_size();
bool testAllData = bool testAllData =
intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod; (!finishPass && !intconfig_->testBatchesWhileTraining) ||
int batches = (finishPass && !intconfig_->testBatchesWhileEnd);
testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod; int batches;
if (testAllData) {
batches = std::numeric_limits<int>::max();
} else {
batches = finishPass ?
intconfig_->testBatchesWhileEnd : intconfig_->testBatchesWhileTraining;
}
std::vector<Argument> outArgs; std::vector<Argument> outArgs;
...@@ -108,7 +115,8 @@ void Tester::testOnePeriod() { ...@@ -108,7 +115,8 @@ void Tester::testOnePeriod() {
if (intconfig_->prevBatchState) { if (intconfig_->prevBatchState) {
gradientMachine_->resetState(); gradientMachine_->resetState();
} }
if (testAllData) { if ((!finishPass && !intconfig_->testBatchesWhileTraining) ||
(finishPass && !intconfig_->testBatchesWhileEnd)) {
break; break;
} else { } else {
num = testDataProvider_->getNextBatch(batchSize, &dataBatch); num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
......
...@@ -67,7 +67,7 @@ public: ...@@ -67,7 +67,7 @@ public:
* It is convenience to test small set of data when test data set is large and * It is convenience to test small set of data when test data set is large and
* is training at same time. * is training at same time.
*/ */
void testOnePeriod(); void testOnePeriod(bool finishPass = true);
void startTestPeriod(); void startTestPeriod();
void finishTestPeriod(); void finishTestPeriod();
void testOneDataBatch(const DataBatch& dataBatch, void testOneDataBatch(const DataBatch& dataBatch,
......
...@@ -38,12 +38,17 @@ struct TesterConfig { ...@@ -38,12 +38,17 @@ struct TesterConfig {
/** /**
* indicate test period * indicate test period
*/ */
int testPeriod; int testPeriodWhileTraining;
/** /**
* indicate whether testing data in one period * indicate how many batches are used for testing under training
*/ */
bool testAllDataInOnePeriod; bool testBatchesWhileTraining;
/**
* indicate how many batches are used for testing at pass end
*/
bool testBatchesWhileEnd;
/** /**
* indicate whether to save previous batch state * indicate whether to save previous batch state
......
...@@ -40,31 +40,28 @@ limitations under the License. */ ...@@ -40,31 +40,28 @@ limitations under the License. */
#include "TrainerConfigHelper.h" #include "TrainerConfigHelper.h"
P_DEFINE_string(config, "", "Trainer config file"); P_DEFINE_string(config, "", "Trainer config file");
P_DEFINE_int32(test_period, 0,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test on all test data");
P_DEFINE_int32(test_batches_while_training, 0, P_DEFINE_int32(test_period, 0,
"This option was deprecated, use test_period_while_training "
" instead. ");
P_DEFINE_int32(test_period_while_training, 0,
"Run test every so many train batches." "Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches." " If not 0, test log_period batches."
" If 0, test nothing.");
P_DEFINE_int32(test_batches_while_training, 1000,
"test test_batches_while_training batches if test_period != 0."
" If 0, test on all test data"); " If 0, test on all test data");
P_DEFINE_int32(test_batches_while_end, 0, P_DEFINE_int32(test_batches_while_end, 0,
"Run test every so many train batches." "test test_batches_while_end batches at pass end."
" 0 for testing after each pass." " Always run test at pass end."
" If not 0, test log_period batches." " If not 0, test test_batches_while_end batches."
" If 0, test on all test data"); " If 0, test on all test data.");
P_DEFINE_bool(test_all_data_in_one_period, false,
"This option was deprecated, use test_batches_while_training "
"and test_batches_while_end instead");
P_DEFINE_bool(local, true, "Train in local mode or not"); P_DEFINE_bool(local, true, "Train in local mode or not");
P_DEFINE_bool(
test_all_data_in_one_period, false,
"true will test all data in one test peroid."
"Otherwise test (batch_size * log_peroid) data in one test period.");
P_DEFINE_int32(average_test_period, 0, P_DEFINE_int32(average_test_period, 0,
"Do test on average parameter every so" "Do test on average parameter every so"
" many batches. MUST be devided by FLAGS_log_period." " many batches. MUST be devided by FLAGS_log_period."
...@@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { ...@@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
FOR_TIMING(globalStat.reset()); FOR_TIMING(globalStat.reset());
} }
if (testDataProvider_ && FLAGS_test_period > 0 && if (testDataProvider_ && FLAGS_test_period_while_training > 0 &&
trainPassContext_.batchId % FLAGS_test_period == 0) { trainPassContext_.batchId % FLAGS_test_period_while_training == 0) {
tester_->testOnePeriod(); tester_->testOnePeriod(false);
} }
if (FLAGS_saving_period_by_batches > 0 && if (FLAGS_saving_period_by_batches > 0 &&
...@@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) { ...@@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
0 == FLAGS_trainer_id) { 0 == FLAGS_trainer_id) {
trainerInternal_.getParameterUpdater()->catchUpWith(); trainerInternal_.getParameterUpdater()->catchUpWith();
if (testDataProvider_) { if (testDataProvider_) {
tester_->testOnePeriod(); tester_->testOnePeriod(false);
} }
paramUtil_->saveParametersOnePass( paramUtil_->saveParametersOnePass(
trainPassContext_.passId, trainPassContext_.passInnerId); trainPassContext_.passId, trainPassContext_.passInnerId);
...@@ -636,8 +633,19 @@ void Trainer::test() { ...@@ -636,8 +633,19 @@ void Trainer::test() {
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() { std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig; TesterConfig* conf = new TesterConfig;
conf->testPeriod = FLAGS_test_period; if (FLAGS_test_period) {
conf->testAllDataInOnePeriod = FLAGS_test_all_data_in_one_period; LOG(WARNING)
<< "--test_period was deprecated, use --test_period_while_training"
<< "--test_batches_while_training --test_batches_while_end instead.";
}
if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING)
<< "--test_all_data_in_one_period was deprecated, use"
<< " --test_batches_while_training and --test_batches_while_end instead";
}
conf->testPeriodWhileTraining = FLAGS_test_period_while_training;
conf->testBatchesWhileTraining = FLAGS_test_batches_while_training;
conf->testBatchesWhileEnd = FLAGS_test_batches_while_end;
conf->prevBatchState = FLAGS_prev_batch_state; conf->prevBatchState = FLAGS_prev_batch_state;
conf->logPeriod = FLAGS_log_period; conf->logPeriod = FLAGS_log_period;
conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver; conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册