提交 cbe734b3 编写于 作者: E emailweixu 提交者: Yu Yang

Python trainer api (#193)

* Python trainer API and demo

* Adding missing PaddleAPIPrivate.h

* Adding api_train.sh

* More comments

* Bump up patch version to 0b3
上级 46bd5f53
......@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8)
project(paddle CXX C)
set(PADDLE_MAJOR_VERSION 0)
set(PADDLE_MINOR_VERSION 8)
set(PADDLE_PATCH_VERSION 0b2)
set(PADDLE_PATCH_VERSION 0b3)
set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION})
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
......
......@@ -27,6 +27,7 @@ function(generate_python_api target_name)
COMMAND swig -python -c++ -outcurrentdir -I../ api/Paddle.swig
&& mv ${PROJ_ROOT}/paddle/swig_paddle.py ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py
DEPENDS ${PROJ_ROOT}/paddle/api/Paddle.swig
${PROJ_ROOT}/paddle/api/PaddleAPI.h
WORKING_DIRECTORY ${PROJ_ROOT}/paddle
COMMENT "Generate Python API from swig")
add_custom_target(${target_name} ALL DEPENDS
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import random
from paddle.trainer.config_parser import parse_config
from py_paddle import swig_paddle as api
from py_paddle import DataProviderConverter
from paddle.trainer.PyDataProvider2 \
import integer_value, integer_value_sequence, sparse_binary_vector
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--train_data",
type=str, required=False, help="train data file")
parser.add_argument("--test_data", type=str, help="test data file")
parser.add_argument("--config",
type=str, required=True, help="config file name")
parser.add_argument("--dict_file", required=True, help="dictionary file")
parser.add_argument("--seq",
default=1, type=int,
help="whether use sequence training")
parser.add_argument("--use_gpu", default=0, type=int,
help="whether use GPU for training")
parser.add_argument("--trainer_count", default=1, type=int,
help="Number of threads for training")
parser.add_argument("--num_passes", default=5, type=int,
help="Number of training passes")
return parser.parse_args()
UNK_IDX = 0
def load_data(file_name, word_dict):
with open(file_name, 'r') as f:
for line in f:
label, comment = line.strip().split('\t')
words = comment.split()
word_slot = [word_dict.get(w, UNK_IDX) for w in words]
yield word_slot, int(label)
def load_dict(dict_file):
word_dict = dict()
with open(dict_file, 'r') as f:
for i, line in enumerate(f):
w = line.strip().split()[0]
word_dict[w] = i
return word_dict
def main():
options = parse_arguments()
api.initPaddle("--use_gpu=%s" % options.use_gpu,
"--trainer_count=%s" % options.trainer_count)
word_dict = load_dict(options.dict_file)
train_dataset = list(load_data(options.train_data, word_dict))
if options.test_data:
test_dataset = list(load_data(options.test_data, word_dict))
else:
test_dataset = None
trainer_config = parse_config(options.config,
"dict_file=%s" % options.dict_file)
# No need to have data provider for trainer
trainer_config.ClearField('data_config')
trainer_config.ClearField('test_data_config')
# create a GradientMachine from the model configuratin
model = api.GradientMachine.createFromConfigProto(
trainer_config.model_config)
# create a trainer for the gradient machine
trainer = api.Trainer.create(trainer_config, model)
# create a data converter which converts data to PaddlePaddle
# internal format
input_types = [
integer_value_sequence(len(word_dict)) if options.seq
else sparse_binary_vector(len(word_dict)),
integer_value(2)]
converter = DataProviderConverter(input_types)
batch_size = trainer_config.opt_config.batch_size
trainer.startTrain()
for train_pass in xrange(options.num_passes):
trainer.startTrainPass()
random.shuffle(train_dataset)
for pos in xrange(0, len(train_dataset), batch_size):
batch = itertools.islice(train_dataset, pos, pos + batch_size)
size = min(batch_size, len(train_dataset) - pos)
trainer.trainOneDataBatch(size, converter(batch))
trainer.finishTrainPass()
if test_dataset:
trainer.startTestPeriod();
for pos in xrange(0, len(test_dataset), batch_size):
batch = itertools.islice(test_dataset, pos, pos + batch_size)
size = min(batch_size, len(test_dataset) - pos)
trainer.testOneDataBatch(size, converter(batch))
trainer.finishTestPeriod()
trainer.finishTrain()
if __name__ == '__main__':
main()
#!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Note: if using trainer_config.emb.py, trainer_config.cnn.py
# or trainer_config.lstm.py, you need to change --seq to --seq=1
# because they are sequence models.
python api_train.py \
--config=trainer_config.lr.py \
--trainer_count=2 \
--num_passes=15 \
--use_gpu=0 \
--seq=0 \
--train_data=data/train.txt \
--test_data=data/test.txt \
--dict_file=data/dict.txt \
2>&1 | tee 'train.log'
......@@ -24,7 +24,7 @@ paddle train \
--config=$cfg \
--save_dir=./output \
--trainer_count=4 \
--log_period=20 \
--log_period=100 \
--num_passes=15 \
--use_gpu=false \
--show_parameter_stats_period=100 \
......
......@@ -16,7 +16,7 @@
from paddle.trainer_config_helpers import *
dict_file = "./data/dict.txt"
dict_file = get_config_arg('dict_file', str, "./data/dict.txt")
word_dict = dict()
with open(dict_file, 'r') as f:
for i, line in enumerate(f):
......@@ -63,7 +63,6 @@ if not is_predict:
label = data_layer(name="label", size=2)
# Define cross-entropy classification loss and error.
classification_cost(input=output, label=label)
cls = classification_cost(input=output, label=label)
outputs(cls)
else:
......
......@@ -46,8 +46,8 @@ class SentimentPrediction():
conf = parse_config(train_conf, "is_predict=1")
self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
self.network.loadParameters(self.model_dir)
slots = [integer_value_sequence(self.dict_dim)]
self.converter = DataProviderConverter(slots)
input_types = [integer_value_sequence(self.dict_dim)]
self.converter = DataProviderConverter(input_types)
def load_dict(self):
"""
......
......@@ -14,27 +14,10 @@ limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/parameter/Argument.h"
struct ArgumentsPrivate {
std::vector<paddle::Argument> outputs;
inline paddle::Argument& getArg(size_t idx) throw(RangeError) {
if (idx < outputs.size()) {
return outputs[idx];
} else {
RangeError e;
throw e;
}
}
template <typename T>
std::shared_ptr<T>& cast(void* rawPtr) const {
return *(std::shared_ptr<T>*)(rawPtr);
}
};
size_t Arguments::getSlotNum() const { return m->outputs.size(); }
Arguments* Arguments::createArguments(size_t slotNum) {
......
......@@ -40,6 +40,8 @@ configure_file(
generate_python_api(python_swig_sources)
file(GLOB PY_PADDLE_PYTHON_FILES ${PROJ_ROOT}/paddle/py_paddle/*.py)
# TODO(yuyang18) : make wheel name calculated by cmake
add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel
......@@ -55,6 +57,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
paddle_trainer
paddle_api
paddle_cuda
${PY_PADDLE_PYTHON_FILES}
)
install(DIRECTORY ${PROJ_ROOT}/paddle/dist/
......
......@@ -14,17 +14,9 @@ limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/trainer/Trainer.h"
struct TrainerConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> conf;
TrainerConfigPrivate() : conf(std::make_shared<paddle::TrainerConfig>()) {}
};
struct ModelConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> conf;
};
struct ParameterConfigPrivate {
paddle::ParameterPtr parameter;
paddle::ParameterConfig config;
......@@ -39,19 +31,6 @@ struct ParameterConfigPrivate {
}
};
struct OptimizationConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> trainer_config;
paddle::OptimizationConfig config;
paddle::OptimizationConfig& getConfig() {
if (trainer_config != nullptr) {
return *trainer_config->mutable_opt_config();
} else {
return config;
}
}
};
TrainerConfig::TrainerConfig() : m(new TrainerConfigPrivate()) {}
TrainerConfig::~TrainerConfig() { delete m; }
......@@ -59,10 +38,19 @@ TrainerConfig::~TrainerConfig() { delete m; }
TrainerConfig* TrainerConfig::createFromTrainerConfigFile(
const std::string& confPath) {
LOG(INFO) << "load trainer config from " << confPath;
paddle::TrainerConfigHelper helper(confPath);
//! TODO(yuyang18): Make TrainerConfigPrivate to TrainerConfigHelper
auto conf = std::make_shared<paddle::TrainerConfigHelper>(confPath);
auto retv = new TrainerConfig();
*retv->m->conf = helper.getConfig();
retv->m->conf = conf;
return retv;
}
TrainerConfig* TrainerConfig::createFromProtoString(
const std::string& str) {
auto retv = new TrainerConfig();
paddle::TrainerConfig trainerConfigProto;
auto conf = std::make_shared<paddle::TrainerConfigHelper>(trainerConfigProto);
CHECK(conf->getMutableConfig().ParseFromString(str));
retv->m->conf = conf;
return retv;
}
......@@ -76,10 +64,6 @@ ModelConfig* TrainerConfig::getModelConfig() const {
return retv;
}
void* ModelConfig::getPaddleModelConfig() const {
return m->conf->mutable_model_config();
}
ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {}
ParameterConfig::~ParameterConfig() {
......@@ -132,8 +116,6 @@ OptimizationConfig* TrainerConfig::getOptimizationConfig() const {
return opt_config;
}
void* OptimizationConfig::getRawPtr() { return &m->getConfig(); }
OptimizationConfig* OptimizationConfig::createFromProtoString(
const std::string& str) {
auto conf = new OptimizationConfig();
......
......@@ -14,30 +14,22 @@ limitations under the License. */
#include "PaddleAPI.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "PaddleAPIPrivate.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "Internal.h"
std::vector<int> GradientMachine::defaultParamTypes = {
PARAMETER_VALUE, PARAMETER_GRADIENT, PARAMETER_MOMENTUM};
struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
template <typename T>
inline T& cast(void* ptr) {
return *(T*)(ptr);
}
};
GradientMachine::GradientMachine() : m(new GradientMachinePrivate()) {}
GradientMachine::~GradientMachine() { delete m; }
GradientMachine* GradientMachine::createFromPaddleModelPtr(
void* confPtr, GradientMatchineCreateMode mode,
const void* confPtr, GradientMatchineCreateMode mode,
const std::vector<int>& types) {
auto& conf = *(paddle::ModelConfig*)(confPtr);
auto& conf = *(const paddle::ModelConfig*)(confPtr);
std::vector<ParameterType> realTypes;
staticCastVector(&realTypes, types);
auto machineRawPtr = paddle::GradientMachine::create(conf, mode, realTypes);
......@@ -66,7 +58,7 @@ GradientMachine* GradientMachine::createByConfigProtoStr(
GradientMachine* GradientMachine::createByModelConfig(
ModelConfig* conf, GradientMatchineCreateMode mode,
const std::vector<int>& types) {
auto confPtr = (paddle::ModelConfig*)conf->getPaddleModelConfig();
auto confPtr = &conf->m->conf->getModelConfig();
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
}
......
......@@ -446,7 +446,6 @@ struct OptimizationConfigPrivate;
class OptimizationConfig {
DISABLE_COPY_AND_ASSIGN(OptimizationConfig);
OptimizationConfig();
void* getRawPtr();
public:
static OptimizationConfig* createFromProtoString(const std::string& str);
......@@ -462,6 +461,7 @@ private:
friend class TrainerConfig;
friend class ParameterOptimizer;
friend class Trainer;
};
struct ParameterPrivate;
......@@ -515,8 +515,6 @@ public:
virtual ~ModelConfig();
private:
void* getPaddleModelConfig() const;
ModelConfigPrivate* m;
friend class TrainerConfig;
friend struct TrainerConfigPrivate;
......@@ -539,6 +537,7 @@ public:
static TrainerConfig* createFromTrainerConfigFile(
const std::string& configPath);
static TrainerConfig* createFromProtoString(const std::string& str);
ModelConfig* getModelConfig() const;
......@@ -546,6 +545,7 @@ public:
private:
TrainerConfigPrivate* m;
friend class Trainer;
};
/**
......@@ -700,11 +700,12 @@ private:
GradientMachinePrivate* m;
static GradientMachine* createFromPaddleModelPtr(
void* confPtr, GradientMatchineCreateMode mode,
const void* confPtr, GradientMatchineCreateMode mode,
const std::vector<int>& types);
// Not to use c++ 11 init-list, so we use static var as function default arg.
static std::vector<int> defaultParamTypes;
friend class Trainer;
};
struct TrainerPrivate;
......@@ -712,6 +713,7 @@ class Trainer {
private:
TrainerPrivate* m;
Trainer();
Trainer(TrainerConfig* optConfig, GradientMachine* gm);
DISABLE_COPY_AND_ASSIGN(Trainer);
public:
......@@ -720,38 +722,42 @@ public:
/// Create A Trainer By TrainerConfig. using paddle command line.
static Trainer* createByCommandLine() throw(IOError);
/// Start Train.
static Trainer* create(TrainerConfig* optConfig, GradientMachine* gm)
throw(IOError);
/// Start training
void startTrain();
/// Finish training
void finishTrain();
/// Start Pass.
/// Start a pass.
void startTrainPass();
void finishTrainPass();
void setBatchSize(size_t batchSize);
/// Finish a pass
void finishTrainPass();
/**
* Train one batch,
*
* @param batchSize -1 wiil use command line or batch size set before,
* otherwise use this batchSize for train.
*
* @return true if all batch finished.
*/
bool trainOneBatch(size_t batchSize = -1UL);
bool trainOneBatch(size_t batchSize);
bool prepareBatchData(size_t batchSize = -1UL);
void trainOneDataBatch(size_t batchSize, const Arguments& args);
void finishTrainOneBatch();
void startTestPeriod();
void testOneDataBatch(size_t batchSize, const Arguments& args);
void finishTestPeriod();
void forwardOneBatch() throw(UnsupportError);
void forwardOneBatch(size_t batchSize);
Arguments* getNetworkOutput();
Arguments* getForwardOutput();
Matrix* getLayerOutput(const std::string& layerName);
};
/// The N-Best results generated from one input sequence.
/// the N-Best results generated from one input sequence.
class ISequenceResults {
public:
virtual ~ISequenceResults();
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/trainer/TrainerConfigHelper.h"
#pragma once
struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
template <typename T>
inline T& cast(void* ptr) {
return *(T*)(ptr);
}
};
struct OptimizationConfigPrivate {
std::shared_ptr<paddle::TrainerConfigHelper> trainer_config;
paddle::OptimizationConfig config;
const paddle::OptimizationConfig& getConfig() {
if (trainer_config != nullptr) {
return trainer_config->getOptConfig();
} else {
return config;
}
}
};
struct TrainerConfigPrivate {
std::shared_ptr<paddle::TrainerConfigHelper> conf;
TrainerConfigPrivate() {}
};
struct ModelConfigPrivate {
std::shared_ptr<paddle::TrainerConfigHelper> conf;
};
struct ArgumentsPrivate {
std::vector<paddle::Argument> outputs;
inline paddle::Argument& getArg(size_t idx) throw(RangeError) {
if (idx < outputs.size()) {
return outputs[idx];
} else {
RangeError e;
throw e;
}
}
template <typename T>
std::shared_ptr<T>& cast(void* rawPtr) const {
return *(std::shared_ptr<T>*)(rawPtr);
}
};
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/parameter/ParameterOptimizer.h"
#include "Internal.h"
#include <algorithm>
......@@ -60,10 +61,9 @@ ParameterOptimizer::~ParameterOptimizer() {
ParameterOptimizer* ParameterOptimizer::create(OptimizationConfig* config) {
CHECK(config != nullptr);
auto opt_config_ptr = (paddle::OptimizationConfig*)config->getRawPtr();
auto retOptimizer = new ParameterOptimizer();
retOptimizer->m->optimizer.reset(
paddle::ParameterOptimizer::create(*opt_config_ptr, false));
paddle::ParameterOptimizer::create(config->m->getConfig(), false));
return retOptimizer;
}
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include <stdlib.h>
#include <memory>
......@@ -30,31 +31,17 @@ P_DECLARE_string(config);
P_DECLARE_string(init_model_path);
P_DECLARE_int32(start_pass);
struct TrainPassContext {
int64_t batchId;
int32_t batchSize;
real avgTestCost;
int64_t numAvgTests;
int passInnerId;
paddle::DataBatch data;
std::vector<paddle::Argument> forwardOutput;
};
struct TrainerPrivate : public paddle::Trainer {
void startTrain();
void finishTrain();
void startTrainPass();
void finishTrainPass();
bool _trainOneBatch();
bool _prepareBatchData();
void _forwardOneBatch() throw(UnsupportError);
bool _trainOneBatch(size_t batchSize);
bool forwardOneBatch(size_t batchSize);
void forwardOneDataBatch(const std::vector<paddle::Argument>& inArgs);
void setBatchSize(size_t batchSize);
std::vector<paddle::Argument>& getForwardOutput();
void startTestPeriod();
void finishTestPeriod();
void testOneDataBatch(const paddle::DataBatch& dataBatch);
TrainerPrivate() : paddle::Trainer() {}
TrainPassContext trainPassContext;
};
Trainer::Trainer() : m(new TrainerPrivate()) {
......@@ -75,61 +62,76 @@ Trainer* Trainer::createByCommandLine() throw(IOError) {
}
}
void Trainer::startTrain() { m->startTrain(); }
Trainer::Trainer(TrainerConfig* config, GradientMachine* gm)
: m(new TrainerPrivate()) {
m->init(config->m->conf, /* testing= */false, gm ? gm->m->machine : nullptr);
}
void TrainerPrivate::startTrain() {
srand(this->config_->getConfig().start_pass() + 1);
this->dataProvider_->reset();
this->trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
Trainer* Trainer::create(TrainerConfig* config, GradientMachine* gm)
throw(IOError)
{
auto retv = new Trainer(config, gm);
if (retv->m->getConfig().IsInitialized()) {
return retv;
} else {
retv->m->getConfig().CheckInitialized();
throw IOError();
}
}
void Trainer::finishTrain() { m->finishTrain(); }
void Trainer::startTrain() { m->startTrain(); }
void TrainerPrivate::finishTrain() {
this->trainerInternal_.getGradientMachine()->finish();
}
void Trainer::finishTrain() { m->finishTrain(); }
void Trainer::startTrainPass() { m->startTrainPass(); }
void TrainerPrivate::startTrainPass() {
this->stats_.reset();
this->trainPassContext.batchId = 0;
this->trainPassContext.batchSize = this->config_->getOptConfig().batch_size();
this->trainPassContext.avgTestCost = 0;
this->trainPassContext.numAvgTests = 0;
this->trainPassContext.passInnerId = 0;
this->trainerInternal_.getParameterUpdater()->startPass();
this->evaluator_->start();
}
void Trainer::finishTrainPass() { m->finishTrainPass(); }
void TrainerPrivate::finishTrainPass() {
this->trainerInternal_.getGradientMachine()->onPassEnd();
this->trainerInternal_.getParameterUpdater()->finishPass();
evaluator_->finish();
void Trainer::trainOneDataBatch(size_t batchSize, const Arguments& inArgs) {
paddle::DataBatch dataBatch;
dataBatch.getStreams() = inArgs.m->outputs;
dataBatch.setSize(batchSize);
m->trainOneDataBatch(dataBatch);
}
void Trainer::setBatchSize(size_t batchSize) {
this->m->trainPassContext.batchSize = batchSize;
bool Trainer::trainOneBatch(size_t batchSize) {
return m->_trainOneBatch(batchSize);
}
bool Trainer::trainOneBatch(size_t batchSize) {
if (batchSize == -1UL) {
this->setBatchSize(batchSize);
bool TrainerPrivate::_trainOneBatch(size_t batchSize) {
paddle::DataBatch dataBatch;
CHECK(dataProvider_) << "data_provider is not specified";
int num = dataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
return false;
}
return m->_trainOneBatch();
trainOneDataBatch(dataBatch);
return false;
}
bool TrainerPrivate::_trainOneBatch() {
if (this->_prepareBatchData()) {
return true;
void TrainerPrivate::startTestPeriod() {
if (!tester_) {
createTester();
}
this->trainerInternal_.trainOneBatch(this->trainPassContext.batchId,
this->trainPassContext.data);
return false;
tester_->startTestPeriod();
}
void Trainer::startTestPeriod() { m->startTestPeriod(); }
void TrainerPrivate::testOneDataBatch(const paddle::DataBatch& dataBatch) {
tester_->testOneDataBatch(dataBatch, &forwardOutput_);
}
void Trainer::testOneDataBatch(size_t batchSize, const Arguments& args) {
paddle::DataBatch dataBatch;
dataBatch.getStreams() = args.m->outputs;
dataBatch.setSize(batchSize);
m->testOneDataBatch(dataBatch);
}
void TrainerPrivate::finishTestPeriod() { tester_->finishTestPeriod(); }
void Trainer::finishTestPeriod() { m->finishTestPeriod(); }
Matrix* Trainer::getLayerOutput(const std::string& layerName) {
auto nn = std::dynamic_pointer_cast<paddle::NeuralNetwork>(
this->m->getGradientMachine());
......@@ -138,46 +140,37 @@ Matrix* Trainer::getLayerOutput(const std::string& layerName) {
return Matrix::createByPaddleMatrixPtr(&m);
}
bool Trainer::prepareBatchData(size_t batchSize) {
if (batchSize != -1UL) {
this->setBatchSize(batchSize);
void Trainer::forwardOneBatch(size_t batchSize) { m->forwardOneBatch(batchSize); }
bool TrainerPrivate::forwardOneBatch(size_t batchSize) {
CHECK(dataProvider_) << "data_provider is not specified";
paddle::DataBatch dataBatch;
int num = dataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
return false;
}
return this->m->_prepareBatchData();
}
bool TrainerPrivate::_prepareBatchData() {
int num = dataProvider_->getNextBatch(this->trainPassContext.batchSize,
&this->trainPassContext.data);
return num == 0;
forwardOneDataBatch(dataBatch.getStreams());
return true;
}
void Trainer::finishTrainOneBatch() { ++m->trainPassContext.batchId; }
void TrainerPrivate::forwardOneDataBatch(
const std::vector<paddle::Argument>& inArgs) {
void Trainer::forwardOneBatch() throw(UnsupportError) { m->_forwardOneBatch(); }
void TrainerPrivate::_forwardOneBatch() throw(UnsupportError) {
auto& dataBatch = this->trainPassContext.data;
int64_t actualBatchSize = dataBatch.getSize();
if (actualBatchSize == 0) {
return;
}
const std::vector<paddle::Argument>& inArgs = dataBatch.getStreams();
std::vector<paddle::Argument>& outArgs = this->trainPassContext.forwardOutput;
outArgs.clear();
paddle::PassType passType =
this->trainerInternal_.getParameterUpdater()->startBatch(actualBatchSize);
std::vector<paddle::Argument>& outArgs = forwardOutput_;
if (config_->getOptConfig().use_sparse_remote_updater()) {
this->trainerInternal_.getGradientMachine()->prefetch(inArgs);
this->trainerInternal_.getParameterUpdater()->getParametersRemote();
trainerInternal_.getGradientMachine()->prefetch(inArgs);
trainerInternal_.getParameterUpdater()->getParametersRemote();
}
this->trainerInternal_.getGradientMachine()->forward(
inArgs, &outArgs, passType);
trainerInternal_.getGradientMachine()->forward(
inArgs, &outArgs, paddle::PASS_TEST);
}
Arguments* Trainer::getForwardOutput() {
return Arguments::createByPaddleArgumentVector(&m->getForwardOutput());
}
Arguments* Trainer::getNetworkOutput() {
return Arguments::createByPaddleArgumentVector(
&m->trainPassContext.forwardOutput);
std::vector<paddle::Argument>& TrainerPrivate::getForwardOutput() {
return forwardOutput_;
}
......@@ -30,7 +30,7 @@ source .test_env/bin/activate
pip --timeout 600 install ../../dist/*.whl
test_list="testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py"
test_list="testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py testTrainer.py"
export PYTHONPATH=$PWD/../../../python/
......
......@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from py_paddle import swig_paddle, DataProviderWrapperConverter
from py_paddle import swig_paddle
import paddle.trainer.config_parser
from paddle.trainer.PyDataProviderWrapper import DenseSlot, IndexSlot
import numpy
import util
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer.config_parser import parse_config
from paddle.trainer.config_parser import logger
from py_paddle import swig_paddle
import util
def main():
trainer_config = parse_config(
"./testTrainConfig.py", "")
model = swig_paddle.GradientMachine.createFromConfigProto(
trainer_config.model_config)
trainer = swig_paddle.Trainer.create(trainer_config, model)
trainer.startTrain()
for train_pass in xrange(2):
trainer.startTrainPass()
num = 0
cost = 0
while True: # Train one batch
batch_size = 1000
data, atEnd = util.loadMNISTTrainData(batch_size)
if atEnd:
break
trainer.trainOneDataBatch(batch_size, data)
outs = trainer.getForwardOutput()
cost += sum(outs[0]['value'])
num += batch_size
trainer.finishTrainPass()
logger.info('train cost=%f' % (cost / num))
trainer.startTestPeriod()
num = 0
cost = 0
while True: # Test one batch
batch_size = 1000
data, atEnd = util.loadMNISTTrainData(batch_size)
if atEnd:
break
trainer.testOneDataBatch(batch_size, data)
outs = trainer.getForwardOutput()
cost += sum(outs[0]['value'])
num += batch_size
trainer.finishTestPeriod()
logger.info('test cost=%f' % (cost / num))
trainer.finishTrain()
if __name__ == '__main__':
swig_paddle.initPaddle("--use_gpu=0", "--trainer_count=1")
main()
......@@ -63,7 +63,7 @@ class SparseBinaryScanner(IScanner):
def scan(self, dat):
self.extend_cols(dat)
self.__rows__.append(len(dat) + self.__rows__[-1])
self.__rows__.append(len(self.__cols__))
self.__height__ += 1
def extend_cols(self, dat):
......
......@@ -79,6 +79,20 @@ class __ParameterCallbackWrapper__(swig_paddle.UpdateCallback):
else:
return __ParameterCallbackWrapper__(callback).__disown__()
def __arguments_to_numpy__(i, arg):
assert isinstance(arg, swig_paddle.Arguments)
value = arg.getSlotValue(i)
if value is not None:
assert isinstance(value, swig_paddle.Matrix)
value = value.copyToNumpyMat()
ids = arg.getSlotIds(i)
if ids is not None:
assert isinstance(ids, swig_paddle.IVector)
ids = ids.copyToNumpyArray()
return {
"value": value,
"id": ids
}
def __monkeypatch_gradient_machine__():
"""
......@@ -88,20 +102,6 @@ def __monkeypatch_gradient_machine__():
swig_paddle.GradientMachine.loadFromConfigFile = \
staticmethod(loadGradientMachine)
def __arguments_to_numpy__(i, arg):
assert isinstance(arg, swig_paddle.Arguments)
value = arg.getSlotValue(i)
if value is not None:
assert isinstance(value, swig_paddle.Matrix)
value = value.copyToNumpyMat()
ids = arg.getSlotIds(i)
if ids is not None:
assert isinstance(ids, swig_paddle.IVector)
ids = ids.copyToNumpyArray()
return {
"value": value,
"id": ids
}
def __matrix_to_numpy__(m):
if isinstance(m, swig_paddle.Matrix):
......@@ -126,7 +126,7 @@ def __monkeypatch_gradient_machine__():
:type paramTypes: list of int
:return: paddle.GradientMachine
"""
assert isinstance(protoObj, paddle.proto.ModelConfig_pb2.ModelConfig)
assert isinstance(protoObj, paddle.proto.ModelConfig)
return swig_paddle.GradientMachine.createByConfigProtoStr(
protoObj.SerializeToString(), createMode, paramTypes)
......@@ -460,13 +460,29 @@ def __monkey_patch_protobuf_objects__():
"""
assert isinstance(protoObj,
paddle.proto.TrainerConfig_pb2.OptimizationConfig)
paddle.proto.OptimizationConfig)
return swig_paddle.OptimizationConfig.createFromProtoString(
protoObj.SerializeToString())
swig_paddle.OptimizationConfig.createFromProto = staticmethod(
OptimizationConfig_createFromProto)
def TrainerConfig_createFromProto(protoObj):
"""
Create a new paddle.TrainerConfig from
proto.OptimizationConfig
:param protoObj: proto.TrainerConfig
:return: paddle.TrainerConfig
"""
assert isinstance(protoObj,
paddle.proto.TrainerConfig)
return swig_paddle.TrainerConfig.createFromProtoString(
protoObj.SerializeToString())
swig_paddle.TrainerConfig.createFromProto = staticmethod(
TrainerConfig_createFromProto)
def __monkey_patch_parameter__():
def getBufs(self):
......@@ -483,9 +499,66 @@ def __monkey_patch_parameter__():
swig_paddle.Parameter.getBufs = getBufs
def __monkey_patch_trainer__():
swig_paddle.Trainer.__create__ = staticmethod(swig_paddle.Trainer.create)
def Trainer_create(config, model=None):
"""
Create a trainer for model with TrainerCOnfig trainer_config
trainer_config.model_config will be ignored when model is supplied.
Trainer.trainOneBatch() and Trainer.forwardOneBatch() can be used only
when trainer_config.data_config is set.
A typical usage for Trainer is:
.. code-block:: python
trainer = Trainer.create(trainer_config, model)
for p in xrange(num_passes)
while True:
data = get_next_batch(batch_size)
if not data:
break
trainer.trainOneDataBatch(batch_size, data)
trainer.finishTrainPass()
trainer.finishTrain()
The trainer will take care of logging, model saving, distributed
training, etc.
:param config: trainer configuration
:type config: paddle.proto.TrainerConfig
:param model: the model to be trained
:type model: swig_paddle.GradientMachine
:return: a trainer
:rtype swig_paddle.Trainer
"""
assert isinstance(config, paddle.proto.TrainerConfig)
if model is not None:
assert isinstance(model, swig_paddle.GradientMachine)
return swig_paddle.Trainer.__create__(
swig_paddle.TrainerConfig.createFromProto(config), model)
swig_paddle.Trainer.create = staticmethod(Trainer_create)
swig_paddle.Trainer.__getForwardOutput__ = \
swig_paddle.Trainer.getForwardOutput
def getForwardOutput(self):
"""
Get the netword outputs from the previous trainOneBatch(),
trainOneDataBatch(), testOneDataPatch(), or forwardOneBatch() call.
:return: list of dictionary with keys ['id', 'value'], each value is a
numpy.ndarray.
"""
outArgs = self.__getForwardOutput__()
return [__arguments_to_numpy__(i, outArgs) for i in xrange(
outArgs.getSlotNum())]
swig_paddle.Trainer.getForwardOutput = getForwardOutput
def monkeypatches():
patches = [__monkeypatch_init_paddle__, __monkeypatch_gradient_machine__,
__monkey_patch_protobuf_objects__,
__monkey_patch_parameter__]
__monkey_patch_parameter__, __monkey_patch_trainer__]
for patch in patches:
patch()
......@@ -71,24 +71,36 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper> &config,
parameterUpdater_));
}
void Tester::startTestPeriod() {
testEvaluator_->start();
testContext_.cost = 0;
testContext_.numSamples = 0;
parameterUpdater_->apply();
if (intconfig_->prevBatchState) {
gradientMachine_->getState(*intconfig_->trainState);
gradientMachine_->setState(*intconfig_->testState);
}
}
void Tester::testOneDataBatch(
const DataBatch& dataBatch, std::vector<Argument>* outArgs) {
testContext_.cost += forwardOneBatch(
dataBatch, testEvaluator_.get(), outArgs);
testContext_.numSamples += dataBatch.getSize();
}
void Tester::testOnePeriod() {
DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size();
testEvaluator_->start();
real cost = 0;
int64_t numSamples = 0;
bool testAllData =
intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod;
int batches =
testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod;
parameterUpdater_->apply();
if (intconfig_->prevBatchState) {
gradientMachine_->getState(*intconfig_->trainState);
gradientMachine_->setState(*intconfig_->testState);
}
std::vector<Argument> outArgs;
startTestPeriod();
for (int i = 0; i < batches; ++i) {
int num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
......@@ -102,13 +114,17 @@ void Tester::testOnePeriod() {
num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
}
}
cost += testOneBatch(dataBatch, testEvaluator_.get());
numSamples += num;
testOneDataBatch(dataBatch, &outArgs);
}
}
void Tester::finishTestPeriod() {
testEvaluator_->finish();
CHECK_GT(numSamples, 0) << "There is no samples in your test batch. Possibly "
"wrong implementation of DataProvidor.reset()";
LOG(INFO) << " Test samples=" << numSamples << " cost=" << cost / numSamples
CHECK_GT(testContext_.numSamples, 0)
<< "There is no samples in your test batch. Possibly "
"wrong implementation of DataProvidor.reset()";
LOG(INFO) << " Test samples=" << testContext_.numSamples
<< " cost=" << testContext_.cost / testContext_.numSamples
<< " Eval: " << *testEvaluator_;
parameterUpdater_->restore();
if (intconfig_->prevBatchState) {
......@@ -128,9 +144,11 @@ int64_t Tester::testOneBatchById(int64_t batchId) {
return 0;
}
std::vector<Argument> outArgs;
stats_ += std::pair<int64_t, real>{
actualBatchSize,
testOneBatch(dataBatch, testEvaluator_.get())};
forwardOneBatch(dataBatch, testEvaluator_.get(), &outArgs)};
if (((batchId + 1) % intconfig_->logPeriod) == 0) {
LOG(INFO) << " Batch=" << batchId + 1 << " " << stats_.getStats(false);
......@@ -139,7 +157,10 @@ int64_t Tester::testOneBatchById(int64_t batchId) {
return actualBatchSize;
}
real Tester::testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator) {
real Tester::forwardOneBatch(const DataBatch &dataBatch,
Evaluator *evaluator,
std::vector<Argument>* pOutArgs) {
auto& outArgs = *pOutArgs;
const std::vector<Argument>& inArgs = dataBatch.getStreams();
if (intconfig_->loadsaveParametersInPserver) {
REGISTER_TIMER("prefetch");
......@@ -148,12 +169,11 @@ real Tester::testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator) {
true /*after apply*/);
}
std::vector<Argument> outArgs;
gradientMachine_->forward(inArgs, &outArgs, PASS_TEST);
// write features if set this flag and outArgs is not empty
std::string featFile = intconfig_->featFile;
if (!featFile.empty() && !outArgs.empty()) {
if (!featFile.empty() && outArgs.empty()) {
size_t numOutputs = outArgs.size();
std::vector<MatrixPtr> featMatrices;
featMatrices.resize(numOutputs);
......
......@@ -68,6 +68,10 @@ public:
* is training at same time.
*/
void testOnePeriod();
void startTestPeriod();
void finishTestPeriod();
void testOneDataBatch(const DataBatch& dataBatch,
std::vector<Argument>* outArgs);
/**
* Test for given data batch.
......@@ -75,7 +79,9 @@ public:
* @param evaluator Evaluator
* @return cost
*/
real testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator);
real forwardOneBatch(const DataBatch& dataBatch,
Evaluator* evaluator,
std::vector<Argument>* outArgs);
/**
......@@ -99,6 +105,10 @@ protected:
std::ofstream os_;
std::vector<MatrixPtr> cpuMat_;
std::vector<IVectorPtr> cpuVec_;
struct {
int64_t numSamples;
real cost;
} testContext_;
private:
/**
......
......@@ -196,7 +196,8 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
if (!dataProvider_ && config_->hasDataConfig()) {
dataProvider_.reset(DataProvider::create(*config_, *config_, gpuData));
}
if (dataProvider_) {
if (!testDataProvider_) {
// No evaluator_ if there is testDataProvider but no dataProvider.
evaluator_.reset(trainerInternal_.getGradientMachine()->makeEvaluator());
currentEvaluator_.reset(
trainerInternal_.getGradientMachine()->makeEvaluator());
......@@ -215,10 +216,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
DataProvider::create(config_->getTestDataConfig(), *config_, gpuData));
}
if (testDataProvider_) {
tester_.reset(new Tester(config_, createTesterConfig(),
trainerInternal_.getGradientMachine(),
trainerInternal_.getParameterUpdater(),
testDataProvider_));
createTester();
}
if (!testing &&
......@@ -258,34 +256,25 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
}
}
// set current evaluator and evalutor
trainerInternal_.setCurrentEvaluator(currentEvaluator_.get());
trainerInternal_.setEvaluator(evaluator_.get());
}
void Trainer::train(size_t numPasses) {
srand(config_->getConfig().start_pass() + 1);
dataProvider_->reset();
if (this->testDataProvider_) {
this->testDataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
startTrain();
for (size_t i = 0; i < numPasses; ++i) {
if (IGradientMachineMode::trainWholeDataInOneBatch(mode_)) {
trainOnePassBatch(config_->getConfig().start_pass() + i);
} else {
trainOnePass(config_->getConfig().start_pass() + i);
trainOnePass();
}
if (i < numPasses - 1) {
dataProvider_->reset();
}
}
trainerInternal_.getGradientMachine()->finish();
finishTrain();
}
......@@ -387,13 +376,30 @@ real Trainer::checkGradient() {
return maxDiff;
}
void Trainer::trainOnePass(int passId) {
this->stats_->reset();
int64_t batchId = 0;
int32_t batchSize = config_->getOptConfig().batch_size();
real avgTestCost = 0;
int64_t numAvgTests = 0;
int passInnerId = 1;
void Trainer::startTrain() {
trainPassContext_.passId = config_->getConfig().start_pass();
srand(config_->getConfig().start_pass() + 1);
if (dataProvider_) {
dataProvider_->reset();
}
if (this->testDataProvider_) {
this->testDataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
}
void Trainer::finishTrain() {
trainerInternal_.getGradientMachine()->finish();
}
void Trainer::startTrainPass() {
stats_->reset();
trainPassContext_.batchId = 0;
trainPassContext_.avgTestCost = 0;
trainPassContext_.numAvgTests = 0;
trainPassContext_.passInnerId = 1;
trainerInternal_.getParameterUpdater()->startPass();
evaluator_->start();
......@@ -401,81 +407,83 @@ void Trainer::trainOnePass(int passId) {
trainerInternal_.getGradientMachine()->resetState();
trainerInternal_.getGradientMachine()->getState(testState_);
}
while (true) {
DataBatch dataBatch;
int num = 0;
{
REGISTER_TIMER("getTrainBatch");
num = dataProvider_->getNextBatch(batchSize, &dataBatch);
}
if (num == 0) break;
}
if (averageEvaluator_) {
int64_t mod = batchId % FLAGS_average_test_period;
if (mod >= FLAGS_average_test_period - FLAGS_log_period) {
if (mod == FLAGS_average_test_period - FLAGS_log_period) {
averageEvaluator_->start();
}
trainerInternal_.getParameterUpdater()->apply();
if (FLAGS_prev_batch_state) {
trainerInternal_.getGradientMachine()->getState(trainState_);
}
avgTestCost +=
tester_->testOneBatch(dataBatch, averageEvaluator_.get());
if (FLAGS_prev_batch_state) {
trainerInternal_.getGradientMachine()->setState(trainState_);
}
numAvgTests += num;
trainerInternal_.getParameterUpdater()->restore();
void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
int num = dataBatch.getSize();
if (averageEvaluator_) {
int64_t mod = trainPassContext_.batchId % FLAGS_average_test_period;
if (mod >= FLAGS_average_test_period - FLAGS_log_period) {
if (mod == FLAGS_average_test_period - FLAGS_log_period) {
averageEvaluator_->start();
}
trainerInternal_.getParameterUpdater()->apply();
if (FLAGS_prev_batch_state) {
trainerInternal_.getGradientMachine()->getState(trainState_);
}
trainPassContext_.avgTestCost +=
tester_->forwardOneBatch(
dataBatch, averageEvaluator_.get(), &forwardOutput_);
if (FLAGS_prev_batch_state) {
trainerInternal_.getGradientMachine()->setState(trainState_);
}
trainPassContext_.numAvgTests += num;
trainerInternal_.getParameterUpdater()->restore();
}
{
REGISTER_TIMER("TrainBatch");
trainerInternal_.trainOneBatch(batchId, dataBatch);
}
}
{
REGISTER_TIMER("TrainBatch");
trainerInternal_.trainOneBatch(
trainPassContext_.batchId, dataBatch, &forwardOutput_);
}
if (averageEvaluator_ &&
batchId % FLAGS_average_test_period == FLAGS_average_test_period - 1) {
averageEvaluator_->finish();
LOG(INFO) << " Averaged parameter:"
<< " cost=" << avgTestCost / numAvgTests
<< " Eval: " << *averageEvaluator_;
numAvgTests = 0;
avgTestCost = 0;
}
if (averageEvaluator_ &&
trainPassContext_.batchId % FLAGS_average_test_period
== FLAGS_average_test_period - 1) {
averageEvaluator_->finish();
LOG(INFO) << " Averaged parameter:"
<< " cost=" << trainPassContext_.avgTestCost
/ trainPassContext_.numAvgTests
<< " Eval: " << *averageEvaluator_;
trainPassContext_.numAvgTests = 0;
trainPassContext_.avgTestCost = 0;
}
++batchId;
++trainPassContext_.batchId;
if (batchId % FLAGS_log_period == 0) {
FOR_TIMING(globalStat.setThreadInfo(true));
FOR_TIMING(globalStat.printAllStatus());
FOR_TIMING(globalStat.reset());
}
if (trainPassContext_.batchId % FLAGS_log_period == 0) {
FOR_TIMING(globalStat.setThreadInfo(true));
FOR_TIMING(globalStat.printAllStatus());
FOR_TIMING(globalStat.reset());
}
if (testDataProvider_ && FLAGS_test_period > 0 &&
batchId % FLAGS_test_period == 0) {
tester_->testOnePeriod();
}
if (testDataProvider_ && FLAGS_test_period > 0 &&
trainPassContext_.batchId % FLAGS_test_period == 0) {
tester_->testOnePeriod();
}
if (FLAGS_saving_period_by_batches > 0 &&
batchId > FLAGS_saving_period_by_batches * passInnerId &&
0 == FLAGS_trainer_id) {
trainerInternal_.getParameterUpdater()->catchUpWith();
if (testDataProvider_) {
tester_->testOnePeriod();
}
paramUtil_->saveParametersOnePass(passId, passInnerId);
++passInnerId;
if (FLAGS_saving_period_by_batches > 0 &&
trainPassContext_.batchId
> FLAGS_saving_period_by_batches * trainPassContext_.passInnerId &&
0 == FLAGS_trainer_id) {
trainerInternal_.getParameterUpdater()->catchUpWith();
if (testDataProvider_) {
tester_->testOnePeriod();
}
paramUtil_->saveParametersOnePass(
trainPassContext_.passId, trainPassContext_.passInnerId);
++trainPassContext_.passInnerId;
}
}
if (batchId == 0) {
void Trainer::finishTrainPass() {
if (trainPassContext_.batchId == 0) {
// This means no more data from DataProvider
return;
}
trainerInternal_.finishTrainPass(passId, batchId);
trainerInternal_.finishTrainPass(
trainPassContext_.passId, trainPassContext_.batchId);
FOR_TIMING(globalStat.setThreadInfo(true));
FOR_TIMING(globalStat.printAllStatus());
......@@ -485,9 +493,30 @@ void Trainer::trainOnePass(int passId) {
tester_->testOnePeriod();
}
if (passId % FLAGS_saving_period == 0 && FLAGS_trainer_id == 0) {
paramUtil_->saveParametersOnePass(passId);
if (trainPassContext_.passId % FLAGS_saving_period == 0
&& FLAGS_trainer_id == 0) {
paramUtil_->saveParametersOnePass(trainPassContext_.passId);
}
++trainPassContext_.passId;
}
void Trainer::trainOnePass() {
startTrainPass();
size_t batchSize = config_->getOptConfig().batch_size();
while (true) {
DataBatch dataBatch;
int num = 0;
{
REGISTER_TIMER("getTrainBatch");
num = dataProvider_->getNextBatch(batchSize, &dataBatch);
}
if (num == 0) break;
CHECK_EQ(num, dataBatch.getSize());
trainOneDataBatch(dataBatch);
}
finishTrainPass();
}
void Trainer::trainOnePassBatch(int passId) {
......@@ -582,6 +611,13 @@ void Trainer::clearGradient() {
int Trainer::getBatchSize() { return config_->getOptConfig().batch_size(); }
void Trainer::createTester() {
tester_.reset(new paddle::Tester(config_, createTesterConfig(),
trainerInternal_.getGradientMachine(),
trainerInternal_.getParameterUpdater(),
testDataProvider_));
}
void Trainer::test() {
tester_->test();
}
......
......@@ -94,6 +94,11 @@ public:
*/
real checkGradient();
void startTrain();
void finishTrain();
void startTrainPass();
void finishTrainPass();
void trainOneDataBatch(DataBatch& dataBatch);
/**
* given a dataBatch and the current parameter value
......@@ -144,11 +149,11 @@ public:
protected:
/**
* Train one pass of data. passId starts from 0
* Train one pass of data.
*
* SGD Method.
*/
void trainOnePass(int passId);
void trainOnePass();
/**
* Train one pass in one batch.
......@@ -161,6 +166,8 @@ protected:
*/
void clearGradient();
void createTester();
private:
std::unique_ptr<TesterConfig> createTesterConfig();
......@@ -173,6 +180,17 @@ protected:
MachineState trainState_;
MachineState testState_;
struct TrainPassContext {
int64_t batchId;
real avgTestCost;
int64_t numAvgTests;
int passId;
int passInnerId;
};
std::vector<paddle::Argument> forwardOutput_;
TrainPassContext trainPassContext_;
std::unique_ptr<Evaluator> evaluator_;
std::unique_ptr<Evaluator> currentEvaluator_;
std::unique_ptr<Evaluator> averageEvaluator_;
......
......@@ -55,6 +55,8 @@ void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config,
gradientMachine_ = gradientMachine;
if (!gradientMachine) {
CHECK(config_->getConfig().has_model_config())
<< "Missing model_config in trainer_config";
gradientMachine_.reset(GradientMachine::create(
config_->getConfig().model_config(), intconfig_->mode,
parameterUpdater_->getParameterTypes()));
......@@ -62,7 +64,8 @@ void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config,
}
void TrainerInternal::trainOneBatch(int64_t batchId,
const DataBatch& dataBatch) {
const DataBatch& dataBatch,
std::vector<Argument>* outArgs) {
// true means updating parameter whenever gradient is ready during backward()
bool doPipelineUpdate =
(intconfig_->mode != GradientMachine::kSgdSparseCpuTraining) &&
......@@ -84,7 +87,6 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
}
const std::vector<Argument>& inArgs = dataBatch.getStreams();
std::vector<Argument> outArgs;
PassType passType = parameterUpdater_->startBatch(actualBatchSize);
......@@ -114,7 +116,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
timer.start();
#endif
REGISTER_TIMER("forwardBackward");
forwardBackwardBatch(inArgs, outArgs, passType, updateCallback,
forwardBackwardBatch(inArgs, *outArgs, passType, updateCallback,
doPipelineUpdate);
#ifndef PADDLE_DISABLE_TIMER
timer.stop();
......@@ -132,7 +134,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
real cost = 0;
{
REGISTER_TIMER("sumCost");
cost = Argument::sumCosts(outArgs);
cost = Argument::sumCosts(*outArgs);
}
if (batchId % intconfig_->log_period == 0) {
......
......@@ -81,7 +81,9 @@ public:
* @param batchId current batch id
* @param dataBatch data for the batch
*/
void trainOneBatch(int64_t batchId, const DataBatch& dataBatch);
void trainOneBatch(int64_t batchId,
const DataBatch& dataBatch,
std::vector<Argument>* outArgs);
/**
* showParameterStats
......
......@@ -130,7 +130,7 @@ message OptimizationConfig {
};
message TrainerConfig {
required ModelConfig model_config = 1;
optional ModelConfig model_config = 1;
optional DataConfig data_config = 2;
required OptimizationConfig opt_config = 3;
optional DataConfig test_data_config = 4;
......
......@@ -12,3 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.proto.TrainerConfig_pb2 import OptimizationConfig, TrainerConfig
from paddle.proto.ModelConfig_pb2 import ModelConfig
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册