提交 3100406c 编写于 作者: Y Yu Yang

Simplify the testOnePeriod method.

上级 239cadef
...@@ -17,22 +17,22 @@ limitations under the License. */ ...@@ -17,22 +17,22 @@ limitations under the License. */
#include <fenv.h> #include <fenv.h>
#include <stdio.h> #include <stdio.h>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <sstream> #include <iostream>
#include <limits> #include <limits>
#include <sstream>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h" #include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h" #include "paddle/utils/Util.h"
#include "paddle/utils/GlobalConstants.h"
#include "TesterConfig.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h" #include "paddle/gserver/layers/ValidationLayer.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "TesterConfig.h"
namespace paddle { namespace paddle {
...@@ -66,6 +66,7 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper>& config, ...@@ -66,6 +66,7 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper>& config,
} }
void Tester::startTestPeriod() { void Tester::startTestPeriod() {
testDataProvider_->reset();
testEvaluator_->start(); testEvaluator_->start();
testContext_.cost = 0; testContext_.cost = 0;
testContext_.numSamples = 0; testContext_.numSamples = 0;
...@@ -87,27 +88,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch, ...@@ -87,27 +88,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
void Tester::testOnePeriod() { void Tester::testOnePeriod() {
DataBatch dataBatch; DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size(); int64_t batchSize = config_->getOptConfig().batch_size();
int batches = std::numeric_limits<int>::max();
std::vector<Argument> outArgs; std::vector<Argument> outArgs;
startTestPeriod(); startTestPeriod();
for (int i = 0; i < batches; ++i) { while (testDataProvider_->getNextBatch(batchSize, &dataBatch) != 0) {
int num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
testDataProvider_->reset();
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
break;
}
testOneDataBatch(dataBatch, &outArgs); testOneDataBatch(dataBatch, &outArgs);
} }
finishTestPeriod(); finishTestPeriod();
} }
void Tester::finishTestPeriod() { void Tester::finishTestPeriod() {
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
testEvaluator_->finish(); testEvaluator_->finish();
CHECK_GT(testContext_.numSamples, 0) CHECK_GT(testContext_.numSamples, 0)
<< "There is no samples in your test batch. Possibly " << "There is no samples in your test batch. Possibly "
......
...@@ -17,34 +17,36 @@ limitations under the License. */ ...@@ -17,34 +17,36 @@ limitations under the License. */
#include <fenv.h> #include <fenv.h>
#include <stdio.h> #include <stdio.h>
#include <iostream>
#include <iomanip> #include <iomanip>
#include <sstream> #include <iostream>
#include <limits> #include <limits>
#include <sstream>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h" #include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h" #include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h" #include "paddle/utils/Util.h"
#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "RemoteParameterUpdater.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/layers/ValidationLayer.h"
#include "TesterConfig.h" #include "TesterConfig.h"
#include "ThreadParameterUpdater.h" #include "ThreadParameterUpdater.h"
#include "RemoteParameterUpdater.h"
#include "TrainerConfigHelper.h" #include "TrainerConfigHelper.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h"
P_DEFINE_string(config, "", "Trainer config file"); P_DEFINE_string(config, "", "Trainer config file");
P_DEFINE_int32(test_period, 0, P_DEFINE_int32(test_period,
0,
"if equal 0, do test on all test data at the end of " "if equal 0, do test on all test data at the end of "
"each pass. While if equal non-zero, do test on all test " "each pass. While if equal non-zero, do test on all test "
"data every test_period batches"); "data every test_period batches");
P_DEFINE_bool(test_all_data_in_one_period, false, P_DEFINE_bool(test_all_data_in_one_period,
false,
"This option was deprecated, since we will always do " "This option was deprecated, since we will always do "
"test on all test set "); "test on all test set ");
...@@ -392,10 +394,6 @@ void Trainer::startTrain() { ...@@ -392,10 +394,6 @@ void Trainer::startTrain() {
dataProvider_->reset(); dataProvider_->reset();
} }
if (this->testDataProvider_) {
this->testDataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
} }
...@@ -630,15 +628,13 @@ void Trainer::test() { tester_->test(); } ...@@ -630,15 +628,13 @@ void Trainer::test() { tester_->test(); }
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() { std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig; TesterConfig* conf = new TesterConfig;
if (FLAGS_test_period) { if (FLAGS_test_period) {
LOG(WARNING) LOG(WARNING) << "The meaning of --test_period is changed: "
<< "The meaning of --test_period is changed: "
<< "if equal 0, do test on all test data at the end of " << "if equal 0, do test on all test data at the end of "
<< "each pass. While if equal non-zero, do test on all test " << "each pass. While if equal non-zero, do test on all test "
<< "data every test_period batches "; << "data every test_period batches ";
} }
if (FLAGS_test_all_data_in_one_period) { if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING) LOG(WARNING) << "--test_all_data_in_one_period was deprecated, since "
<< "--test_all_data_in_one_period was deprecated, since "
<< "we will always do test on all test set "; << "we will always do test on all test set ";
} }
conf->testPeriod = FLAGS_test_period; conf->testPeriod = FLAGS_test_period;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册