提交 0a0c55d2 编写于 作者: W wangyanfei01

more friendly test options

上级 c6a0298e
......@@ -90,13 +90,20 @@ void Tester::testOneDataBatch(
testContext_.numSamples += dataBatch.getSize();
}
void Tester::testOnePeriod() {
void Tester::testOnePeriod(bool finishPass) {
DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size();
bool testAllData =
intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod;
int batches =
testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod;
(!finishPass && !intconfig_->testBatchesWhileTraining) ||
(finishPass && !intconfig_->testBatchesWhileEnd);
int batches;
if (testAllData) {
batches = std::numeric_limits<int>::max();
} else {
batches = finishPass ?
intconfig_->testBatchesWhileEnd : intconfig_->testBatchesWhileTraining;
}
std::vector<Argument> outArgs;
......@@ -108,7 +115,8 @@ void Tester::testOnePeriod() {
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
if (testAllData) {
if ((!finishPass && !intconfig_->testBatchesWhileTraining) ||
(finishPass && !intconfig_->testBatchesWhileEnd)) {
break;
} else {
num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
......
......@@ -67,7 +67,7 @@ public:
* It is convenience to test small set of data when test data set is large and
* is training at same time.
*/
void testOnePeriod();
void testOnePeriod(bool finishPass = true);
void startTestPeriod();
void finishTestPeriod();
void testOneDataBatch(const DataBatch& dataBatch,
......
......@@ -38,12 +38,17 @@ struct TesterConfig {
/**
* indicate test period
*/
int testPeriod;
int testPeriodWhileTraining;
/**
* indicate whether testing data in one period
* indicate how many batches are used for testing under training
*/
bool testAllDataInOnePeriod;
bool testBatchesWhileTraining;
/**
* indicate how many batches are used for testing at pass end
*/
bool testBatchesWhileEnd;
/**
* indicate whether to save previous batch state
......
......@@ -40,31 +40,28 @@ limitations under the License. */
#include "TrainerConfigHelper.h"
P_DEFINE_string(config, "", "Trainer config file");
P_DEFINE_int32(test_period, 0,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test on all test data");
P_DEFINE_int32(test_batches_while_training, 0,
P_DEFINE_int32(test_period, 0,
"This option was deprecated, use test_period_while_training "
" instead. ");
P_DEFINE_int32(test_period_while_training, 0,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test nothing.");
P_DEFINE_int32(test_batches_while_training, 1000,
"test test_batches_while_training batches if test_period != 0."
" If 0, test on all test data");
P_DEFINE_int32(test_batches_while_end, 0,
"Run test every so many train batches."
" 0 for testing after each pass."
" If not 0, test log_period batches."
" If 0, test on all test data");
"test test_batches_while_end batches at pass end."
" Always run test at pass end."
" If not 0, test test_batches_while_end batches."
" If 0, test on all test data.");
P_DEFINE_bool(test_all_data_in_one_period, false,
"This option was deprecated, use test_batches_while_training "
"and test_batches_while_end instead");
P_DEFINE_bool(local, true, "Train in local mode or not");
P_DEFINE_bool(
test_all_data_in_one_period, false,
"true will test all data in one test peroid."
"Otherwise test (batch_size * log_peroid) data in one test period.");
P_DEFINE_int32(average_test_period, 0,
"Do test on average parameter every so"
" many batches. MUST be devided by FLAGS_log_period."
......@@ -469,9 +466,9 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
FOR_TIMING(globalStat.reset());
}
if (testDataProvider_ && FLAGS_test_period > 0 &&
trainPassContext_.batchId % FLAGS_test_period == 0) {
tester_->testOnePeriod();
if (testDataProvider_ && FLAGS_test_period_while_training > 0 &&
trainPassContext_.batchId % FLAGS_test_period_while_training == 0) {
tester_->testOnePeriod(false);
}
if (FLAGS_saving_period_by_batches > 0 &&
......@@ -480,7 +477,7 @@ void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
0 == FLAGS_trainer_id) {
trainerInternal_.getParameterUpdater()->catchUpWith();
if (testDataProvider_) {
tester_->testOnePeriod();
tester_->testOnePeriod(false);
}
paramUtil_->saveParametersOnePass(
trainPassContext_.passId, trainPassContext_.passInnerId);
......@@ -636,8 +633,19 @@ void Trainer::test() {
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig;
conf->testPeriod = FLAGS_test_period;
conf->testAllDataInOnePeriod = FLAGS_test_all_data_in_one_period;
if (FLAGS_test_period) {
LOG(WARNING)
<< "--test_period was deprecated, use --test_period_while_training"
<< "--test_batches_while_training --test_batches_while_end instead.";
}
if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING)
<< "--test_all_data_in_one_period was deprecated, use"
<< " --test_batches_while_training and --test_batches_while_end instead";
}
conf->testPeriodWhileTraining = FLAGS_test_period_while_training;
conf->testBatchesWhileTraining = FLAGS_test_batches_while_training;
conf->testBatchesWhileEnd = FLAGS_test_batches_while_end;
conf->prevBatchState = FLAGS_prev_batch_state;
conf->logPeriod = FLAGS_log_period;
conf->loadsaveParametersInPserver = FLAGS_loadsave_parameters_in_pserver;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册