提交 cbe734b3 编写于 作者: E emailweixu 提交者: Yu Yang

Python trainer api (#193)

* Python trainer API and demo

* Adding missing PaddleAPIPrivate.h

* Adding api_train.sh

* More comments

* Bump up patch version to 0b3
上级 46bd5f53
...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8) ...@@ -3,7 +3,7 @@ cmake_minimum_required(VERSION 2.8)
project(paddle CXX C) project(paddle CXX C)
set(PADDLE_MAJOR_VERSION 0) set(PADDLE_MAJOR_VERSION 0)
set(PADDLE_MINOR_VERSION 8) set(PADDLE_MINOR_VERSION 8)
set(PADDLE_PATCH_VERSION 0b2) set(PADDLE_PATCH_VERSION 0b3)
set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION}) set(PADDLE_VERSION ${PADDLE_MAJOR_VERSION}.${PADDLE_MINOR_VERSION}.${PADDLE_PATCH_VERSION})
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake")
......
...@@ -27,6 +27,7 @@ function(generate_python_api target_name) ...@@ -27,6 +27,7 @@ function(generate_python_api target_name)
COMMAND swig -python -c++ -outcurrentdir -I../ api/Paddle.swig COMMAND swig -python -c++ -outcurrentdir -I../ api/Paddle.swig
&& mv ${PROJ_ROOT}/paddle/swig_paddle.py ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py && mv ${PROJ_ROOT}/paddle/swig_paddle.py ${PROJ_ROOT}/paddle/py_paddle/swig_paddle.py
DEPENDS ${PROJ_ROOT}/paddle/api/Paddle.swig DEPENDS ${PROJ_ROOT}/paddle/api/Paddle.swig
${PROJ_ROOT}/paddle/api/PaddleAPI.h
WORKING_DIRECTORY ${PROJ_ROOT}/paddle WORKING_DIRECTORY ${PROJ_ROOT}/paddle
COMMENT "Generate Python API from swig") COMMENT "Generate Python API from swig")
add_custom_target(${target_name} ALL DEPENDS add_custom_target(${target_name} ALL DEPENDS
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import itertools
import random
from paddle.trainer.config_parser import parse_config
from py_paddle import swig_paddle as api
from py_paddle import DataProviderConverter
from paddle.trainer.PyDataProvider2 \
import integer_value, integer_value_sequence, sparse_binary_vector
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--train_data",
type=str, required=False, help="train data file")
parser.add_argument("--test_data", type=str, help="test data file")
parser.add_argument("--config",
type=str, required=True, help="config file name")
parser.add_argument("--dict_file", required=True, help="dictionary file")
parser.add_argument("--seq",
default=1, type=int,
help="whether use sequence training")
parser.add_argument("--use_gpu", default=0, type=int,
help="whether use GPU for training")
parser.add_argument("--trainer_count", default=1, type=int,
help="Number of threads for training")
parser.add_argument("--num_passes", default=5, type=int,
help="Number of training passes")
return parser.parse_args()
UNK_IDX = 0
def load_data(file_name, word_dict):
with open(file_name, 'r') as f:
for line in f:
label, comment = line.strip().split('\t')
words = comment.split()
word_slot = [word_dict.get(w, UNK_IDX) for w in words]
yield word_slot, int(label)
def load_dict(dict_file):
word_dict = dict()
with open(dict_file, 'r') as f:
for i, line in enumerate(f):
w = line.strip().split()[0]
word_dict[w] = i
return word_dict
def main():
options = parse_arguments()
api.initPaddle("--use_gpu=%s" % options.use_gpu,
"--trainer_count=%s" % options.trainer_count)
word_dict = load_dict(options.dict_file)
train_dataset = list(load_data(options.train_data, word_dict))
if options.test_data:
test_dataset = list(load_data(options.test_data, word_dict))
else:
test_dataset = None
trainer_config = parse_config(options.config,
"dict_file=%s" % options.dict_file)
# No need to have data provider for trainer
trainer_config.ClearField('data_config')
trainer_config.ClearField('test_data_config')
# create a GradientMachine from the model configuratin
model = api.GradientMachine.createFromConfigProto(
trainer_config.model_config)
# create a trainer for the gradient machine
trainer = api.Trainer.create(trainer_config, model)
# create a data converter which converts data to PaddlePaddle
# internal format
input_types = [
integer_value_sequence(len(word_dict)) if options.seq
else sparse_binary_vector(len(word_dict)),
integer_value(2)]
converter = DataProviderConverter(input_types)
batch_size = trainer_config.opt_config.batch_size
trainer.startTrain()
for train_pass in xrange(options.num_passes):
trainer.startTrainPass()
random.shuffle(train_dataset)
for pos in xrange(0, len(train_dataset), batch_size):
batch = itertools.islice(train_dataset, pos, pos + batch_size)
size = min(batch_size, len(train_dataset) - pos)
trainer.trainOneDataBatch(size, converter(batch))
trainer.finishTrainPass()
if test_dataset:
trainer.startTestPeriod();
for pos in xrange(0, len(test_dataset), batch_size):
batch = itertools.islice(test_dataset, pos, pos + batch_size)
size = min(batch_size, len(test_dataset) - pos)
trainer.testOneDataBatch(size, converter(batch))
trainer.finishTestPeriod()
trainer.finishTrain()
if __name__ == '__main__':
main()
#!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Note: if using trainer_config.emb.py, trainer_config.cnn.py
# or trainer_config.lstm.py, you need to change --seq to --seq=1
# because they are sequence models.
python api_train.py \
--config=trainer_config.lr.py \
--trainer_count=2 \
--num_passes=15 \
--use_gpu=0 \
--seq=0 \
--train_data=data/train.txt \
--test_data=data/test.txt \
--dict_file=data/dict.txt \
2>&1 | tee 'train.log'
...@@ -24,7 +24,7 @@ paddle train \ ...@@ -24,7 +24,7 @@ paddle train \
--config=$cfg \ --config=$cfg \
--save_dir=./output \ --save_dir=./output \
--trainer_count=4 \ --trainer_count=4 \
--log_period=20 \ --log_period=100 \
--num_passes=15 \ --num_passes=15 \
--use_gpu=false \ --use_gpu=false \
--show_parameter_stats_period=100 \ --show_parameter_stats_period=100 \
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
dict_file = "./data/dict.txt" dict_file = get_config_arg('dict_file', str, "./data/dict.txt")
word_dict = dict() word_dict = dict()
with open(dict_file, 'r') as f: with open(dict_file, 'r') as f:
for i, line in enumerate(f): for i, line in enumerate(f):
...@@ -63,7 +63,6 @@ if not is_predict: ...@@ -63,7 +63,6 @@ if not is_predict:
label = data_layer(name="label", size=2) label = data_layer(name="label", size=2)
# Define cross-entropy classification loss and error. # Define cross-entropy classification loss and error.
classification_cost(input=output, label=label)
cls = classification_cost(input=output, label=label) cls = classification_cost(input=output, label=label)
outputs(cls) outputs(cls)
else: else:
......
...@@ -46,8 +46,8 @@ class SentimentPrediction(): ...@@ -46,8 +46,8 @@ class SentimentPrediction():
conf = parse_config(train_conf, "is_predict=1") conf = parse_config(train_conf, "is_predict=1")
self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config) self.network = swig_paddle.GradientMachine.createFromConfigProto(conf.model_config)
self.network.loadParameters(self.model_dir) self.network.loadParameters(self.model_dir)
slots = [integer_value_sequence(self.dict_dim)] input_types = [integer_value_sequence(self.dict_dim)]
self.converter = DataProviderConverter(slots) self.converter = DataProviderConverter(input_types)
def load_dict(self): def load_dict(self):
""" """
......
...@@ -14,27 +14,10 @@ limitations under the License. */ ...@@ -14,27 +14,10 @@ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/parameter/Argument.h" #include "paddle/parameter/Argument.h"
struct ArgumentsPrivate {
std::vector<paddle::Argument> outputs;
inline paddle::Argument& getArg(size_t idx) throw(RangeError) {
if (idx < outputs.size()) {
return outputs[idx];
} else {
RangeError e;
throw e;
}
}
template <typename T>
std::shared_ptr<T>& cast(void* rawPtr) const {
return *(std::shared_ptr<T>*)(rawPtr);
}
};
size_t Arguments::getSlotNum() const { return m->outputs.size(); } size_t Arguments::getSlotNum() const { return m->outputs.size(); }
Arguments* Arguments::createArguments(size_t slotNum) { Arguments* Arguments::createArguments(size_t slotNum) {
......
...@@ -40,6 +40,8 @@ configure_file( ...@@ -40,6 +40,8 @@ configure_file(
generate_python_api(python_swig_sources) generate_python_api(python_swig_sources)
file(GLOB PY_PADDLE_PYTHON_FILES ${PROJ_ROOT}/paddle/py_paddle/*.py)
# TODO(yuyang18) : make wheel name calculated by cmake # TODO(yuyang18) : make wheel name calculated by cmake
add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel
...@@ -55,6 +57,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp ...@@ -55,6 +57,7 @@ add_custom_command(OUTPUT ${PROJ_ROOT}/paddle/dist/.timestamp
paddle_trainer paddle_trainer
paddle_api paddle_api
paddle_cuda paddle_cuda
${PY_PADDLE_PYTHON_FILES}
) )
install(DIRECTORY ${PROJ_ROOT}/paddle/dist/ install(DIRECTORY ${PROJ_ROOT}/paddle/dist/
......
...@@ -14,17 +14,9 @@ limitations under the License. */ ...@@ -14,17 +14,9 @@ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/trainer/Trainer.h" #include "paddle/trainer/Trainer.h"
struct TrainerConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> conf;
TrainerConfigPrivate() : conf(std::make_shared<paddle::TrainerConfig>()) {}
};
struct ModelConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> conf;
};
struct ParameterConfigPrivate { struct ParameterConfigPrivate {
paddle::ParameterPtr parameter; paddle::ParameterPtr parameter;
paddle::ParameterConfig config; paddle::ParameterConfig config;
...@@ -39,19 +31,6 @@ struct ParameterConfigPrivate { ...@@ -39,19 +31,6 @@ struct ParameterConfigPrivate {
} }
}; };
struct OptimizationConfigPrivate {
std::shared_ptr<paddle::TrainerConfig> trainer_config;
paddle::OptimizationConfig config;
paddle::OptimizationConfig& getConfig() {
if (trainer_config != nullptr) {
return *trainer_config->mutable_opt_config();
} else {
return config;
}
}
};
TrainerConfig::TrainerConfig() : m(new TrainerConfigPrivate()) {} TrainerConfig::TrainerConfig() : m(new TrainerConfigPrivate()) {}
TrainerConfig::~TrainerConfig() { delete m; } TrainerConfig::~TrainerConfig() { delete m; }
...@@ -59,10 +38,19 @@ TrainerConfig::~TrainerConfig() { delete m; } ...@@ -59,10 +38,19 @@ TrainerConfig::~TrainerConfig() { delete m; }
TrainerConfig* TrainerConfig::createFromTrainerConfigFile( TrainerConfig* TrainerConfig::createFromTrainerConfigFile(
const std::string& confPath) { const std::string& confPath) {
LOG(INFO) << "load trainer config from " << confPath; LOG(INFO) << "load trainer config from " << confPath;
paddle::TrainerConfigHelper helper(confPath); auto conf = std::make_shared<paddle::TrainerConfigHelper>(confPath);
//! TODO(yuyang18): Make TrainerConfigPrivate to TrainerConfigHelper
auto retv = new TrainerConfig(); auto retv = new TrainerConfig();
*retv->m->conf = helper.getConfig(); retv->m->conf = conf;
return retv;
}
TrainerConfig* TrainerConfig::createFromProtoString(
const std::string& str) {
auto retv = new TrainerConfig();
paddle::TrainerConfig trainerConfigProto;
auto conf = std::make_shared<paddle::TrainerConfigHelper>(trainerConfigProto);
CHECK(conf->getMutableConfig().ParseFromString(str));
retv->m->conf = conf;
return retv; return retv;
} }
...@@ -76,10 +64,6 @@ ModelConfig* TrainerConfig::getModelConfig() const { ...@@ -76,10 +64,6 @@ ModelConfig* TrainerConfig::getModelConfig() const {
return retv; return retv;
} }
void* ModelConfig::getPaddleModelConfig() const {
return m->conf->mutable_model_config();
}
ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {} ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {}
ParameterConfig::~ParameterConfig() { ParameterConfig::~ParameterConfig() {
...@@ -132,8 +116,6 @@ OptimizationConfig* TrainerConfig::getOptimizationConfig() const { ...@@ -132,8 +116,6 @@ OptimizationConfig* TrainerConfig::getOptimizationConfig() const {
return opt_config; return opt_config;
} }
void* OptimizationConfig::getRawPtr() { return &m->getConfig(); }
OptimizationConfig* OptimizationConfig::createFromProtoString( OptimizationConfig* OptimizationConfig::createFromProtoString(
const std::string& str) { const std::string& str) {
auto conf = new OptimizationConfig(); auto conf = new OptimizationConfig();
......
...@@ -14,30 +14,22 @@ limitations under the License. */ ...@@ -14,30 +14,22 @@ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h" #include "PaddleAPIPrivate.h"
#include "paddle/gserver/gradientmachines/NeuralNetwork.h" #include "paddle/gserver/gradientmachines/NeuralNetwork.h"
#include "Internal.h" #include "Internal.h"
std::vector<int> GradientMachine::defaultParamTypes = { std::vector<int> GradientMachine::defaultParamTypes = {
PARAMETER_VALUE, PARAMETER_GRADIENT, PARAMETER_MOMENTUM}; PARAMETER_VALUE, PARAMETER_GRADIENT, PARAMETER_MOMENTUM};
struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
template <typename T>
inline T& cast(void* ptr) {
return *(T*)(ptr);
}
};
GradientMachine::GradientMachine() : m(new GradientMachinePrivate()) {} GradientMachine::GradientMachine() : m(new GradientMachinePrivate()) {}
GradientMachine::~GradientMachine() { delete m; } GradientMachine::~GradientMachine() { delete m; }
GradientMachine* GradientMachine::createFromPaddleModelPtr( GradientMachine* GradientMachine::createFromPaddleModelPtr(
void* confPtr, GradientMatchineCreateMode mode, const void* confPtr, GradientMatchineCreateMode mode,
const std::vector<int>& types) { const std::vector<int>& types) {
auto& conf = *(paddle::ModelConfig*)(confPtr); auto& conf = *(const paddle::ModelConfig*)(confPtr);
std::vector<ParameterType> realTypes; std::vector<ParameterType> realTypes;
staticCastVector(&realTypes, types); staticCastVector(&realTypes, types);
auto machineRawPtr = paddle::GradientMachine::create(conf, mode, realTypes); auto machineRawPtr = paddle::GradientMachine::create(conf, mode, realTypes);
...@@ -66,7 +58,7 @@ GradientMachine* GradientMachine::createByConfigProtoStr( ...@@ -66,7 +58,7 @@ GradientMachine* GradientMachine::createByConfigProtoStr(
GradientMachine* GradientMachine::createByModelConfig( GradientMachine* GradientMachine::createByModelConfig(
ModelConfig* conf, GradientMatchineCreateMode mode, ModelConfig* conf, GradientMatchineCreateMode mode,
const std::vector<int>& types) { const std::vector<int>& types) {
auto confPtr = (paddle::ModelConfig*)conf->getPaddleModelConfig(); auto confPtr = &conf->m->conf->getModelConfig();
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types); return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
} }
......
...@@ -446,7 +446,6 @@ struct OptimizationConfigPrivate; ...@@ -446,7 +446,6 @@ struct OptimizationConfigPrivate;
class OptimizationConfig { class OptimizationConfig {
DISABLE_COPY_AND_ASSIGN(OptimizationConfig); DISABLE_COPY_AND_ASSIGN(OptimizationConfig);
OptimizationConfig(); OptimizationConfig();
void* getRawPtr();
public: public:
static OptimizationConfig* createFromProtoString(const std::string& str); static OptimizationConfig* createFromProtoString(const std::string& str);
...@@ -462,6 +461,7 @@ private: ...@@ -462,6 +461,7 @@ private:
friend class TrainerConfig; friend class TrainerConfig;
friend class ParameterOptimizer; friend class ParameterOptimizer;
friend class Trainer;
}; };
struct ParameterPrivate; struct ParameterPrivate;
...@@ -515,8 +515,6 @@ public: ...@@ -515,8 +515,6 @@ public:
virtual ~ModelConfig(); virtual ~ModelConfig();
private: private:
void* getPaddleModelConfig() const;
ModelConfigPrivate* m; ModelConfigPrivate* m;
friend class TrainerConfig; friend class TrainerConfig;
friend struct TrainerConfigPrivate; friend struct TrainerConfigPrivate;
...@@ -539,6 +537,7 @@ public: ...@@ -539,6 +537,7 @@ public:
static TrainerConfig* createFromTrainerConfigFile( static TrainerConfig* createFromTrainerConfigFile(
const std::string& configPath); const std::string& configPath);
static TrainerConfig* createFromProtoString(const std::string& str);
ModelConfig* getModelConfig() const; ModelConfig* getModelConfig() const;
...@@ -546,6 +545,7 @@ public: ...@@ -546,6 +545,7 @@ public:
private: private:
TrainerConfigPrivate* m; TrainerConfigPrivate* m;
friend class Trainer;
}; };
/** /**
...@@ -700,11 +700,12 @@ private: ...@@ -700,11 +700,12 @@ private:
GradientMachinePrivate* m; GradientMachinePrivate* m;
static GradientMachine* createFromPaddleModelPtr( static GradientMachine* createFromPaddleModelPtr(
void* confPtr, GradientMatchineCreateMode mode, const void* confPtr, GradientMatchineCreateMode mode,
const std::vector<int>& types); const std::vector<int>& types);
// Not to use c++ 11 init-list, so we use static var as function default arg. // Not to use c++ 11 init-list, so we use static var as function default arg.
static std::vector<int> defaultParamTypes; static std::vector<int> defaultParamTypes;
friend class Trainer;
}; };
struct TrainerPrivate; struct TrainerPrivate;
...@@ -712,6 +713,7 @@ class Trainer { ...@@ -712,6 +713,7 @@ class Trainer {
private: private:
TrainerPrivate* m; TrainerPrivate* m;
Trainer(); Trainer();
Trainer(TrainerConfig* optConfig, GradientMachine* gm);
DISABLE_COPY_AND_ASSIGN(Trainer); DISABLE_COPY_AND_ASSIGN(Trainer);
public: public:
...@@ -720,38 +722,42 @@ public: ...@@ -720,38 +722,42 @@ public:
/// Create A Trainer By TrainerConfig. using paddle command line. /// Create A Trainer By TrainerConfig. using paddle command line.
static Trainer* createByCommandLine() throw(IOError); static Trainer* createByCommandLine() throw(IOError);
/// Start Train. static Trainer* create(TrainerConfig* optConfig, GradientMachine* gm)
throw(IOError);
/// Start training
void startTrain(); void startTrain();
/// Finish training
void finishTrain(); void finishTrain();
/// Start Pass. /// Start a pass.
void startTrainPass(); void startTrainPass();
void finishTrainPass();
void setBatchSize(size_t batchSize); /// Finish a pass
void finishTrainPass();
/** /**
* Train one batch, * Train one batch,
* *
* @param batchSize -1 wiil use command line or batch size set before,
* otherwise use this batchSize for train.
*
* @return true if all batch finished. * @return true if all batch finished.
*/ */
bool trainOneBatch(size_t batchSize = -1UL); bool trainOneBatch(size_t batchSize);
bool prepareBatchData(size_t batchSize = -1UL); void trainOneDataBatch(size_t batchSize, const Arguments& args);
void finishTrainOneBatch(); void startTestPeriod();
void testOneDataBatch(size_t batchSize, const Arguments& args);
void finishTestPeriod();
void forwardOneBatch() throw(UnsupportError); void forwardOneBatch(size_t batchSize);
Arguments* getNetworkOutput(); Arguments* getForwardOutput();
Matrix* getLayerOutput(const std::string& layerName); Matrix* getLayerOutput(const std::string& layerName);
}; };
/// The N-Best results generated from one input sequence. /// the N-Best results generated from one input sequence.
class ISequenceResults { class ISequenceResults {
public: public:
virtual ~ISequenceResults(); virtual ~ISequenceResults();
......
/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/trainer/TrainerConfigHelper.h"
#pragma once
struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
template <typename T>
inline T& cast(void* ptr) {
return *(T*)(ptr);
}
};
struct OptimizationConfigPrivate {
std::shared_ptr<paddle::TrainerConfigHelper> trainer_config;
paddle::OptimizationConfig config;
const paddle::OptimizationConfig& getConfig() {
if (trainer_config != nullptr) {
return trainer_config->getOptConfig();
} else {
return config;
}
}
};
struct TrainerConfigPrivate {
std::shared_ptr<paddle::TrainerConfigHelper> conf;
TrainerConfigPrivate() {}
};
struct ModelConfigPrivate {
std::shared_ptr<paddle::TrainerConfigHelper> conf;
};
struct ArgumentsPrivate {
std::vector<paddle::Argument> outputs;
inline paddle::Argument& getArg(size_t idx) throw(RangeError) {
if (idx < outputs.size()) {
return outputs[idx];
} else {
RangeError e;
throw e;
}
}
template <typename T>
std::shared_ptr<T>& cast(void* rawPtr) const {
return *(std::shared_ptr<T>*)(rawPtr);
}
};
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/parameter/ParameterOptimizer.h" #include "paddle/parameter/ParameterOptimizer.h"
#include "Internal.h" #include "Internal.h"
#include <algorithm> #include <algorithm>
...@@ -60,10 +61,9 @@ ParameterOptimizer::~ParameterOptimizer() { ...@@ -60,10 +61,9 @@ ParameterOptimizer::~ParameterOptimizer() {
ParameterOptimizer* ParameterOptimizer::create(OptimizationConfig* config) { ParameterOptimizer* ParameterOptimizer::create(OptimizationConfig* config) {
CHECK(config != nullptr); CHECK(config != nullptr);
auto opt_config_ptr = (paddle::OptimizationConfig*)config->getRawPtr();
auto retOptimizer = new ParameterOptimizer(); auto retOptimizer = new ParameterOptimizer();
retOptimizer->m->optimizer.reset( retOptimizer->m->optimizer.reset(
paddle::ParameterOptimizer::create(*opt_config_ptr, false)); paddle::ParameterOptimizer::create(config->m->getConfig(), false));
return retOptimizer; return retOptimizer;
} }
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "PaddleAPI.h" #include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include <stdlib.h> #include <stdlib.h>
#include <memory> #include <memory>
...@@ -30,31 +31,17 @@ P_DECLARE_string(config); ...@@ -30,31 +31,17 @@ P_DECLARE_string(config);
P_DECLARE_string(init_model_path); P_DECLARE_string(init_model_path);
P_DECLARE_int32(start_pass); P_DECLARE_int32(start_pass);
struct TrainPassContext {
int64_t batchId;
int32_t batchSize;
real avgTestCost;
int64_t numAvgTests;
int passInnerId;
paddle::DataBatch data;
std::vector<paddle::Argument> forwardOutput;
};
struct TrainerPrivate : public paddle::Trainer { struct TrainerPrivate : public paddle::Trainer {
void startTrain(); bool _trainOneBatch(size_t batchSize);
void finishTrain(); bool forwardOneBatch(size_t batchSize);
void forwardOneDataBatch(const std::vector<paddle::Argument>& inArgs);
void startTrainPass(); void setBatchSize(size_t batchSize);
void finishTrainPass(); std::vector<paddle::Argument>& getForwardOutput();
bool _trainOneBatch(); void startTestPeriod();
void finishTestPeriod();
bool _prepareBatchData(); void testOneDataBatch(const paddle::DataBatch& dataBatch);
void _forwardOneBatch() throw(UnsupportError);
TrainerPrivate() : paddle::Trainer() {} TrainerPrivate() : paddle::Trainer() {}
TrainPassContext trainPassContext;
}; };
Trainer::Trainer() : m(new TrainerPrivate()) { Trainer::Trainer() : m(new TrainerPrivate()) {
...@@ -75,61 +62,76 @@ Trainer* Trainer::createByCommandLine() throw(IOError) { ...@@ -75,61 +62,76 @@ Trainer* Trainer::createByCommandLine() throw(IOError) {
} }
} }
void Trainer::startTrain() { m->startTrain(); } Trainer::Trainer(TrainerConfig* config, GradientMachine* gm)
: m(new TrainerPrivate()) {
m->init(config->m->conf, /* testing= */false, gm ? gm->m->machine : nullptr);
}
void TrainerPrivate::startTrain() { Trainer* Trainer::create(TrainerConfig* config, GradientMachine* gm)
srand(this->config_->getConfig().start_pass() + 1); throw(IOError)
this->dataProvider_->reset(); {
this->trainerInternal_.getGradientMachine()->start(*config_, dataProvider_); auto retv = new Trainer(config, gm);
if (retv->m->getConfig().IsInitialized()) {
return retv;
} else {
retv->m->getConfig().CheckInitialized();
throw IOError();
}
} }
void Trainer::finishTrain() { m->finishTrain(); } void Trainer::startTrain() { m->startTrain(); }
void TrainerPrivate::finishTrain() { void Trainer::finishTrain() { m->finishTrain(); }
this->trainerInternal_.getGradientMachine()->finish();
}
void Trainer::startTrainPass() { m->startTrainPass(); } void Trainer::startTrainPass() { m->startTrainPass(); }
void TrainerPrivate::startTrainPass() {
this->stats_.reset();
this->trainPassContext.batchId = 0;
this->trainPassContext.batchSize = this->config_->getOptConfig().batch_size();
this->trainPassContext.avgTestCost = 0;
this->trainPassContext.numAvgTests = 0;
this->trainPassContext.passInnerId = 0;
this->trainerInternal_.getParameterUpdater()->startPass();
this->evaluator_->start();
}
void Trainer::finishTrainPass() { m->finishTrainPass(); } void Trainer::finishTrainPass() { m->finishTrainPass(); }
void TrainerPrivate::finishTrainPass() { void Trainer::trainOneDataBatch(size_t batchSize, const Arguments& inArgs) {
this->trainerInternal_.getGradientMachine()->onPassEnd(); paddle::DataBatch dataBatch;
this->trainerInternal_.getParameterUpdater()->finishPass(); dataBatch.getStreams() = inArgs.m->outputs;
evaluator_->finish(); dataBatch.setSize(batchSize);
m->trainOneDataBatch(dataBatch);
} }
void Trainer::setBatchSize(size_t batchSize) { bool Trainer::trainOneBatch(size_t batchSize) {
this->m->trainPassContext.batchSize = batchSize; return m->_trainOneBatch(batchSize);
} }
bool Trainer::trainOneBatch(size_t batchSize) { bool TrainerPrivate::_trainOneBatch(size_t batchSize) {
if (batchSize == -1UL) { paddle::DataBatch dataBatch;
this->setBatchSize(batchSize); CHECK(dataProvider_) << "data_provider is not specified";
int num = dataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
return false;
} }
return m->_trainOneBatch(); trainOneDataBatch(dataBatch);
return false;
} }
bool TrainerPrivate::_trainOneBatch() { void TrainerPrivate::startTestPeriod() {
if (this->_prepareBatchData()) { if (!tester_) {
return true; createTester();
} }
this->trainerInternal_.trainOneBatch(this->trainPassContext.batchId, tester_->startTestPeriod();
this->trainPassContext.data);
return false;
} }
void Trainer::startTestPeriod() { m->startTestPeriod(); }
void TrainerPrivate::testOneDataBatch(const paddle::DataBatch& dataBatch) {
tester_->testOneDataBatch(dataBatch, &forwardOutput_);
}
void Trainer::testOneDataBatch(size_t batchSize, const Arguments& args) {
paddle::DataBatch dataBatch;
dataBatch.getStreams() = args.m->outputs;
dataBatch.setSize(batchSize);
m->testOneDataBatch(dataBatch);
}
void TrainerPrivate::finishTestPeriod() { tester_->finishTestPeriod(); }
void Trainer::finishTestPeriod() { m->finishTestPeriod(); }
Matrix* Trainer::getLayerOutput(const std::string& layerName) { Matrix* Trainer::getLayerOutput(const std::string& layerName) {
auto nn = std::dynamic_pointer_cast<paddle::NeuralNetwork>( auto nn = std::dynamic_pointer_cast<paddle::NeuralNetwork>(
this->m->getGradientMachine()); this->m->getGradientMachine());
...@@ -138,46 +140,37 @@ Matrix* Trainer::getLayerOutput(const std::string& layerName) { ...@@ -138,46 +140,37 @@ Matrix* Trainer::getLayerOutput(const std::string& layerName) {
return Matrix::createByPaddleMatrixPtr(&m); return Matrix::createByPaddleMatrixPtr(&m);
} }
bool Trainer::prepareBatchData(size_t batchSize) { void Trainer::forwardOneBatch(size_t batchSize) { m->forwardOneBatch(batchSize); }
if (batchSize != -1UL) {
this->setBatchSize(batchSize); bool TrainerPrivate::forwardOneBatch(size_t batchSize) {
CHECK(dataProvider_) << "data_provider is not specified";
paddle::DataBatch dataBatch;
int num = dataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) {
return false;
} }
return this->m->_prepareBatchData();
}
bool TrainerPrivate::_prepareBatchData() { forwardOneDataBatch(dataBatch.getStreams());
int num = dataProvider_->getNextBatch(this->trainPassContext.batchSize, return true;
&this->trainPassContext.data);
return num == 0;
} }
void Trainer::finishTrainOneBatch() { ++m->trainPassContext.batchId; } void TrainerPrivate::forwardOneDataBatch(
const std::vector<paddle::Argument>& inArgs) {
void Trainer::forwardOneBatch() throw(UnsupportError) { m->_forwardOneBatch(); }
void TrainerPrivate::_forwardOneBatch() throw(UnsupportError) { std::vector<paddle::Argument>& outArgs = forwardOutput_;
auto& dataBatch = this->trainPassContext.data;
int64_t actualBatchSize = dataBatch.getSize();
if (actualBatchSize == 0) {
return;
}
const std::vector<paddle::Argument>& inArgs = dataBatch.getStreams();
std::vector<paddle::Argument>& outArgs = this->trainPassContext.forwardOutput;
outArgs.clear();
paddle::PassType passType =
this->trainerInternal_.getParameterUpdater()->startBatch(actualBatchSize);
if (config_->getOptConfig().use_sparse_remote_updater()) { if (config_->getOptConfig().use_sparse_remote_updater()) {
this->trainerInternal_.getGradientMachine()->prefetch(inArgs); trainerInternal_.getGradientMachine()->prefetch(inArgs);
this->trainerInternal_.getParameterUpdater()->getParametersRemote(); trainerInternal_.getParameterUpdater()->getParametersRemote();
} }
this->trainerInternal_.getGradientMachine()->forward( trainerInternal_.getGradientMachine()->forward(
inArgs, &outArgs, passType); inArgs, &outArgs, paddle::PASS_TEST);
}
Arguments* Trainer::getForwardOutput() {
return Arguments::createByPaddleArgumentVector(&m->getForwardOutput());
} }
Arguments* Trainer::getNetworkOutput() { std::vector<paddle::Argument>& TrainerPrivate::getForwardOutput() {
return Arguments::createByPaddleArgumentVector( return forwardOutput_;
&m->trainPassContext.forwardOutput);
} }
...@@ -30,7 +30,7 @@ source .test_env/bin/activate ...@@ -30,7 +30,7 @@ source .test_env/bin/activate
pip --timeout 600 install ../../dist/*.whl pip --timeout 600 install ../../dist/*.whl
test_list="testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py" test_list="testArguments.py testGradientMachine.py testMatrix.py testVector.py testTrain.py testTrainer.py"
export PYTHONPATH=$PWD/../../../python/ export PYTHONPATH=$PWD/../../../python/
......
...@@ -12,9 +12,8 @@ ...@@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from py_paddle import swig_paddle, DataProviderWrapperConverter from py_paddle import swig_paddle
import paddle.trainer.config_parser import paddle.trainer.config_parser
from paddle.trainer.PyDataProviderWrapper import DenseSlot, IndexSlot
import numpy import numpy
import util import util
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer.config_parser import parse_config
from paddle.trainer.config_parser import logger
from py_paddle import swig_paddle
import util
def main():
trainer_config = parse_config(
"./testTrainConfig.py", "")
model = swig_paddle.GradientMachine.createFromConfigProto(
trainer_config.model_config)
trainer = swig_paddle.Trainer.create(trainer_config, model)
trainer.startTrain()
for train_pass in xrange(2):
trainer.startTrainPass()
num = 0
cost = 0
while True: # Train one batch
batch_size = 1000
data, atEnd = util.loadMNISTTrainData(batch_size)
if atEnd:
break
trainer.trainOneDataBatch(batch_size, data)
outs = trainer.getForwardOutput()
cost += sum(outs[0]['value'])
num += batch_size
trainer.finishTrainPass()
logger.info('train cost=%f' % (cost / num))
trainer.startTestPeriod()
num = 0
cost = 0
while True: # Test one batch
batch_size = 1000
data, atEnd = util.loadMNISTTrainData(batch_size)
if atEnd:
break
trainer.testOneDataBatch(batch_size, data)
outs = trainer.getForwardOutput()
cost += sum(outs[0]['value'])
num += batch_size
trainer.finishTestPeriod()
logger.info('test cost=%f' % (cost / num))
trainer.finishTrain()
if __name__ == '__main__':
swig_paddle.initPaddle("--use_gpu=0", "--trainer_count=1")
main()
...@@ -63,7 +63,7 @@ class SparseBinaryScanner(IScanner): ...@@ -63,7 +63,7 @@ class SparseBinaryScanner(IScanner):
def scan(self, dat): def scan(self, dat):
self.extend_cols(dat) self.extend_cols(dat)
self.__rows__.append(len(dat) + self.__rows__[-1]) self.__rows__.append(len(self.__cols__))
self.__height__ += 1 self.__height__ += 1
def extend_cols(self, dat): def extend_cols(self, dat):
......
...@@ -79,16 +79,7 @@ class __ParameterCallbackWrapper__(swig_paddle.UpdateCallback): ...@@ -79,16 +79,7 @@ class __ParameterCallbackWrapper__(swig_paddle.UpdateCallback):
else: else:
return __ParameterCallbackWrapper__(callback).__disown__() return __ParameterCallbackWrapper__(callback).__disown__()
def __arguments_to_numpy__(i, arg):
def __monkeypatch_gradient_machine__():
"""
Add some class methods to GradientMachine.
This method should be only used internally.
"""
swig_paddle.GradientMachine.loadFromConfigFile = \
staticmethod(loadGradientMachine)
def __arguments_to_numpy__(i, arg):
assert isinstance(arg, swig_paddle.Arguments) assert isinstance(arg, swig_paddle.Arguments)
value = arg.getSlotValue(i) value = arg.getSlotValue(i)
if value is not None: if value is not None:
...@@ -103,6 +94,15 @@ def __monkeypatch_gradient_machine__(): ...@@ -103,6 +94,15 @@ def __monkeypatch_gradient_machine__():
"id": ids "id": ids
} }
def __monkeypatch_gradient_machine__():
"""
Add some class methods to GradientMachine.
This method should be only used internally.
"""
swig_paddle.GradientMachine.loadFromConfigFile = \
staticmethod(loadGradientMachine)
def __matrix_to_numpy__(m): def __matrix_to_numpy__(m):
if isinstance(m, swig_paddle.Matrix): if isinstance(m, swig_paddle.Matrix):
return m.copyToNumpyMat() return m.copyToNumpyMat()
...@@ -126,7 +126,7 @@ def __monkeypatch_gradient_machine__(): ...@@ -126,7 +126,7 @@ def __monkeypatch_gradient_machine__():
:type paramTypes: list of int :type paramTypes: list of int
:return: paddle.GradientMachine :return: paddle.GradientMachine
""" """
assert isinstance(protoObj, paddle.proto.ModelConfig_pb2.ModelConfig) assert isinstance(protoObj, paddle.proto.ModelConfig)
return swig_paddle.GradientMachine.createByConfigProtoStr( return swig_paddle.GradientMachine.createByConfigProtoStr(
protoObj.SerializeToString(), createMode, paramTypes) protoObj.SerializeToString(), createMode, paramTypes)
...@@ -460,13 +460,29 @@ def __monkey_patch_protobuf_objects__(): ...@@ -460,13 +460,29 @@ def __monkey_patch_protobuf_objects__():
""" """
assert isinstance(protoObj, assert isinstance(protoObj,
paddle.proto.TrainerConfig_pb2.OptimizationConfig) paddle.proto.OptimizationConfig)
return swig_paddle.OptimizationConfig.createFromProtoString( return swig_paddle.OptimizationConfig.createFromProtoString(
protoObj.SerializeToString()) protoObj.SerializeToString())
swig_paddle.OptimizationConfig.createFromProto = staticmethod( swig_paddle.OptimizationConfig.createFromProto = staticmethod(
OptimizationConfig_createFromProto) OptimizationConfig_createFromProto)
def TrainerConfig_createFromProto(protoObj):
"""
Create a new paddle.TrainerConfig from
proto.OptimizationConfig
:param protoObj: proto.TrainerConfig
:return: paddle.TrainerConfig
"""
assert isinstance(protoObj,
paddle.proto.TrainerConfig)
return swig_paddle.TrainerConfig.createFromProtoString(
protoObj.SerializeToString())
swig_paddle.TrainerConfig.createFromProto = staticmethod(
TrainerConfig_createFromProto)
def __monkey_patch_parameter__(): def __monkey_patch_parameter__():
def getBufs(self): def getBufs(self):
...@@ -483,9 +499,66 @@ def __monkey_patch_parameter__(): ...@@ -483,9 +499,66 @@ def __monkey_patch_parameter__():
swig_paddle.Parameter.getBufs = getBufs swig_paddle.Parameter.getBufs = getBufs
def __monkey_patch_trainer__():
swig_paddle.Trainer.__create__ = staticmethod(swig_paddle.Trainer.create)
def Trainer_create(config, model=None):
"""
Create a trainer for model with TrainerCOnfig trainer_config
trainer_config.model_config will be ignored when model is supplied.
Trainer.trainOneBatch() and Trainer.forwardOneBatch() can be used only
when trainer_config.data_config is set.
A typical usage for Trainer is:
.. code-block:: python
trainer = Trainer.create(trainer_config, model)
for p in xrange(num_passes)
while True:
data = get_next_batch(batch_size)
if not data:
break
trainer.trainOneDataBatch(batch_size, data)
trainer.finishTrainPass()
trainer.finishTrain()
The trainer will take care of logging, model saving, distributed
training, etc.
:param config: trainer configuration
:type config: paddle.proto.TrainerConfig
:param model: the model to be trained
:type model: swig_paddle.GradientMachine
:return: a trainer
:rtype swig_paddle.Trainer
"""
assert isinstance(config, paddle.proto.TrainerConfig)
if model is not None:
assert isinstance(model, swig_paddle.GradientMachine)
return swig_paddle.Trainer.__create__(
swig_paddle.TrainerConfig.createFromProto(config), model)
swig_paddle.Trainer.create = staticmethod(Trainer_create)
swig_paddle.Trainer.__getForwardOutput__ = \
swig_paddle.Trainer.getForwardOutput
def getForwardOutput(self):
"""
Get the netword outputs from the previous trainOneBatch(),
trainOneDataBatch(), testOneDataPatch(), or forwardOneBatch() call.
:return: list of dictionary with keys ['id', 'value'], each value is a
numpy.ndarray.
"""
outArgs = self.__getForwardOutput__()
return [__arguments_to_numpy__(i, outArgs) for i in xrange(
outArgs.getSlotNum())]
swig_paddle.Trainer.getForwardOutput = getForwardOutput
def monkeypatches(): def monkeypatches():
patches = [__monkeypatch_init_paddle__, __monkeypatch_gradient_machine__, patches = [__monkeypatch_init_paddle__, __monkeypatch_gradient_machine__,
__monkey_patch_protobuf_objects__, __monkey_patch_protobuf_objects__,
__monkey_patch_parameter__] __monkey_patch_parameter__, __monkey_patch_trainer__]
for patch in patches: for patch in patches:
patch() patch()
...@@ -71,24 +71,36 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper> &config, ...@@ -71,24 +71,36 @@ Tester::Tester(const std::shared_ptr<TrainerConfigHelper> &config,
parameterUpdater_)); parameterUpdater_));
} }
void Tester::startTestPeriod() {
testEvaluator_->start();
testContext_.cost = 0;
testContext_.numSamples = 0;
parameterUpdater_->apply();
if (intconfig_->prevBatchState) {
gradientMachine_->getState(*intconfig_->trainState);
gradientMachine_->setState(*intconfig_->testState);
}
}
void Tester::testOneDataBatch(
const DataBatch& dataBatch, std::vector<Argument>* outArgs) {
testContext_.cost += forwardOneBatch(
dataBatch, testEvaluator_.get(), outArgs);
testContext_.numSamples += dataBatch.getSize();
}
void Tester::testOnePeriod() { void Tester::testOnePeriod() {
DataBatch dataBatch; DataBatch dataBatch;
int64_t batchSize = config_->getOptConfig().batch_size(); int64_t batchSize = config_->getOptConfig().batch_size();
testEvaluator_->start();
real cost = 0;
int64_t numSamples = 0;
bool testAllData = bool testAllData =
intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod; intconfig_->testPeriod == 0 || intconfig_->testAllDataInOnePeriod;
int batches = int batches =
testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod; testAllData ? std::numeric_limits<int>::max() : intconfig_->testPeriod;
parameterUpdater_->apply(); std::vector<Argument> outArgs;
if (intconfig_->prevBatchState) {
gradientMachine_->getState(*intconfig_->trainState);
gradientMachine_->setState(*intconfig_->testState);
}
startTestPeriod();
for (int i = 0; i < batches; ++i) { for (int i = 0; i < batches; ++i) {
int num = testDataProvider_->getNextBatch(batchSize, &dataBatch); int num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
if (num == 0) { if (num == 0) {
...@@ -102,13 +114,17 @@ void Tester::testOnePeriod() { ...@@ -102,13 +114,17 @@ void Tester::testOnePeriod() {
num = testDataProvider_->getNextBatch(batchSize, &dataBatch); num = testDataProvider_->getNextBatch(batchSize, &dataBatch);
} }
} }
cost += testOneBatch(dataBatch, testEvaluator_.get()); testOneDataBatch(dataBatch, &outArgs);
numSamples += num;
} }
}
void Tester::finishTestPeriod() {
testEvaluator_->finish(); testEvaluator_->finish();
CHECK_GT(numSamples, 0) << "There is no samples in your test batch. Possibly " CHECK_GT(testContext_.numSamples, 0)
<< "There is no samples in your test batch. Possibly "
"wrong implementation of DataProvidor.reset()"; "wrong implementation of DataProvidor.reset()";
LOG(INFO) << " Test samples=" << numSamples << " cost=" << cost / numSamples LOG(INFO) << " Test samples=" << testContext_.numSamples
<< " cost=" << testContext_.cost / testContext_.numSamples
<< " Eval: " << *testEvaluator_; << " Eval: " << *testEvaluator_;
parameterUpdater_->restore(); parameterUpdater_->restore();
if (intconfig_->prevBatchState) { if (intconfig_->prevBatchState) {
...@@ -128,9 +144,11 @@ int64_t Tester::testOneBatchById(int64_t batchId) { ...@@ -128,9 +144,11 @@ int64_t Tester::testOneBatchById(int64_t batchId) {
return 0; return 0;
} }
std::vector<Argument> outArgs;
stats_ += std::pair<int64_t, real>{ stats_ += std::pair<int64_t, real>{
actualBatchSize, actualBatchSize,
testOneBatch(dataBatch, testEvaluator_.get())}; forwardOneBatch(dataBatch, testEvaluator_.get(), &outArgs)};
if (((batchId + 1) % intconfig_->logPeriod) == 0) { if (((batchId + 1) % intconfig_->logPeriod) == 0) {
LOG(INFO) << " Batch=" << batchId + 1 << " " << stats_.getStats(false); LOG(INFO) << " Batch=" << batchId + 1 << " " << stats_.getStats(false);
...@@ -139,7 +157,10 @@ int64_t Tester::testOneBatchById(int64_t batchId) { ...@@ -139,7 +157,10 @@ int64_t Tester::testOneBatchById(int64_t batchId) {
return actualBatchSize; return actualBatchSize;
} }
real Tester::testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator) { real Tester::forwardOneBatch(const DataBatch &dataBatch,
Evaluator *evaluator,
std::vector<Argument>* pOutArgs) {
auto& outArgs = *pOutArgs;
const std::vector<Argument>& inArgs = dataBatch.getStreams(); const std::vector<Argument>& inArgs = dataBatch.getStreams();
if (intconfig_->loadsaveParametersInPserver) { if (intconfig_->loadsaveParametersInPserver) {
REGISTER_TIMER("prefetch"); REGISTER_TIMER("prefetch");
...@@ -148,12 +169,11 @@ real Tester::testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator) { ...@@ -148,12 +169,11 @@ real Tester::testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator) {
true /*after apply*/); true /*after apply*/);
} }
std::vector<Argument> outArgs;
gradientMachine_->forward(inArgs, &outArgs, PASS_TEST); gradientMachine_->forward(inArgs, &outArgs, PASS_TEST);
// write features if set this flag and outArgs is not empty // write features if set this flag and outArgs is not empty
std::string featFile = intconfig_->featFile; std::string featFile = intconfig_->featFile;
if (!featFile.empty() && !outArgs.empty()) { if (!featFile.empty() && outArgs.empty()) {
size_t numOutputs = outArgs.size(); size_t numOutputs = outArgs.size();
std::vector<MatrixPtr> featMatrices; std::vector<MatrixPtr> featMatrices;
featMatrices.resize(numOutputs); featMatrices.resize(numOutputs);
......
...@@ -68,6 +68,10 @@ public: ...@@ -68,6 +68,10 @@ public:
* is training at same time. * is training at same time.
*/ */
void testOnePeriod(); void testOnePeriod();
void startTestPeriod();
void finishTestPeriod();
void testOneDataBatch(const DataBatch& dataBatch,
std::vector<Argument>* outArgs);
/** /**
* Test for given data batch. * Test for given data batch.
...@@ -75,7 +79,9 @@ public: ...@@ -75,7 +79,9 @@ public:
* @param evaluator Evaluator * @param evaluator Evaluator
* @return cost * @return cost
*/ */
real testOneBatch(const DataBatch &dataBatch, Evaluator *evaluator); real forwardOneBatch(const DataBatch& dataBatch,
Evaluator* evaluator,
std::vector<Argument>* outArgs);
/** /**
...@@ -99,6 +105,10 @@ protected: ...@@ -99,6 +105,10 @@ protected:
std::ofstream os_; std::ofstream os_;
std::vector<MatrixPtr> cpuMat_; std::vector<MatrixPtr> cpuMat_;
std::vector<IVectorPtr> cpuVec_; std::vector<IVectorPtr> cpuVec_;
struct {
int64_t numSamples;
real cost;
} testContext_;
private: private:
/** /**
......
...@@ -196,7 +196,8 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config, ...@@ -196,7 +196,8 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
if (!dataProvider_ && config_->hasDataConfig()) { if (!dataProvider_ && config_->hasDataConfig()) {
dataProvider_.reset(DataProvider::create(*config_, *config_, gpuData)); dataProvider_.reset(DataProvider::create(*config_, *config_, gpuData));
} }
if (dataProvider_) { if (!testDataProvider_) {
// No evaluator_ if there is testDataProvider but no dataProvider.
evaluator_.reset(trainerInternal_.getGradientMachine()->makeEvaluator()); evaluator_.reset(trainerInternal_.getGradientMachine()->makeEvaluator());
currentEvaluator_.reset( currentEvaluator_.reset(
trainerInternal_.getGradientMachine()->makeEvaluator()); trainerInternal_.getGradientMachine()->makeEvaluator());
...@@ -215,10 +216,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config, ...@@ -215,10 +216,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
DataProvider::create(config_->getTestDataConfig(), *config_, gpuData)); DataProvider::create(config_->getTestDataConfig(), *config_, gpuData));
} }
if (testDataProvider_) { if (testDataProvider_) {
tester_.reset(new Tester(config_, createTesterConfig(), createTester();
trainerInternal_.getGradientMachine(),
trainerInternal_.getParameterUpdater(),
testDataProvider_));
} }
if (!testing && if (!testing &&
...@@ -258,34 +256,25 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config, ...@@ -258,34 +256,25 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper> &config,
} }
} }
// set current evaluator and evalutor // set current evaluator and evalutor
trainerInternal_.setCurrentEvaluator(currentEvaluator_.get()); trainerInternal_.setCurrentEvaluator(currentEvaluator_.get());
trainerInternal_.setEvaluator(evaluator_.get()); trainerInternal_.setEvaluator(evaluator_.get());
} }
void Trainer::train(size_t numPasses) { void Trainer::train(size_t numPasses) {
srand(config_->getConfig().start_pass() + 1); startTrain();
dataProvider_->reset();
if (this->testDataProvider_) {
this->testDataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
for (size_t i = 0; i < numPasses; ++i) { for (size_t i = 0; i < numPasses; ++i) {
if (IGradientMachineMode::trainWholeDataInOneBatch(mode_)) { if (IGradientMachineMode::trainWholeDataInOneBatch(mode_)) {
trainOnePassBatch(config_->getConfig().start_pass() + i); trainOnePassBatch(config_->getConfig().start_pass() + i);
} else { } else {
trainOnePass(config_->getConfig().start_pass() + i); trainOnePass();
} }
if (i < numPasses - 1) { if (i < numPasses - 1) {
dataProvider_->reset(); dataProvider_->reset();
} }
} }
trainerInternal_.getGradientMachine()->finish(); finishTrain();
} }
...@@ -387,13 +376,30 @@ real Trainer::checkGradient() { ...@@ -387,13 +376,30 @@ real Trainer::checkGradient() {
return maxDiff; return maxDiff;
} }
void Trainer::trainOnePass(int passId) { void Trainer::startTrain() {
this->stats_->reset(); trainPassContext_.passId = config_->getConfig().start_pass();
int64_t batchId = 0; srand(config_->getConfig().start_pass() + 1);
int32_t batchSize = config_->getOptConfig().batch_size(); if (dataProvider_) {
real avgTestCost = 0; dataProvider_->reset();
int64_t numAvgTests = 0; }
int passInnerId = 1;
if (this->testDataProvider_) {
this->testDataProvider_->reset();
}
trainerInternal_.getGradientMachine()->start(*config_, dataProvider_);
}
void Trainer::finishTrain() {
trainerInternal_.getGradientMachine()->finish();
}
void Trainer::startTrainPass() {
stats_->reset();
trainPassContext_.batchId = 0;
trainPassContext_.avgTestCost = 0;
trainPassContext_.numAvgTests = 0;
trainPassContext_.passInnerId = 1;
trainerInternal_.getParameterUpdater()->startPass(); trainerInternal_.getParameterUpdater()->startPass();
evaluator_->start(); evaluator_->start();
...@@ -401,18 +407,12 @@ void Trainer::trainOnePass(int passId) { ...@@ -401,18 +407,12 @@ void Trainer::trainOnePass(int passId) {
trainerInternal_.getGradientMachine()->resetState(); trainerInternal_.getGradientMachine()->resetState();
trainerInternal_.getGradientMachine()->getState(testState_); trainerInternal_.getGradientMachine()->getState(testState_);
} }
while (true) { }
DataBatch dataBatch;
int num = 0;
{
REGISTER_TIMER("getTrainBatch");
num = dataProvider_->getNextBatch(batchSize, &dataBatch);
}
if (num == 0) break;
void Trainer::trainOneDataBatch(DataBatch& dataBatch) {
int num = dataBatch.getSize();
if (averageEvaluator_) { if (averageEvaluator_) {
int64_t mod = batchId % FLAGS_average_test_period; int64_t mod = trainPassContext_.batchId % FLAGS_average_test_period;
if (mod >= FLAGS_average_test_period - FLAGS_log_period) { if (mod >= FLAGS_average_test_period - FLAGS_log_period) {
if (mod == FLAGS_average_test_period - FLAGS_log_period) { if (mod == FLAGS_average_test_period - FLAGS_log_period) {
averageEvaluator_->start(); averageEvaluator_->start();
...@@ -421,61 +421,69 @@ void Trainer::trainOnePass(int passId) { ...@@ -421,61 +421,69 @@ void Trainer::trainOnePass(int passId) {
if (FLAGS_prev_batch_state) { if (FLAGS_prev_batch_state) {
trainerInternal_.getGradientMachine()->getState(trainState_); trainerInternal_.getGradientMachine()->getState(trainState_);
} }
avgTestCost += trainPassContext_.avgTestCost +=
tester_->testOneBatch(dataBatch, averageEvaluator_.get()); tester_->forwardOneBatch(
dataBatch, averageEvaluator_.get(), &forwardOutput_);
if (FLAGS_prev_batch_state) { if (FLAGS_prev_batch_state) {
trainerInternal_.getGradientMachine()->setState(trainState_); trainerInternal_.getGradientMachine()->setState(trainState_);
} }
numAvgTests += num; trainPassContext_.numAvgTests += num;
trainerInternal_.getParameterUpdater()->restore(); trainerInternal_.getParameterUpdater()->restore();
} }
} }
{ {
REGISTER_TIMER("TrainBatch"); REGISTER_TIMER("TrainBatch");
trainerInternal_.trainOneBatch(batchId, dataBatch); trainerInternal_.trainOneBatch(
trainPassContext_.batchId, dataBatch, &forwardOutput_);
} }
if (averageEvaluator_ && if (averageEvaluator_ &&
batchId % FLAGS_average_test_period == FLAGS_average_test_period - 1) { trainPassContext_.batchId % FLAGS_average_test_period
== FLAGS_average_test_period - 1) {
averageEvaluator_->finish(); averageEvaluator_->finish();
LOG(INFO) << " Averaged parameter:" LOG(INFO) << " Averaged parameter:"
<< " cost=" << avgTestCost / numAvgTests << " cost=" << trainPassContext_.avgTestCost
/ trainPassContext_.numAvgTests
<< " Eval: " << *averageEvaluator_; << " Eval: " << *averageEvaluator_;
numAvgTests = 0; trainPassContext_.numAvgTests = 0;
avgTestCost = 0; trainPassContext_.avgTestCost = 0;
} }
++batchId; ++trainPassContext_.batchId;
if (batchId % FLAGS_log_period == 0) { if (trainPassContext_.batchId % FLAGS_log_period == 0) {
FOR_TIMING(globalStat.setThreadInfo(true)); FOR_TIMING(globalStat.setThreadInfo(true));
FOR_TIMING(globalStat.printAllStatus()); FOR_TIMING(globalStat.printAllStatus());
FOR_TIMING(globalStat.reset()); FOR_TIMING(globalStat.reset());
} }
if (testDataProvider_ && FLAGS_test_period > 0 && if (testDataProvider_ && FLAGS_test_period > 0 &&
batchId % FLAGS_test_period == 0) { trainPassContext_.batchId % FLAGS_test_period == 0) {
tester_->testOnePeriod(); tester_->testOnePeriod();
} }
if (FLAGS_saving_period_by_batches > 0 && if (FLAGS_saving_period_by_batches > 0 &&
batchId > FLAGS_saving_period_by_batches * passInnerId && trainPassContext_.batchId
> FLAGS_saving_period_by_batches * trainPassContext_.passInnerId &&
0 == FLAGS_trainer_id) { 0 == FLAGS_trainer_id) {
trainerInternal_.getParameterUpdater()->catchUpWith(); trainerInternal_.getParameterUpdater()->catchUpWith();
if (testDataProvider_) { if (testDataProvider_) {
tester_->testOnePeriod(); tester_->testOnePeriod();
} }
paramUtil_->saveParametersOnePass(passId, passInnerId); paramUtil_->saveParametersOnePass(
++passInnerId; trainPassContext_.passId, trainPassContext_.passInnerId);
} ++trainPassContext_.passInnerId;
} }
}
if (batchId == 0) { void Trainer::finishTrainPass() {
if (trainPassContext_.batchId == 0) {
// This means no more data from DataProvider // This means no more data from DataProvider
return; return;
} }
trainerInternal_.finishTrainPass(passId, batchId); trainerInternal_.finishTrainPass(
trainPassContext_.passId, trainPassContext_.batchId);
FOR_TIMING(globalStat.setThreadInfo(true)); FOR_TIMING(globalStat.setThreadInfo(true));
FOR_TIMING(globalStat.printAllStatus()); FOR_TIMING(globalStat.printAllStatus());
...@@ -485,9 +493,30 @@ void Trainer::trainOnePass(int passId) { ...@@ -485,9 +493,30 @@ void Trainer::trainOnePass(int passId) {
tester_->testOnePeriod(); tester_->testOnePeriod();
} }
if (passId % FLAGS_saving_period == 0 && FLAGS_trainer_id == 0) { if (trainPassContext_.passId % FLAGS_saving_period == 0
paramUtil_->saveParametersOnePass(passId); && FLAGS_trainer_id == 0) {
paramUtil_->saveParametersOnePass(trainPassContext_.passId);
} }
++trainPassContext_.passId;
}
void Trainer::trainOnePass() {
startTrainPass();
size_t batchSize = config_->getOptConfig().batch_size();
while (true) {
DataBatch dataBatch;
int num = 0;
{
REGISTER_TIMER("getTrainBatch");
num = dataProvider_->getNextBatch(batchSize, &dataBatch);
}
if (num == 0) break;
CHECK_EQ(num, dataBatch.getSize());
trainOneDataBatch(dataBatch);
}
finishTrainPass();
} }
void Trainer::trainOnePassBatch(int passId) { void Trainer::trainOnePassBatch(int passId) {
...@@ -582,6 +611,13 @@ void Trainer::clearGradient() { ...@@ -582,6 +611,13 @@ void Trainer::clearGradient() {
int Trainer::getBatchSize() { return config_->getOptConfig().batch_size(); } int Trainer::getBatchSize() { return config_->getOptConfig().batch_size(); }
void Trainer::createTester() {
tester_.reset(new paddle::Tester(config_, createTesterConfig(),
trainerInternal_.getGradientMachine(),
trainerInternal_.getParameterUpdater(),
testDataProvider_));
}
void Trainer::test() { void Trainer::test() {
tester_->test(); tester_->test();
} }
......
...@@ -94,6 +94,11 @@ public: ...@@ -94,6 +94,11 @@ public:
*/ */
real checkGradient(); real checkGradient();
void startTrain();
void finishTrain();
void startTrainPass();
void finishTrainPass();
void trainOneDataBatch(DataBatch& dataBatch);
/** /**
* given a dataBatch and the current parameter value * given a dataBatch and the current parameter value
...@@ -144,11 +149,11 @@ public: ...@@ -144,11 +149,11 @@ public:
protected: protected:
/** /**
* Train one pass of data. passId starts from 0 * Train one pass of data.
* *
* SGD Method. * SGD Method.
*/ */
void trainOnePass(int passId); void trainOnePass();
/** /**
* Train one pass in one batch. * Train one pass in one batch.
...@@ -161,6 +166,8 @@ protected: ...@@ -161,6 +166,8 @@ protected:
*/ */
void clearGradient(); void clearGradient();
void createTester();
private: private:
std::unique_ptr<TesterConfig> createTesterConfig(); std::unique_ptr<TesterConfig> createTesterConfig();
...@@ -173,6 +180,17 @@ protected: ...@@ -173,6 +180,17 @@ protected:
MachineState trainState_; MachineState trainState_;
MachineState testState_; MachineState testState_;
struct TrainPassContext {
int64_t batchId;
real avgTestCost;
int64_t numAvgTests;
int passId;
int passInnerId;
};
std::vector<paddle::Argument> forwardOutput_;
TrainPassContext trainPassContext_;
std::unique_ptr<Evaluator> evaluator_; std::unique_ptr<Evaluator> evaluator_;
std::unique_ptr<Evaluator> currentEvaluator_; std::unique_ptr<Evaluator> currentEvaluator_;
std::unique_ptr<Evaluator> averageEvaluator_; std::unique_ptr<Evaluator> averageEvaluator_;
......
...@@ -55,6 +55,8 @@ void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config, ...@@ -55,6 +55,8 @@ void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config,
gradientMachine_ = gradientMachine; gradientMachine_ = gradientMachine;
if (!gradientMachine) { if (!gradientMachine) {
CHECK(config_->getConfig().has_model_config())
<< "Missing model_config in trainer_config";
gradientMachine_.reset(GradientMachine::create( gradientMachine_.reset(GradientMachine::create(
config_->getConfig().model_config(), intconfig_->mode, config_->getConfig().model_config(), intconfig_->mode,
parameterUpdater_->getParameterTypes())); parameterUpdater_->getParameterTypes()));
...@@ -62,7 +64,8 @@ void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config, ...@@ -62,7 +64,8 @@ void TrainerInternal::init(const std::shared_ptr<TrainerConfigHelper> &config,
} }
void TrainerInternal::trainOneBatch(int64_t batchId, void TrainerInternal::trainOneBatch(int64_t batchId,
const DataBatch& dataBatch) { const DataBatch& dataBatch,
std::vector<Argument>* outArgs) {
// true means updating parameter whenever gradient is ready during backward() // true means updating parameter whenever gradient is ready during backward()
bool doPipelineUpdate = bool doPipelineUpdate =
(intconfig_->mode != GradientMachine::kSgdSparseCpuTraining) && (intconfig_->mode != GradientMachine::kSgdSparseCpuTraining) &&
...@@ -84,7 +87,6 @@ void TrainerInternal::trainOneBatch(int64_t batchId, ...@@ -84,7 +87,6 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
} }
const std::vector<Argument>& inArgs = dataBatch.getStreams(); const std::vector<Argument>& inArgs = dataBatch.getStreams();
std::vector<Argument> outArgs;
PassType passType = parameterUpdater_->startBatch(actualBatchSize); PassType passType = parameterUpdater_->startBatch(actualBatchSize);
...@@ -114,7 +116,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId, ...@@ -114,7 +116,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
timer.start(); timer.start();
#endif #endif
REGISTER_TIMER("forwardBackward"); REGISTER_TIMER("forwardBackward");
forwardBackwardBatch(inArgs, outArgs, passType, updateCallback, forwardBackwardBatch(inArgs, *outArgs, passType, updateCallback,
doPipelineUpdate); doPipelineUpdate);
#ifndef PADDLE_DISABLE_TIMER #ifndef PADDLE_DISABLE_TIMER
timer.stop(); timer.stop();
...@@ -132,7 +134,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId, ...@@ -132,7 +134,7 @@ void TrainerInternal::trainOneBatch(int64_t batchId,
real cost = 0; real cost = 0;
{ {
REGISTER_TIMER("sumCost"); REGISTER_TIMER("sumCost");
cost = Argument::sumCosts(outArgs); cost = Argument::sumCosts(*outArgs);
} }
if (batchId % intconfig_->log_period == 0) { if (batchId % intconfig_->log_period == 0) {
......
...@@ -81,7 +81,9 @@ public: ...@@ -81,7 +81,9 @@ public:
* @param batchId current batch id * @param batchId current batch id
* @param dataBatch data for the batch * @param dataBatch data for the batch
*/ */
void trainOneBatch(int64_t batchId, const DataBatch& dataBatch); void trainOneBatch(int64_t batchId,
const DataBatch& dataBatch,
std::vector<Argument>* outArgs);
/** /**
* showParameterStats * showParameterStats
......
...@@ -130,7 +130,7 @@ message OptimizationConfig { ...@@ -130,7 +130,7 @@ message OptimizationConfig {
}; };
message TrainerConfig { message TrainerConfig {
required ModelConfig model_config = 1; optional ModelConfig model_config = 1;
optional DataConfig data_config = 2; optional DataConfig data_config = 2;
required OptimizationConfig opt_config = 3; required OptimizationConfig opt_config = 3;
optional DataConfig test_data_config = 4; optional DataConfig test_data_config = 4;
......
...@@ -12,3 +12,5 @@ ...@@ -12,3 +12,5 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.proto.TrainerConfig_pb2 import OptimizationConfig, TrainerConfig
from paddle.proto.ModelConfig_pb2 import ModelConfig
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册