提交 3100406c 编写于 作者: Y Yu Yang

Simplify the testOnePeriod method.

上级 239cadef
......@@ -17,22 +17,22 @@ limitations under the License. */
#include <fenv.h>
#include <stdio.h>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <iostream>
#include <limits>
#include <sstream>
#include <google/protobuf/text_format.h>
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/GlobalConstants.h"
#include "TesterConfig.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "TesterConfig.h"
namespace paddle {
......@@ -66,6 +66,7 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper>& config,
}
void Tester::startTestPeriod() {
testDataProvider_->reset();
testEvaluator_->start();
testContext_.cost = 0;
testContext_.numSamples = 0;
......@@ -87,27 +88,18 @@ void Tester::testOneDataBatch(const DataBatch& dataBatch,
void Tester::testOnePeriod() {
DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size();
int batches = std::numeric_limits<int>::max();
std::vector<Argument> outArgs;
startTestPeriod();
for (int i = 0; i < batches; ++i) {
int num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
testDataProvider_->reset();
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
break;
}
while (testDataProvider_->getNextBatch(batchSize, &dataBatch) != 0) {
testOneDataBatch(dataBatch, &outArgs);
}
finishTestPeriod();
}
void Tester::finishTestPeriod() {
if (intconfig_->prevBatchState) {
gradientMachine_->resetState();
}
testEvaluator_->finish();
CHECK_GT(testContext_.numSamples, 0)
<< "There is no samples in your test batch. Possibly "
......
......@@ -17,34 +17,36 @@ limitations under the License. */
#include <fenv.h>
#include <stdio.h>
#include <iostream>
#include <iomanip>
#include <sstream>
#include <iostream>
#include <limits>
#include <sstream>
#include <google/protobuf/text_format.h>
#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/PythonUtil.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/Excepts.h"
#include "paddle/utils/GlobalConstants.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/layers/ValidationLayer.h"
#include "RemoteParameterUpdater.h"
#include "TesterConfig.h"
#include "ThreadParameterUpdater.h"
#include "RemoteParameterUpdater.h"
#include "TrainerConfigHelper.h"
#include "paddle/gserver/gradientmachines/GradientMachineMode.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "paddle/gserver/layers/ValidationLayer.h"
P_DEFINE_string(config, "", "Trainer config file");
P_DEFINE_int32(test_period, 0,
P_DEFINE_int32(test_period,
0,
"if equal 0, do test on all test data at the end of "
"each pass. While if equal non-zero, do test on all test "
"data every test_period batches");
P_DEFINE_bool(test_all_data_in_one_period, false,
P_DEFINE_bool(test_all_data_in_one_period,
false,
"This option was deprecated, since we will always do "
"test on all test set ");
......@@ -392,10 +394,6 @@ void Trainer::startTrain() {
dataProvider_->reset();
}
if (this->testDataProvider_) {
this->testDataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
}
......@@ -630,15 +628,13 @@ void Trainer::test() { tester_->test(); }
std::unique_ptr<TesterConfig> Trainer::createTesterConfig() {
TesterConfig* conf = new TesterConfig;
if (FLAGS_test_period) {
LOG(WARNING)
<< "The meaning of --test_period is changed: "
LOG(WARNING) << "The meaning of --test_period is changed: "
<< "if equal 0, do test on all test data at the end of "
<< "each pass. While if equal non-zero, do test on all test "
<< "data every test_period batches ";
}
if (FLAGS_test_all_data_in_one_period) {
LOG(WARNING)
<< "--test_all_data_in_one_period was deprecated, since "
LOG(WARNING) << "--test_all_data_in_one_period was deprecated, since "
<< "we will always do test on all test set ";
}
conf->testPeriod = FLAGS_test_period;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册