提交 d7ee6ba1 编写于 作者: X xiexionghang

commit local runtimeEnvironment

此差异已折叠。
此差异已折叠。
......@@ -332,7 +332,7 @@ class ChannelReader {
}
if (cursor_ >= buffer_.size()) {
cursor_ = 0;
if (channel_->read(buffer_) == 0) {
if (channel_->Read(buffer_) == 0) {
failed_ = true;
return *this;
}
......
......@@ -149,7 +149,7 @@ std::vector<std::string> localfs_list(const std::string& path) {
std::shared_ptr<FILE> pipe;
int err_no = 0;
pipe = shell_popen(
string::format_string("find %s -type f -maxdepth 1", path.c_str()), "r",
string::format_string("find %s -maxdepth 1 -type f", path.c_str()), "r",
&err_no);
string::LineFileReader reader;
std::vector<std::string> list;
......@@ -452,5 +452,24 @@ void fs_mkdir(const std::string& path) {
LOG(FATAL) << "Not supported";
}
}
std::string fs_path_join(const std::string& dir, const std::string &path) {
if (dir.empty()) {
return path;
}
if (dir.back() == '/') {
return dir + path;
}
return dir + '/' + path;
}
std::pair<std::string, std::string> fs_path_split(const std::string &path) {
size_t pos = path.find_last_of('/');
if (pos == std::string::npos) {
return {".", path};
}
return {path.substr(0, pos), path.substr(pos + 1)};
}
} // end namespace framework
} // end namespace paddle
......@@ -97,5 +97,9 @@ extern std::string fs_tail(const std::string& path);
extern bool fs_exists(const std::string& path);
extern void fs_mkdir(const std::string& path);
extern std::string fs_path_join(const std::string& dir, const std::string &path);
extern std::pair<std::string, std::string> fs_path_split(const std::string &path);
} // namespace framework
} // namespace paddle
......@@ -136,6 +136,18 @@ std::string join_strings(const Container& strs, char delim) {
return str;
}
static inline bool end_with(const std::string& main_str, const std::string& str) {
return main_str.length() >= str.length() &&
strncmp(main_str.c_str() + main_str.length() - str.length(), str.c_str(), str.length()) ==
0;
}
static inline bool begin_with(const std::string& main_str, const std::string& str) {
return main_str.length() >= str.length() &&
strncmp(main_str.c_str(), str.c_str(), str.length()) == 0;
}
// A helper class for reading lines from file. A line buffer is maintained. It
// doesn't need to know the maximum possible length of a line.
......
BasedOnStyle: Google
AccessModifierOffset: -4
AlignAfterOpenBracket: AlwaysBreak
AlignOperands: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BreakConstructorInitializers: AfterColon
ColumnLimit: 100
ConstructorInitializerIndentWidth: 8
ContinuationIndentWidth: 8
DerivePointerAlignment: true
FixNamespaceComments: true
IndentCaseLabels: false
IndentWidth: 4
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 500
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 400
PointerAlignment: Left
SortIncludes: false
#pragma once
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int EpochAccessor::initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) {
_model_root_path = config["model_root_path"].as<std::string>() + "/";
_done_file_path = _model_root_path;
if (config["donefile"]) {
_done_file_path.append(config["donefile"].as<std::string>());
} else {
_done_file_path.append("epoch_donefile.txt");
}
if (!context_ptr->file_system->exists(_done_file_path)) {
VLOG(0) << "missing done file, path:" << _done_file_path;
}
std::string done_text = context_ptr->file_system->tail(_done_file_path);
_done_status = paddle::string::split_string(done_text, std::string("\t"));
_current_epoch_id = get_status<uint64_t>(EpochStatusFiled::EpochIdField);
_last_checkpoint_epoch_id = get_status<uint64_t>(EpochStatusFiled::CheckpointIdField);
_last_checkpoint_path = get_status<std::string>(EpochStatusFiled::CheckpointPathField);
return 0;
}
int HourlyEpochAccessor::initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) {
EpochAccessor::initialize(config, context_ptr);
return 0;
}
void HourlyEpochAccessor::next_epoch() {
_current_epoch_id = next_epoch_id(_current_epoch_id);
}
std::string HourlyEpochAccessor::text(uint64_t epoch_id) {
return std::to_string(epoch_id);
return format_timestamp(epoch_id, "%Y%m%d delta-%H");
}
bool HourlyEpochAccessor::data_ready(uint64_t epoch_id) {
return true;
}
int HourlyEpochAccessor::next_epoch_id(uint64_t epoch_id) {
uint64_t HourlyEpochAccessor::next_epoch_id(uint64_t epoch_id) {
if (epoch_id == 0) {
struct timeval now;
gettimeofday(&now, NULL);
......@@ -25,15 +50,19 @@ namespace feed {
}
return epoch_id + 3600;
}
bool HourlyEpochAccessor::is_last_epoch(uint64_t epoch_id) {
return ((epoch_id / 3600) % 24) == 23;
}
}
uint64_t HourlyEpochAccessor::epoch_time_interval() {
return 3600;
}
uint64_t HourlyEpochAccessor::epoch_timestamp(uint64_t epoch_id) {
return epoch_id;
}
}
bool HourlyEpochAccessor::need_save_model(uint64_t epoch_id, ModelSaveWay save_way) {
if (epoch_id == 0) {
return false;
......@@ -47,6 +76,7 @@ namespace feed {
}
return false;
}
std::string HourlyEpochAccessor::model_save_path(uint64_t epoch_id, ModelSaveWay save_way) {
if (save_way == ModelSaveWay::ModelSaveInferenceDelta) {
return _model_root_path + "/xbox/delta-" + std::to_string(epoch_id);
......@@ -57,6 +87,7 @@ namespace feed {
}
return "";
}
REGISTER_CLASS(EpochAccessor, HourlyEpochAccessor);
} // namespace feed
......
#pragma once
#include <boost/lexical_cast.hpp>
#include "paddle/fluid/string/to_string.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
......@@ -6,20 +9,41 @@ namespace paddle {
namespace custom_trainer {
namespace feed {
enum class EpochStatusFiled {
DateField = 0,
TimestampField = 1,
CheckpointPathField = 2,
EpochIdField = 3,
CheckpointIdField = 4
};
class EpochAccessor : public Accessor {
public:
EpochAccessor() {}
virtual ~EpochAccessor() {}
virtual int initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) = 0;
std::shared_ptr<TrainerContext> context_ptr);
virtual uint64_t current_epoch_id() {
return _current_epoch_id;
}
virtual void next_epoch() = 0;
virtual std::string text(uint64_t epoch_id) = 0;
virtual bool data_ready(uint64_t epoch_id) = 0;
virtual int next_epoch_id(uint64_t epoch_id) = 0;
virtual const std::string& checkpoint_path() {
return _last_checkpoint_path;
}
template <class T>
T get_status(EpochStatusFiled field) {
auto status = paddle::string::trim_spaces(_done_status[static_cast<int>(field)]);
return boost::lexical_cast<T>(status.c_str());
}
virtual void next_epoch() = 0;
virtual std::string model_root_path() {
return _model_root_path;
}
virtual std::string text(uint64_t epoch_id) = 0;
virtual uint64_t next_epoch_id(uint64_t epoch_id) = 0;
virtual bool is_last_epoch(uint64_t epoch_id) = 0;
//epoch间的数据时间间隔(秒)
virtual uint64_t epoch_time_interval() = 0;
......@@ -28,7 +52,13 @@ public:
virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way) = 0;
virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way) = 0;
protected:
uint64_t _current_epoch_id;
std::string _done_file_path;
std::string _model_root_path;
uint64_t _current_epoch_id = 0;
std::string _last_checkpoint_path;
uint64_t _last_checkpoint_epoch_id = 0;
std::vector<std::string> _done_status; //当前完成状态,统一存成string
};
REGISTER_REGISTERER(EpochAccessor);
......@@ -40,15 +70,12 @@ public:
std::shared_ptr<TrainerContext> context_ptr);
virtual void next_epoch();
virtual std::string text(uint64_t epoch_id);
virtual bool data_ready(uint64_t epoch_id);
virtual int next_epoch_id(uint64_t epoch_id);
virtual uint64_t next_epoch_id(uint64_t epoch_id);
virtual bool is_last_epoch(uint64_t epoch_id);
virtual uint64_t epoch_time_interval();
virtual uint64_t epoch_timestamp(uint64_t epoch_id);
virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way);
virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way);
private:
std::string _model_root_path;
};
} // namespace feed
......
#pragma once
#include <thread>
#include "paddle/fluid/framework/archive.h"
namespace paddle {
......
......@@ -106,7 +106,7 @@ BaseClassMap& global_factory_map_cpp();
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
base_class##Registerer::CreateInstanceByName(name)
}//namespace feed
}//namespace custom_trainer
......
#include <mpi.h>
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace paddle {
......@@ -68,26 +69,67 @@ public:
virtual void barrier(EnvironmentRole role) {
MPI_Barrier(mpi_node_info(role).mpi_comm);
}
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) {
auto& node_info = mpi_node_info(role);
int len = (int)ar.length();
int len = (int)ar.Length();
MPI_Bcast(&len, 1, MPI_INT, root_id, node_info.mpi_comm);
ar.resize(len);
ar.set_cursor(ar.buffer());
MPI_Bcast(ar.buffer(), len, MPI_BYTE, root, node_info.mpi_comm);
ar.Resize(len);
ar.SetCursor(ar.Buffer());
MPI_Bcast(ar.Buffer(), len, MPI_BYTE, root_id, node_info.mpi_comm);
}
protected:
virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str);
virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) {
if (type == EnvironmentLogType::MASTER_LOG && !is_master_node(role)) {
return;
}
VLOG(static_cast<int>(level)) << log_str;
}
inline MpiNodeInfo& mpi_node_info(EnvironmentRole role) {
return _roles_node_info[static_cast<int>(role)];
}
private:
std::vector<MpiNodeInfo> _roles_node_info;
};
REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment);
//用于本地模式单机训练
class LocalRuntimeEnvironment : public RuntimeEnvironment {
public:
LocalRuntimeEnvironment() {}
virtual ~LocalRuntimeEnvironment() {}
virtual int initialize(YAML::Node config) {
return 0;
}
virtual int wireup() {
return 0;
}
virtual uint32_t rank_id(EnvironmentRole role) {
return 0;
}
virtual uint32_t node_num(EnvironmentRole role) {
return 1;
}
virtual int set_role(EnvironmentRole role) {
return 0;
}
virtual void barrier(EnvironmentRole role) {
return;
}
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) {
return;
}
protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) {
VLOG(static_cast<int>(level)) << log_str;
}
};
REGISTER_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment);
} // namespace feed
} // namespace custom_trainer
......
......@@ -55,9 +55,9 @@ public:
//环境定制化log
template<class... ARGS>
void log(EnvironmentLogType type, EnvironmentLogLevel level,
const char* fmt, ARGS && ... args) {
print_log(type, level, paddle::string::format_string(fmt, args...));
void log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const char* fmt, ARGS && ... args) {
print_log(role, type, level, paddle::string::format_string(fmt, args...));
}
//多线程可调用接口 End
......@@ -69,14 +69,13 @@ public:
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0;
//接口只允许在主线程调用 End
protected:
virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str) = 0;
virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) = 0;
};
REGISTER_REGISTERER(RuntimeEnvironment);
std::string format_timestamp(time_t time, const char* format);
std::string format_timestamp(time_t time, const std::string& format) {
inline std::string format_timestamp(time_t time, const std::string& format) {
return format_timestamp(time, format.c_str());
}
......
train_thread_num : 10
environment :
environment_class : MPIRuntimeEnvironment
environment_class : LocalRuntimeEnvironment
io :
file_systems :
afs :
class : HadoopFileSystem
buffer_size : 1024000
ugis :
'default': 'feed_video,D3a0z8'
'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'
local :
class : LocalFileSystem
buffer_size : 1024000
dataset :
data_list :
train_sample :
prefetch_num : 2
root_path : ./sample
data_spit_interval : 300
data_path_formater : '%Y%m%d/%H%M'
data_reader : LineDataReader
done_file : to.hadoop.done
filename_prefix : part
pipeline_cmd : cat
parser :
class : LineDataParser
epoch:
epoch_class : HourlyEpochAccessor
epoch_class : HourlyEpochAccessor
model_root_path : ./model/
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include <cstdio>
#include <atomic>
#include <glog/logging.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class LineDataParser : public DataParser {
public:
LineDataParser() {}
virtual ~LineDataParser() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
return 0;
}
virtual int parse(const char* str, size_t len, DataItem& data) const {
size_t pos = 0;
while (pos < len && str[pos] != ' ') {
++pos;
}
if (pos >= len) {
VLOG(2) << "fail to parse line: " << std::string(str, len) << ", strlen: " << len;
return -1;
}
VLOG(5) << "getline: " << str << " , pos: " << pos << ", len: " << len;
data.id.assign(str, pos);
data.data.assign(str + pos + 1, len - pos - 1);
return 0;
}
virtual int parse(const char* str, DataItem& data) const {
size_t pos = 0;
while (str[pos] != '\0' && str[pos] != ' ') {
++pos;
}
if (str[pos] == '\0') {
VLOG(2) << "fail to parse line: " << str << ", get '\\0' at pos: " << pos;
return -1;
}
VLOG(5) << "getline: " << str << " , pos: " << pos;
data.id.assign(str, pos);
data.data.assign(str + pos + 1);
return 0;
}
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const {
return 0;
}
};
REGISTER_CLASS(DataParser, LineDataParser);
int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
_parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
if (_parser == nullptr) {
VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
return -1;
}
if (_parser->initialize(config["parser"], context) != 0) {
VLOG(2) << "fail to initialize parser" << config["parser"]["class"].as<std::string>();
return -1;
}
_pipeline_cmd = config["pipeline_cmd"].as<std::string>();
return 0;
}
class LineDataReader : public DataReader {
public:
LineDataReader() {}
virtual ~LineDataReader() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
if (DataReader::initialize(config, context) != 0) {
return -1;
}
_done_file_name = config["done_file"].as<std::string>();
_filename_prefix = config["filename_prefix"].as<std::string>("");
if (config["file_system"] && config["file_system"]["class"]) {
_file_system.reset(
CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
if (_file_system == nullptr ||
_file_system->initialize(config["file_system"], context) != 0) {
VLOG(2) << "fail to create class: "
<< config["file_system"]["class"].as<std::string>();
return -1;
}
} else if (context->file_system != nullptr) {
_file_system = context->file_system;
} else {
_file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
VLOG(2) << "fail to init file system";
return -1;
}
}
return 0;
}
//判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) {
auto done_file_path = _file_system->path_join(data_dir, _done_file_name);
if (_file_system->exists(done_file_path)) {
return true;
}
return false;
}
virtual std::vector<std::string> data_file_list(const std::string& data_dir) {
std::vector<std::string> data_files;
for (auto& filepath : _file_system->list(data_dir)) {
auto filename = _file_system->path_split(filepath).second;
if (filename != _done_file_name &&
string::begin_with(filename, _filename_prefix)) {
data_files.push_back(std::move(filepath));
}
}
return data_files;
}
//读取数据样本流中
virtual int read_all(const std::string& data_dir, framework::Channel<DataItem> data_channel) {
auto file_list = data_file_list(data_dir);
return read_all(file_list, data_channel);
}
virtual int read_all(const std::vector<std::string>& file_list, ::paddle::framework::Channel<DataItem> data_channel) {
auto deleter = [](framework::ChannelWriter<DataItem> *writer) {
if (writer) {
writer->Flush();
VLOG(3) << "writer auto flush";
}
delete writer;
};
std::unique_ptr<framework::ChannelWriter<DataItem>, decltype(deleter)> writer(new framework::ChannelWriter<DataItem>(data_channel.get()), deleter);
DataItem data_item;
int file_list_size = file_list.size();
std::atomic<bool> is_failed(false);
#pragma omp parallel for
for (int i = 0; i < file_list_size; ++i) {
const auto& filepath = file_list[i];
if (!is_failed) {
std::shared_ptr<FILE> fin = _file_system->open_read(filepath, _pipeline_cmd);
if (fin == nullptr) {
VLOG(2) << "fail to open file: " << filepath << ", with cmd: " << _pipeline_cmd;
is_failed = true;
continue;
}
char *buffer = nullptr;
size_t buffer_size = 0;
ssize_t line_len = 0;
while ((line_len = getline(&buffer, &buffer_size, fin.get())) != -1) {
if (line_len > 0 && buffer[line_len - 1] == '\n') {
buffer[--line_len] = '\0';
}
if (line_len <= 0) {
continue;
}
if (_parser->parse(buffer, line_len, data_item) == 0) {
(*writer) << std::move(data_item);
}
}
if (buffer != nullptr) {
free(buffer);
buffer = nullptr;
buffer_size = 0;
}
if (ferror(fin.get()) != 0) {
VLOG(2) << "fail to read file: " << filepath;
is_failed = true;
continue;
}
}
if (_file_system->err_no() != 0) {
_file_system->reset_err_no();
is_failed = true;
continue;
}
}
writer->Flush();
if (!(*writer)) {
VLOG(2) << "fail when write to channel";
is_failed = true;
}
data_channel->Close();
return is_failed ? -1 : 0;
}
virtual const DataParser* get_parser() {
return _parser.get();
}
private:
std::string _done_file_name; // without data_dir
std::string _filename_prefix;
std::shared_ptr<FileSystem> _file_system;
};
REGISTER_CLASS(DataReader, LineDataReader);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -48,9 +48,10 @@ public:
virtual ~DataParser() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
virtual int parse(const std::string& str, DataItem& data) const {
return parse(str.c_str(), str.size(), data);
return parse(str.c_str(), data);
}
virtual int parse(const char* str, size_t len, DataItem& data) const = 0;
virtual int parse(const char* str, DataItem& data) const = 0;
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;
};
REGISTER_REGISTERER(DataParser);
......@@ -59,29 +60,24 @@ class DataReader {
public:
DataReader() {}
virtual ~DataReader() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context);
//判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) = 0;
//读取dir下文件列表
virtual std::vector<std::string> data_file_list(const std::string& data_dir);
virtual std::vector<std::string> data_file_list(const std::string& data_dir) = 0;
//读取目录下数据到样本流中
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem>& data_channel) = 0;
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
//读取指定文件列表的数据到样本流中
virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem>& data_channel) = 0;
virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem> data_channel) = 0;
virtual const DataParser* get_parser() {
return _parser.get();
}
private:
protected:
std::shared_ptr<DataParser> _parser;//数据格式转换
std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
};
REGISTER_REGISTERER(DataReader);
//TODO
//可读取HDFS/DISK上数据的Reader,数据按行分隔
//HDFS/DISK - FileLineReader
}//namespace feed
}//namespace custom_trainer
}//namespace paddle
......@@ -6,15 +6,14 @@ namespace feed {
int Dataset::initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
if (!config["data_list"]) {
VLOG(0) << "miss data_list config in dataset, please check";
if (config["data_list"].Type() != YAML::NodeType::Map) {
VLOG(0) << "miss data_list config in dataset, or type error please check";
return -1;
}
int data_num = config["data_list"].size();
for (int i = 0; i < data_num; ++i) {
std::string name = config["data_list"][i]["name"].as<std::string>();
for (auto& data_config : config["data_list"]) {
std::string name = data_config.first.as<std::string>();
auto data_ptr = std::make_shared<DatasetContainer>();
if (data_ptr->initialize(config["data_list"][i], context) != 0) {
if (data_ptr->initialize(data_config.second, context) != 0) {
VLOG(0) << "dataset initialize failed, name:" << name;
return -1;
}
......@@ -23,12 +22,27 @@ int Dataset::initialize(
return 0;
}
inline void Dataset::pre_detect_data(uint64_t epoch_id) {
for (auto it = _data_containers.begin(); it != _data_containers.end(); ++it) {
it->second->pre_detect_data(epoch_id);
}
return;
}
inline void Dataset::pre_detect_data(
const std::string& data_name, uint64_t epoch_id) {
_data_containers[data_name]->pre_detect_data(epoch_id);
return;
}
inline DatasetStatus Dataset::epoch_data_status(uint64_t epoch_id) {
int status = static_cast<int>(DatasetStatus::Ready);
for (auto it = _data_containers.begin(); it != _data_containers.end(); ++it) {
auto d_status = static_cast<int>(it->second->epoch_data_status(epoch_id));
status = d_status < status ? d_status : status;
}
return static_cast<DatasetStatus>(status);
}
inline DatasetStatus Dataset::epoch_data_status(
const std::string& data_name, uint64_t epoch_id) {
return _data_containers[data_name]->epoch_data_status(epoch_id);
......
......@@ -22,9 +22,11 @@ public:
const YAML::Node& config, std::shared_ptr<TrainerContext> context);
//触发可预取的数据判断
virtual void pre_detect_data(uint64_t epoch_id);
virtual void pre_detect_data(const std::string& data_name, uint64_t epoch_id);
//获取数据状态
virtual DatasetStatus epoch_data_status(uint64_t epoch_id);
virtual DatasetStatus epoch_data_status(const std::string& data_name, uint64_t epoch_id);
//返回各DataContainer内的原始数据(maybe 压缩格式)
......
......@@ -23,6 +23,9 @@ int DatasetContainer::initialize(
//预取n轮样本数据
_prefetch_num = config["prefetch_num"].as<int>();
_dataset_list.resize(_prefetch_num);
for (int i = 0; i < _prefetch_num; ++i) {
_dataset_list[i].reset(new DatasetInfo);
}
_data_root_paths = paddle::string::split_string(
config["root_path"].as<std::string>(), " ");
......@@ -48,21 +51,32 @@ void DatasetContainer::pre_detect_data(uint64_t epoch_id) {
LOG(FATAL) << "timestamp:" << timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval();
return;
}
size_t data_num = data_num_for_train(timestamp, epoch_accessor->epoch_time_interval(), _data_split_interval);
uint64_t data_timestamp = timestamp % _data_split_interval == 0 ? timestamp : (timestamp / _data_split_interval + 1) * _data_split_interval;
std::vector<std::string> data_path_list;
for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) {
for (int j = 0; j < data_num && status == 0; ++j) {
std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater);
std::string data_dir = _data_root_paths[i] + "/" + path_suffix;
status = read_data_list(data_dir, data_path_list);
}
if (_downloader_thread == nullptr) {
_downloader_thread.reset(new std::thread([this, timestamp](){
async_download_data(timestamp);
}));
}
if (status == 0) {
auto dataset_info = dataset(timestamp);
dataset_info->timestamp = timestamp;
dataset_info->file_path_list = std::move(data_path_list);
dataset_info->status = DatasetStatus::Detected;
for (int detect_idx = 0 ; detect_idx < _prefetch_num; ++detect_idx) {
if (DatasetStatus::Empty != data_status(timestamp)) {
continue;
}
size_t data_num = data_num_for_train(timestamp, epoch_accessor->epoch_time_interval(), _data_split_interval);
uint64_t data_timestamp = timestamp % _data_split_interval == 0 ? timestamp : (timestamp / _data_split_interval + 1) * _data_split_interval;
std::vector<std::string> data_path_list;
for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) {
for (int j = 0; j < data_num && status == 0; ++j) {
std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater);
std::string data_dir = _data_root_paths[i] + "/" + path_suffix;
status = read_data_list(data_dir, data_path_list);
}
}
if (status == 0) {
auto dataset_info = dataset(timestamp);
dataset_info->timestamp = timestamp;
dataset_info->file_path_list = std::move(data_path_list);
dataset_info->status = DatasetStatus::Detected;
}
timestamp += epoch_accessor->epoch_time_interval();
}
return;
}
......@@ -134,7 +148,7 @@ void DatasetContainer::async_download_data(uint64_t start_timestamp) {
LOG(FATAL) << "timestamp:" << start_timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval();
return;
}
while (true) {
while (!_stop_download) {
auto dataset_info = dataset(start_timestamp);
while (data_status(start_timestamp) != DatasetStatus::Detected) {
sleep(30);
......
......@@ -30,6 +30,7 @@ enum class DatasetStatus {
Downloding = 2,
Ready = 3
};
struct DatasetInfo {
uint64_t timestamp = 0;
std::vector<std::string> file_path_list;
......@@ -40,10 +41,14 @@ struct DatasetInfo {
class DatasetContainer {
public:
DatasetContainer() {}
virtual ~DatasetContainer() {}
virtual ~DatasetContainer() {
if (_downloader_thread != nullptr) {
_stop_download = true;
_downloader_thread->join();
}
}
virtual int initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context);
virtual void run();
//触发可预取的数据判断
virtual void pre_detect_data(uint64_t epoch_id);
//获取数据状态
......@@ -62,6 +67,7 @@ protected:
virtual std::shared_ptr<DatasetInfo> dataset(uint64_t timestamp);
int _prefetch_num = 0;
bool _stop_download = false;
int _data_split_interval = 60; //样本切分周期(秒)
YAML::Node _dataset_config;
std::string _data_path_formater;
......
......@@ -43,86 +43,84 @@ std::unique_ptr<paddle::framework::ProgramDesc> Load(
}
struct SimpleExecutor::Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
}
const ::paddle::platform::Place& place;
::paddle::framework::Executor executor;
::std::unique_ptr<::paddle::framework::ProgramDesc> main_program;
::std::unique_ptr<framework::ExecutorPrepareContext> prepare_context;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner;
};
SimpleExecutor::SimpleExecutor() {
}
SimpleExecutor::~SimpleExecutor() {
}
int SimpleExecutor::initialize(YAML::Node exe_config,
class SimpleExecutor : public Executor {
public:
SimpleExecutor() {};
virtual ~SimpleExecutor() {};
virtual int initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr) {
paddle::framework::InitDevices(false);
if (exe_config["num_threads"]) {
paddle::platform::SetNumThreads(exe_config["num_threads"].as<int>());
} else {
paddle::platform::SetNumThreads(1);
}
if (!exe_config["startup_program"] ||
!exe_config["main_program"]) {
VLOG(2) << "fail to load config";
return -1;
}
paddle::framework::InitDevices(false);
if (exe_config["num_threads"]) {
paddle::platform::SetNumThreads(exe_config["num_threads"].as<int>());
} else {
paddle::platform::SetNumThreads(1);
}
try {
_context.reset(new SimpleExecutor::Context(context_ptr->cpu_place));
auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>());
if (startup_program == nullptr) {
VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
if (!exe_config["startup_program"] ||
!exe_config["main_program"]) {
VLOG(2) << "fail to load config";
return -1;
}
_context->executor.Run(*startup_program, this->scope(), 0, false, true);
_context->main_program = Load(&_context->executor, exe_config["main_program"].as<std::string>());
if (_context->main_program == nullptr) {
VLOG(2) << "fail to load main_program: " << exe_config["main_program"].as<std::string>();
try {
_context.reset(new SimpleExecutor::Context(context_ptr->cpu_place));
auto startup_program = Load(&_context->executor, exe_config["startup_program"].as<std::string>());
if (startup_program == nullptr) {
VLOG(2) << "fail to load startup_program: " << exe_config["startup_program"].as<std::string>();
return -1;
}
_context->executor.Run(*startup_program, this->scope(), 0, false, true);
_context->main_program = Load(&_context->executor, exe_config["main_program"].as<std::string>());
if (_context->main_program == nullptr) {
VLOG(2) << "fail to load main_program: " << exe_config["main_program"].as<std::string>();
return -1;
}
_context->prepare_context = _context->executor.Prepare(*_context->main_program, 0);
_context->executor.CreateVariables(*_context->main_program, this->scope(), 0);
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
_context.reset(nullptr);
return -1;
}
_context->prepare_context = _context->executor.Prepare(*_context->main_program, 0);
_context->executor.CreateVariables(*_context->main_program, this->scope(), 0);
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
_context.reset(nullptr);
return -1;
}
return 0;
}
int SimpleExecutor::run() {
if (_context == nullptr) {
VLOG(2) << "need initialize before run";
return -1;
return 0;
}
try {
_context->executor.RunPreparedContext(_context->prepare_context.get(), this->scope(),
false, /* don't create local scope each time*/
false /* don't create variable each time */);
// For some other vector like containers not cleaned after each batch.
_context->tensor_array_batch_cleaner.CollectNoTensorVars(this->scope());
_context->tensor_array_batch_cleaner.ResetNoTensorVars();
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
return -1;
virtual int run() {
if (_context == nullptr) {
VLOG(2) << "need initialize before run";
return -1;
}
try {
_context->executor.RunPreparedContext(_context->prepare_context.get(), this->scope(),
false, /* don't create local scope each time*/
false /* don't create variable each time */);
// For some other vector like containers not cleaned after each batch.
_context->tensor_array_batch_cleaner.CollectNoTensorVars(this->scope());
_context->tensor_array_batch_cleaner.ResetNoTensorVars();
} catch (::paddle::platform::EnforceNotMet& err) {
VLOG(2) << err.what();
return -1;
}
return 0;
}
return 0;
}
protected:
struct Context {
Context(const ::paddle::platform::Place& place) : place(place), executor(place) {
}
const ::paddle::platform::Place& place;
::paddle::framework::Executor executor;
::std::unique_ptr<::paddle::framework::ProgramDesc> main_program;
::std::unique_ptr<framework::ExecutorPrepareContext> prepare_context;
details::TensorArrayBatchCleaner tensor_array_batch_cleaner;
};
std::unique_ptr<Context> _context;
};
REGISTER_CLASS(Executor, SimpleExecutor);
} // namespace feed
......
......@@ -42,18 +42,6 @@ protected:
};
REGISTER_REGISTERER(Executor);
class SimpleExecutor : public Executor {
public:
SimpleExecutor();
virtual ~SimpleExecutor();
virtual int initialize(YAML::Node exe_config,
std::shared_ptr<TrainerContext> context_ptr);
virtual int run();
protected:
struct Context;
std::unique_ptr<Context> _context;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
#include <unordered_map>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "glog/logging.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class AutoFileSystem : public FileSystem {
public:
int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) override {
_file_system.clear();
if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) {
for (auto& prefix_fs: config["file_systems"]) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>("")));
if (fs == nullptr) {
VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
return -1;
}
if (fs->initialize(prefix_fs.second, context) != 0) {
VLOG(2) << "fail to initialize class: " << prefix_fs.second["class"].as<std::string>("");
return 0;
}
_file_system.emplace(prefix_fs.first.as<std::string>(""), std::move(fs));
}
}
if (_file_system.find("default") == _file_system.end()) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) {
return -1;
}
_file_system.emplace("default", std::move(fs));
}
return 0;
}
std::shared_ptr<FILE> open_read(const std::string& path, const std::string& converter)
override {
return get_file_system(path)->open_read(path, converter);
}
std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter)
override {
return get_file_system(path)->open_write(path, converter);
}
int64_t file_size(const std::string& path) override {
return get_file_system(path)->file_size(path);
}
void remove(const std::string& path) override {
get_file_system(path)->remove(path);
}
std::vector<std::string> list(const std::string& path) override {
return get_file_system(path)->list(path);
}
std::string tail(const std::string& path) override {
return get_file_system(path)->tail(path);
}
bool exists(const std::string& path) override {
return get_file_system(path)->exists(path);
}
void mkdir(const std::string& path) override {
get_file_system(path)->mkdir(path);
}
FileSystem* get_file_system(const std::string& path) {
auto pos = path.find_first_of(":");
if (pos != std::string::npos) {
auto substr = path.substr(0, pos + 1);
auto fs_it = _file_system.find(substr);
if (fs_it != _file_system.end()) {
return fs_it->second.get();
}
}
VLOG(5) << "path: " << path << ", select default file system";
return _file_system["default"].get();
}
int err_no() const override {
if (_err_no == 0) {
for (const auto& file_system : _file_system) {
if (file_system.second->err_no() != 0) {
const_cast<int&>(_err_no) = -1;
break;
}
}
}
return FileSystem::err_no();
}
void reset_err_no() override {
_err_no = 0;
for (auto& file_system : _file_system) {
file_system.second->reset_err_no();
}
}
private:
std::unordered_map<std::string, std::unique_ptr<FileSystem>> _file_system;
};
REGISTER_CLASS(FileSystem, AutoFileSystem);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
namespace paddle {
namespace custom_trainer {
namespace feed {
std::string FileSystem::path_join(const std::string& dir, const std::string& path) {
if (dir.empty()) {
return path;
}
if (dir.back() == '/') {
return dir + path;
}
return dir + '/' + path;
}
std::pair<std::string, std::string> FileSystem::path_split(const std::string& path) {
size_t pos = path.find_last_of('/');
if (pos == std::string::npos) {
return {".", path};
}
return {path.substr(0, pos), path.substr(pos + 1)};
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <cstdio>
#include <vector>
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include <yaml-cpp/yaml.h>
namespace paddle {
namespace custom_trainer {
namespace feed {
class FileSystem {
public:
FileSystem() {}
virtual ~FileSystem() {}
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
virtual std::shared_ptr<FILE> open_read(const std::string& path, const std::string& converter) = 0;
virtual std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter) = 0;
virtual int64_t file_size(const std::string& path) = 0;
virtual void remove(const std::string& path) = 0;
virtual std::vector<std::string> list(const std::string& path) = 0;
virtual std::string tail(const std::string& path) = 0;
virtual bool exists(const std::string& path) = 0;
virtual void mkdir(const std::string& path) = 0;
virtual std::string path_join(const std::string& dir, const std::string& path);
virtual std::pair<std::string, std::string> path_split(const std::string& path);
virtual int err_no() const {
return _err_no;
}
inline operator bool() {
return err_no() == 0;
}
virtual void reset_err_no() {
_err_no = 0;
}
protected:
int _err_no = 0;
};
REGISTER_REGISTERER(FileSystem);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
#include <unordered_map>
#include <tuple>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/string/piece.h"
#include "glog/logging.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class HadoopFileSystem : public FileSystem {
public:
int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) override {
_buffer_size = config["buffer_size"].as<size_t>(0);
_hdfs_command = config["hdfs_command"].as<std::string>("hadoop fs");
_ugi.clear();
if (config["ugis"] && config["ugis"].Type() == YAML::NodeType::Map) {
for (const auto& prefix_ugi : config["ugis"]) {
_ugi.emplace(prefix_ugi.first.as<std::string>(), prefix_ugi.second.as<std::string>());
}
}
if (_ugi.find("default") == _ugi.end()) {
VLOG(2) << "fail to load default ugi";
return -1;
}
return 0;
}
std::shared_ptr<FILE> open_read(const std::string& path, const std::string& converter)
override {
std::string cmd;
if (string::end_with(path, ".gz")) {
cmd = string::format_string(
"%s -text \"%s\"", hdfs_command(path).c_str(), path.c_str());
} else {
cmd = string::format_string("%s -cat \"%s\"", hdfs_command(path).c_str(), path.c_str());
}
bool is_pipe = true;
shell_add_read_converter(cmd, is_pipe, converter);
return shell_open(cmd, is_pipe, "r", _buffer_size, &_err_no);
}
std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter)
override {
std::string cmd =
string::format_string("%s -put - \"%s\"", hdfs_command(path).c_str(), path.c_str());
bool is_pipe = true;
if (string::end_with(path, ".gz\"")) {
shell_add_write_converter(cmd, is_pipe, "gzip");
}
shell_add_write_converter(cmd, is_pipe, converter);
return shell_open(cmd, is_pipe, "w", _buffer_size, &_err_no);
}
int64_t file_size(const std::string& path) override {
_err_no = -1;
VLOG(2) << "not support";
return 0;
}
void remove(const std::string& path) override {
if (path == "") {
return;
}
shell_execute(string::format_string(
"%s -rmr %s &>/dev/null; true", _hdfs_command.c_str(), path.c_str()));
}
std::vector<std::string> list(const std::string& path) override {
if (path == "") {
return {};
}
auto paths = split_path(path);
int err_no = 0;
std::vector<std::string> list;
do {
err_no = 0;
std::shared_ptr<FILE> pipe;
pipe = shell_popen(
string::format_string(
"%s -ls %s | ( grep ^- ; [ $? != 2 ] )",
hdfs_command(path).c_str(),
path.c_str()),
"r",
&err_no);
string::LineFileReader reader;
list.clear();
while (reader.getline(&*pipe)) {
std::vector<std::string> line = string::split_string(reader.get());
if (line.size() != 8) {
continue;
}
list.push_back(get_prefix(paths) + line[7]);
}
} while (err_no == -1);
return list;
}
std::string tail(const std::string& path) override {
if (path == "") {
return "";
}
return shell_get_command_output(string::format_string(
"%s -text %s | tail -1 ", hdfs_command(path).c_str(), path.c_str()));
}
bool exists(const std::string& path) override {
std::string test = shell_get_command_output(string::format_string(
"%s -test -e %s ; echo $?", hdfs_command(path).c_str(), path.c_str()));
if (string::trim_spaces(test) == "0") {
return true;
}
return false;
}
void mkdir(const std::string& path) override {
if (path == "") {
return;
}
shell_execute(string::format_string(
"%s -mkdir %s; true", hdfs_command(path).c_str(), path.c_str()));
}
std::string hdfs_command(const std::string& path) {
auto paths = split_path(path);
auto it = _ugi.find(std::get<1>(paths).ToString());
if (it != _ugi.end()) {
return hdfs_command_with_ugi(it->second);
}
VLOG(5) << "path: " << path << ", select default ugi";
return hdfs_command_with_ugi(_ugi["default"]);
}
std::string hdfs_command_with_ugi(std::string ugi) {
return string::format_string(
"%s -Dhadoop.job.ugi=\"%s\"", _hdfs_command.c_str(), ugi.c_str());
}
private:
std::string get_prefix(const std::tuple<string::Piece, string::Piece, string::Piece>& paths) {
if (std::get<1>(paths).len() == 0) {
return std::get<0>(paths).ToString();
}
return std::get<0>(paths).ToString() + "//" + std::get<1>(paths).ToString();
}
// parse "xxx://abc.def:8756/user" as "xxx:", "abc.def:8756", "/user"
// parse "xxx:/user" as "xxx:", "", "/user"
// parse "xxx://abc.def:8756" as "xxx:", "abc.def:8756", ""
// parse "other" as "", "", "other"
std::tuple<string::Piece, string::Piece, string::Piece> split_path(string::Piece path) {
std::tuple<string::Piece, string::Piece, string::Piece> result{string::SubStr(path, 0, 0), string::SubStr(path, 0, 0), path};
auto fs_pos = string::Find(path, ':', 0) + 1;
if (path.len() > fs_pos) {
std::get<0>(result) = string::SubStr(path, 0, fs_pos);
path = string::SkipPrefix(path, fs_pos);
if (string::HasPrefix(path, "//")) {
path = string::SkipPrefix(path, 2);
auto end_pos = string::Find(path, '/', 0);
if (end_pos != string::Piece::npos) {
std::get<1>(result) = string::SubStr(path, 0, end_pos);
std::get<2>(result) = string::SkipPrefix(path, end_pos);
} else {
std::get<1>(result) = path;
}
} else {
std::get<2>(result) = path;
}
}
return result;
}
size_t _buffer_size = 0;
std::string _hdfs_command;
std::unordered_map<std::string, std::string> _ugi;
};
REGISTER_CLASS(FileSystem, HadoopFileSystem);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include <string>
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "glog/logging.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class LocalFileSystem : public FileSystem {
public:
int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) override {
_buffer_size = config["buffer_size"].as<size_t>(0);
return 0;
}
std::shared_ptr<FILE> open_read(const std::string& path, const std::string& converter) override {
std::string cmd = path;
bool is_pipe = false;
if (string::end_with(path, ".gz")) {
shell_add_read_converter(cmd, is_pipe, "zcat");
}
shell_add_read_converter(cmd, is_pipe, converter);
return shell_open(cmd, is_pipe, "r", _buffer_size);
}
std::shared_ptr<FILE> open_write(const std::string& path, const std::string& converter) override {
std::string cmd = path;
shell_execute(string::format_string("mkdir -p $(dirname \"%s\")", path.c_str()));
bool is_pipe = false;
if (string::end_with(path, ".gz")) {
shell_add_write_converter(cmd, is_pipe, "gzip");
}
shell_add_write_converter(cmd, is_pipe, converter);
return shell_open(cmd, is_pipe, "w", _buffer_size);
}
int64_t file_size(const std::string& path) override {
struct stat buf;
if (0 != stat(path.c_str(), &buf)) {
LOG(FATAL) << "file stat not zero";
return -1;
}
return (int64_t)buf.st_size;
}
void remove(const std::string& path) override {
if (path == "") {
return;
}
shell_execute(string::format_string("rm -rf %s", path.c_str()));
}
std::vector<std::string> list(const std::string& path) override {
if (path == "") {
return {};
}
std::shared_ptr<FILE> pipe;
pipe = shell_popen(
string::format_string("find %s -maxdepth 1 -type f", path.c_str()), "r", &_err_no);
string::LineFileReader reader;
std::vector<std::string> list;
while (reader.getline(&*pipe)) {
list.push_back(reader.get());
}
return list;
}
std::string tail(const std::string& path) override {
if (path == "") {
return "";
}
return shell_get_command_output(string::format_string("tail -1 %s ", path.c_str()));
}
bool exists(const std::string& path) override {
std::string test_f = shell_get_command_output(
string::format_string("[ -f %s ] ; echo $?", path.c_str()));
if (string::trim_spaces(test_f) == "0") {
return true;
}
std::string test_d = shell_get_command_output(
string::format_string("[ -d %s ] ; echo $?", path.c_str()));
if (string::trim_spaces(test_d) == "0") {
return true;
}
return false;
}
void mkdir(const std::string& path) override {
if (path == "") {
return;
}
shell_execute(string::format_string("mkdir -p %s", path.c_str()));
}
private:
size_t _buffer_size = 0;
};
REGISTER_CLASS(FileSystem, LocalFileSystem);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
void shell_add_write_converter(std::string& path, bool& is_pipe, // NOLINT
const std::string& converter) {
if (converter == "") {
return;
}
if (!is_pipe) {
path = string::format_string("( %s ) > \"%s\"", converter.c_str(), path.c_str());
is_pipe = true;
} else {
path = string::format_string("%s | %s", converter.c_str(), path.c_str());
}
}
void shell_add_read_converter(std::string& path, bool& is_pipe, const std::string& converter) {
if (converter == "") {
return;
}
if (!is_pipe) {
path = string::format_string("( %s ) < \"%s\"", converter.c_str(), path.c_str());
is_pipe = true;
} else {
path = string::format_string("%s | %s", path.c_str(), converter.c_str());
}
}
std::shared_ptr<FILE> shell_open(
const std::string& path,
bool is_pipe,
const std::string& mode,
size_t buffer_size,
int* err_no) {
std::shared_ptr<FILE> fp = nullptr;
if (!is_pipe) {
fp = shell_fopen(path, mode);
} else {
fp = shell_popen(path, mode, err_no);
}
if (buffer_size > 0) {
char* buffer = new char[buffer_size];
CHECK_EQ(0, setvbuf(&*fp, buffer, _IOFBF, buffer_size));
fp = {&*fp, [fp, buffer](FILE*) mutable { // NOLINT
CHECK(fp.unique()); // NOLINT
fp = nullptr;
delete[] buffer;
}};
}
return fp;
}
std::shared_ptr<FILE> shell_fopen(const std::string& path, const std::string& mode) {
#if defined _WIN32 || defined __APPLE__
return nullptr;
#else
if (shell_verbose()) {
LOG(INFO) << "Opening file[" << path << "] with mode[" << mode << "]";
}
FILE* fp;
if (!(fp = fopen(path.c_str(), mode.c_str()))) {
LOG(FATAL) << "fopen fail, path[" << path << "], mode[" << mode << "]";
}
return {fp, [path](FILE* fp) {
if (shell_verbose()) {
LOG(INFO) << "Closing file[" << path << "]";
}
if (0 != fclose(fp)) {
LOG(FATAL) << "fclose fail, path[" << path << "]";
}
}};
#endif
}
// Close all open file descriptors
// The implementation is async signal safe
// Mostly copy from CPython code
static int close_open_fds_internal() {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
struct linux_dirent {
long d_ino = 0; // NOLINT
off_t d_off;
unsigned short d_reclen = 0; // NOLINT
char d_name[256];
};
int dir_fd = -1;
if ((dir_fd = open("/proc/self/fd", O_RDONLY)) < 0) {
LOG(FATAL) << "proc/self/fd open fail";
return -1;
}
char buffer[sizeof(linux_dirent)];
for (;;) {
int bytes = 0;
if ((bytes =
syscall(SYS_getdents,
dir_fd,
reinterpret_cast<linux_dirent*>(buffer),
sizeof(buffer))) < 0) {
LOG(FATAL) << "syscall fail";
return -1;
}
if (bytes == 0) {
break;
}
linux_dirent* entry = NULL;
for (int offset = 0; offset < bytes; offset += entry->d_reclen) {
entry = reinterpret_cast<linux_dirent*>(buffer + offset);
int fd = 0;
const char* s = entry->d_name;
while (*s >= '0' && *s <= '9') {
fd = fd * 10 + (*s - '0');
s++;
}
if (s != entry->d_name && fd != dir_fd && fd >= 3) {
close(fd);
}
}
}
close(dir_fd);
return 0;
#endif
}
static int shell_popen_fork_internal(
const char* real_cmd,
bool do_read,
int parent_end,
int child_end) {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
int child_pid = -1;
// Too frequent calls to fork() makes openmpi very slow. Use vfork() instead.
// But vfork() is very dangerous. Be careful.
if ((child_pid = vfork()) < 0) {
return -1;
}
// The following code is async signal safe (No memory allocation, no access to
// global data, etc.)
if (child_pid != 0) {
return child_pid;
}
int child_std_end = do_read ? 1 : 0;
close(parent_end);
if (child_end != child_std_end) {
if (dup2(child_end, child_std_end) != child_std_end) {
return -1;
}
close(child_end);
}
close_open_fds_internal();
if (execl("/bin/bash", "bash", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
#endif
}
std::shared_ptr<FILE> shell_popen(const std::string& cmd, const std::string& mode, int* err_no) {
#if defined _WIN32 || defined __APPLE__
return nullptr;
#else
bool do_read = mode == "r";
bool do_write = mode == "w";
if (!(do_read || do_write)) {
*err_no = -1;
return NULL;
}
if (shell_verbose()) {
LOG(INFO) << "Opening pipe[" << cmd << "] with mode[" << mode << "]";
}
std::string real_cmd = "set -o pipefail; " + cmd;
int pipe_fds[2];
if (pipe(pipe_fds) != 0) {
*err_no = -1;
return NULL;
}
int parent_end = 0;
int child_end = 0;
if (do_read) {
parent_end = pipe_fds[0];
child_end = pipe_fds[1];
} else if (do_write) {
parent_end = pipe_fds[1];
child_end = pipe_fds[0];
}
int child_pid = shell_popen_fork_internal(real_cmd.c_str(), do_read, parent_end, child_end);
close(child_end);
fcntl(parent_end, F_SETFD, FD_CLOEXEC);
FILE* fp;
if ((fp = fdopen(parent_end, mode.c_str())) == NULL) {
*err_no = -1;
return NULL;
}
return {fp, [child_pid, cmd, err_no](FILE* fp) {
if (shell_verbose()) {
LOG(INFO) << "Closing pipe[" << cmd << "]";
}
if (fclose(fp) != 0) {
*err_no = -1;
}
int wstatus = -1;
waitpid(child_pid, &wstatus, 0);
if (wstatus == 0 || wstatus == (128 + SIGPIPE) * 256 ||
(wstatus == -1 && errno == ECHILD)) {
} else {
*err_no = -1;
LOG(WARNING) << "status[" << wstatus << "], cmd[" << cmd << "]"
<< ", err_no[" << *err_no << "]";
}
if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD";
}
}};
#endif
}
static int shell_p2open_fork_internal(const char* real_cmd, int pipein_fds[2], int pipeout_fds[2]) {
#if defined _WIN32 || defined __APPLE__
return 0;
#else
int child_pid = -1;
if ((child_pid = fork()) < 0) {
return -1;
}
if (child_pid != 0) {
return child_pid;
}
close(pipein_fds[0]);
close(pipeout_fds[1]);
if (pipein_fds[1] != 1) {
if (dup2(pipein_fds[1], 1) != 1) {
return -1;
}
close(pipein_fds[1]);
}
if (pipeout_fds[0] != 0) {
if (dup2(pipeout_fds[0], 0) != 0) {
return -1;
}
close(pipeout_fds[0]);
}
close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
#endif
}
std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(const std::string& cmd) {
#if defined _WIN32 || defined __APPLE__
return {};
#else
if (shell_verbose()) {
LOG(INFO) << "Opening bidirectional pipe[" << cmd << "]";
}
std::string real_cmd = "set -o pipefail; " + cmd;
int pipein_fds[2];
int pipeout_fds[2];
if (pipe(pipein_fds) != 0) {
return {NULL, NULL};
}
if (pipe(pipeout_fds) != 0) {
return {NULL, NULL};
}
int child_pid = shell_p2open_fork_internal(real_cmd.c_str(), pipein_fds, pipeout_fds);
close(pipein_fds[1]);
close(pipeout_fds[0]);
fcntl(pipein_fds[0], F_SETFD, FD_CLOEXEC);
fcntl(pipeout_fds[1], F_SETFD, FD_CLOEXEC);
std::shared_ptr<int> child_life = {
NULL, [child_pid, cmd](void*) {
if (shell_verbose()) {
LOG(INFO) << "Closing bidirectional pipe[" << cmd << "]";
}
int wstatus, ret;
do {
PCHECK((ret = waitpid(child_pid, &wstatus, 0)) >= 0 ||
(ret == -1 && errno == EINTR));
} while (ret == -1 && errno == EINTR);
PCHECK(wstatus == 0 || wstatus == (128 + SIGPIPE) * 256 ||
(wstatus == -1 && errno == ECHILD))
<< "status[" << wstatus << "], cmd[" << cmd << "]";
if (wstatus == -1 && errno == ECHILD) {
LOG(WARNING) << "errno is ECHILD";
}
}};
FILE* in_fp;
PCHECK((in_fp = fdopen(pipein_fds[0], "r")) != NULL);
FILE* out_fp;
PCHECK((out_fp = fdopen(pipeout_fds[1], "w")) != NULL);
return {{in_fp, [child_life](FILE* fp) { PCHECK(fclose(fp) == 0); }},
{out_fp, [child_life](FILE* fp) { PCHECK(fclose(fp) == 0); }}};
#endif
}
std::string shell_get_command_output(const std::string& cmd) {
#if defined _WIN32 || defined __APPLE__
return "";
#else
int err_no = 0;
do {
if (err_no == -1) {
sleep(10);
}
err_no = 0;
std::shared_ptr<FILE> pipe = shell_popen(cmd, "r", &err_no);
string::LineFileReader reader;
if (reader.getdelim(&*pipe, 0)) {
pipe = nullptr;
if (err_no == 0) {
return reader.get();
}
}
} while (err_no == -1);
return "";
#endif
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fcntl.h>
#include <sys/stat.h>
#ifdef _WIN32
#include <windows.h>
#else
#include <sys/syscall.h>
#endif
#include <sys/types.h>
#ifndef _WIN32
#include <sys/wait.h>
#endif
#include <memory>
#include <string>
#include <utility>
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
inline bool& shell_verbose_internal() {
static bool x = false;
return x;
}
inline bool shell_verbose() {
return shell_verbose_internal();
}
inline void shell_set_verbose(bool x) {
shell_verbose_internal() = x;
}
extern std::shared_ptr<FILE> shell_fopen(const std::string& path, const std::string& mode);
extern std::shared_ptr<FILE> shell_popen(
const std::string& cmd,
const std::string& mode,
int* err_no);
extern std::pair<std::shared_ptr<FILE>, std::shared_ptr<FILE>> shell_p2open(const std::string& cmd);
inline void shell_execute(const std::string& cmd) {
int err_no = 0;
do {
err_no = 0;
shell_popen(cmd, "w", &err_no);
} while (err_no == -1);
}
extern std::string shell_get_command_output(const std::string& cmd);
extern void shell_add_read_converter(std::string& path, bool& is_pipe, const std::string& converter);
extern std::shared_ptr<FILE> shell_open(const std::string& path, bool is_pipe, const std::string& mode, size_t buffer_size, int* err_no = 0);
extern void shell_add_write_converter(std::string& path, bool& is_pipe, const std::string& converter);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
20190710 1562775817 afs:/user/feed/mlarch/feed_multiTarget_model/magnet_duration_model_new_label2/batch_model/20190710_18 21 18
20190710 1562779976 afs:/user/feed/mlarch/feed_multiTarget_model/magnet_duration_model_new_label2/batch_model/20190710_18 22 18
20190711 1562783841 afs:/user/feed/mlarch/feed_multiTarget_model/magnet_duration_model_new_label2/batch_model/20190711_0 1565625600 1565625600
......@@ -5,6 +5,8 @@
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
......@@ -20,26 +22,45 @@ int InitEnvProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
YAML::Node config = _context_ptr->trainer_config;
//environment
std::string env_class = config["environment"]["environment_class"].as<std::string>();
auto* environment = CREATE_CLASS(RuntimeEnvironment, env_class);
if (environment->initialize(config["environment"]) != 0) {
context_ptr->environment.reset(CREATE_CLASS(RuntimeEnvironment, env_class));
if (context_ptr->environment->initialize(config["environment"]) != 0) {
return -1;
}
//file_system
context_ptr->file_system.reset(CREATE_CLASS(FileSystem, "AutoFileSystem"));
if (context_ptr->file_system->initialize(config["io"], context_ptr) != 0) {
return -1;
}
context_ptr->environment.reset(environment);
//epoch
std::string epoch_class = config["epoch"]["epoch_class"].as<std::string>();
auto* epoch = CREATE_CLASS(EpochAccessor, epoch_class);
if (epoch->initialize(config["epoch"], context_ptr) != 0) {
context_ptr->epoch_accessor.reset(CREATE_CLASS(EpochAccessor, epoch_class));
if (context_ptr->epoch_accessor->initialize(config["epoch"], context_ptr) != 0) {
return -1;
}
//Dataset
context_ptr->dataset.reset(new Dataset());
if (context_ptr->dataset->initialize(config["dataset"], context_ptr) != 0) {
return -1;
}
context_ptr->epoch_accessor.reset(epoch);
VLOG(3) << "Env initialize success";
return 0;
}
int InitEnvProcess::run() {
auto* epoch_accessor = _context_ptr->epoch_accessor.get();
VLOG(3) << "Trainer Resume From epoch:" << epoch_accessor->current_epoch_id();
auto next_epoch_id = epoch_accessor->next_epoch_id(epoch_accessor->current_epoch_id());
_context_ptr->dataset->pre_detect_data(next_epoch_id);
//step 1. psserver init
//step2. psserver load
VLOG(3) << "Psserver Start Success";
//context_ptr->pslib_client()->load_model();
VLOG(3) << "Psserver Load Model Success";
return 0;
}
......
......@@ -3,6 +3,7 @@
*Train样本
*/
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/process/learner_process.h"
......@@ -49,7 +50,7 @@ std::future<int> LearnerProcess::save_model(uint64_t epoch_id, int table_id, Mod
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
auto* environment = _context_ptr->environment.get();
if (!environment->is_master_node()) {
if (!environment->is_master_node(EnvironmentRole::WORKER)) {
return 0;
}
int ret_size = 0;
......@@ -69,16 +70,17 @@ int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
}
int LearnerProcess::run() {
auto* dataset = _context_ptr->dataset.get();
auto* environment = _context_ptr->environment.get();
auto* epoch_accessor = _context_ptr->epoch_accessor.get();
uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
//判断是否先dump出base
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceBase);
environment->barrier_all();
environment->barrier(EnvironmentRole::WORKER);
while (true) {
epoch_accessor->next_epoch();
......@@ -87,16 +89,17 @@ int LearnerProcess::run() {
"train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
//Step1. 等待样本ready
environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Start %s, wait data ready", epoch_log_title.c_str());
while (!epoch_accessor->data_ready(epoch_id)) {
while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
sleep(30);
environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
dataset->pre_detect_data(epoch_id);
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data not ready, wait 30s", epoch_log_title.c_str());
}
environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data is ready, start traning", epoch_log_title.c_str());
environment->barrier_all();
environment->barrier(EnvironmentRole::WORKER);
//Step2. 运行训练网络
bool already_dump_inference_model = false;
......@@ -111,13 +114,13 @@ int LearnerProcess::run() {
for (int i = 0; i < _train_thread_num; ++i) {
train_threads[i]->join();
}
environment->barrier_all();
environment->barrier(EnvironmentRole::WORKER);
if (_threads_executor[0][i]->is_dump_all_model()) {
already_dump_inference_model = true;
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
}
environment->barrier_all();
environment->barrier(EnvironmentRole::WORKER);
}
//Step3. Dump Model For Delta&&Checkpoint
......@@ -126,7 +129,7 @@ int LearnerProcess::run() {
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
}
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
environment->barrier_all();
environment->barrier(EnvironmentRole::WORKER);
//Step4. Output Monitor && RunStatus
//TODO
......
......@@ -12,6 +12,8 @@ namespace custom_trainer {
namespace feed {
class Process;
class Dataset;
class FileSystem;
class EpochAccessor;
enum class ModelSaveWay {
......@@ -35,10 +37,13 @@ class TrainerContext {
public:
YAML::Node trainer_config;
paddle::platform::CPUPlace cpu_place;
std::vector<TableMeta> params_table_list;
std::shared_ptr<EpochAccessor> epoch_accessor;
std::shared_ptr<RuntimeEnvironment> environment;
std::vector<std::shared_ptr<Process>> process_list;
std::shared_ptr<Dataset> dataset; //训练样本
std::shared_ptr<FileSystem> file_system; //文件操作辅助类
std::vector<TableMeta> params_table_list; //参数表
std::shared_ptr<EpochAccessor> epoch_accessor; //训练轮次控制
std::shared_ptr<RuntimeEnvironment> environment; //运行环境
std::vector<std::shared_ptr<Process>> process_list; //训练流程
};
} // namespace feed
......
......@@ -8,5 +8,6 @@ int32_t main(int32_t argc, char** argv) {
::google::InitGoogleLogging(argv[0]);
::testing::InitGoogleTest(&argc, argv);
::google::ParseCommandLineFlags(&argc, &argv, true);
return RUN_ALL_TESTS();
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
namespace {
const char test_data_dir[] = "test_data";
}
class DataReaderTest : public testing::Test {
public:
static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
{
std::ofstream fout(fs->path_join(test_data_dir, "a.txt"));
fout << "abc 123456" << std::endl;
fout << "def 234567" << std::endl;
fout.close();
}
{
std::ofstream fout(fs->path_join(test_data_dir, "b.txt"));
fout << "ghi 345678" << std::endl;
fout << "jkl 456789" << std::endl;
fout.close();
}
}
static void TearDownTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
virtual void SetUp() {
thread_num = omp_get_max_threads();
omp_set_num_threads(1);
fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext());
}
virtual void TearDown() {
omp_set_num_threads(thread_num);
fs = nullptr;
context_ptr = nullptr;
}
std::shared_ptr<TrainerContext> context_ptr;
std::unique_ptr<FileSystem> fs;
int thread_num = 1;
};
TEST_F(DataReaderTest, LineDataParser) {
std::unique_ptr<DataParser> data_parser(CREATE_CLASS(DataParser, "LineDataParser"));
ASSERT_NE(nullptr, data_parser);
auto config = YAML::Load("");
ASSERT_EQ(0, data_parser->initialize(config, context_ptr));
DataItem data_item;
ASSERT_NE(0, data_parser->parse(std::string("1abcd123456"), data_item));
ASSERT_EQ(0, data_parser->parse(std::string("2abc 123456"), data_item));
ASSERT_STREQ("2abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
ASSERT_NE(0, data_parser->parse("3abcd123456", data_item));
ASSERT_EQ(0, data_parser->parse("4abc 123456", data_item));
ASSERT_STREQ("4abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
ASSERT_NE(0, data_parser->parse("5abc 123456", 4, data_item));
ASSERT_EQ(0, data_parser->parse("6abc 123456", 5, data_item));
ASSERT_STREQ("6abc", data_item.id.c_str());
ASSERT_STREQ("", data_item.data.c_str());
ASSERT_EQ(0, data_parser->parse("7abc 123456", 8, data_item));
ASSERT_STREQ("7abc", data_item.id.c_str());
ASSERT_STREQ("123", data_item.data.c_str());
}
TEST_F(DataReaderTest, LineDataReader) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(2, data_file_list.size());
std::sort(data_file_list.begin(), data_file_list.end());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "b.txt"), data_file_list[1]);
ASSERT_FALSE(data_reader->is_data_ready(test_data_dir));
std::ofstream fout(fs->path_join(test_data_dir, "done_file"));
fout << "done";
fout.close();
ASSERT_TRUE(data_reader->is_data_ready(test_data_dir));
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("def", data_item.id.c_str());
ASSERT_STREQ("234567", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("ghi", data_item.id.c_str());
ASSERT_STREQ("345678", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("jkl", data_item.id.c_str());
ASSERT_STREQ("456789", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"filename_prefix: a");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("def", data_item.id.c_str());
ASSERT_STREQ("234567", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
TEST_F(DataReaderTest, LineDataReader_FileSystem) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n"
"filename_prefix: a\n"
"file_system:\n"
" class: AutoFileSystem\n"
" file_systems:\n"
" 'afs:': &HDFS \n"
" class: HadoopFileSystem\n"
" hdfs_command: 'hadoop fs'\n"
" ugis:\n"
" 'default': 'feed_video,D3a0z8'\n"
" 'xingtian.afs.baidu.com:9902': 'feed_video,D3a0z8'\n"
" \n"
" 'hdfs:': *HDFS\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
{
auto data_file_list = data_reader->data_file_list(test_data_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_data_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("abc", data_item.id.c_str());
ASSERT_STREQ("123456", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("def", data_item.id.c_str());
ASSERT_STREQ("234567", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
{
char test_hadoop_dir[] = "afs://xingtian.afs.baidu.com:9902/user/feed_video/user/rensilin/paddle_trainer_test_dir";
ASSERT_TRUE(data_reader->is_data_ready(test_hadoop_dir));
auto data_file_list = data_reader->data_file_list(test_hadoop_dir);
ASSERT_EQ(1, data_file_list.size());
ASSERT_EQ(string::format_string("%s/%s", test_hadoop_dir, "a.txt"), data_file_list[0]);
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_hadoop_dir, channel));
framework::ChannelReader<DataItem> reader(channel.get());
DataItem data_item;
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("hello", data_item.id.c_str());
ASSERT_STREQ("world", data_item.data.c_str());
reader >> data_item;
ASSERT_TRUE(reader);
ASSERT_STREQ("hello", data_item.id.c_str());
ASSERT_STREQ("hadoop", data_item.data.c_str());
reader >> data_item;
ASSERT_FALSE(reader);
}
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
#include <algorithm>
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
namespace {
const char test_data_dir[] = "test_data";
}
class DataReaderOmpTest : public testing::Test {
public:
static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
std_items.clear();
sorted_std_items.clear();
for (char c = 'a'; c <= 'z'; ++c) {
DataItem item;
item.id = c;
item.data = std::to_string(c - 'a');
std::ofstream fout(fs->path_join(test_data_dir, string::format_string("%c.txt", c)));
fout << item.id << " " << item.data << std::endl;
fout.close();
sorted_std_items.push_back(std::move(item));
}
for (const auto& filename: fs->list(test_data_dir)) {
std::ifstream fin(filename);
DataItem item;
fin >> item.id >> item.data;
fin.close();
std_items.push_back(std::move(item));
}
}
static void TearDownTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
virtual void SetUp() {
thread_num = omp_get_max_threads();
omp_set_num_threads(1);
fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext());
}
virtual void TearDown() {
omp_set_num_threads(thread_num);
fs = nullptr;
context_ptr = nullptr;
}
static bool is_same(const std::vector<DataItem>& a, const std::vector<DataItem>& b) {
int a_size = a.size();
if (a_size != b.size()) {
return false;
}
for (int i = 0; i < a_size; ++i) {
if (a[i].id != b[i].id || a[i].data != b[i].data) {
return false;
}
}
return true;
}
static bool is_same_with_std_items(const std::vector<DataItem>& items) {
return is_same(items, std_items);
}
static bool is_same_with_sorted_std_items(const std::vector<DataItem>& items) {
return is_same(items, sorted_std_items);
}
static std::vector<DataItem> std_items;
static std::vector<DataItem> sorted_std_items;
std::shared_ptr<TrainerContext> context_ptr;
std::unique_ptr<FileSystem> fs;
int thread_num = 1;
const int n_run = 5;
};
std::vector<DataItem> DataReaderOmpTest::std_items;
std::vector<DataItem> DataReaderOmpTest::sorted_std_items;
TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
const int std_items_size = std_items.size();
ASSERT_EQ(std_items_size, data_file_list.size());
for (int i = 0; i < std_items_size; ++i) {
ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]);
}
int same_count = 0;
for (int i = 0; i < n_run; ++i) {
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
std::vector<DataItem> items;
channel->ReadAll(items);
if (is_same_with_std_items(items)) {
++same_count;
}
}
// n_run 次都相同
ASSERT_EQ(n_run, same_count);
}
TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
" class: LineDataParser\n"
"pipeline_cmd: cat\n"
"done_file: done_file\n");
ASSERT_EQ(0, data_reader->initialize(config, context_ptr));
auto data_file_list = data_reader->data_file_list(test_data_dir);
const int std_items_size = std_items.size();
ASSERT_EQ(std_items_size, data_file_list.size());
for (int i = 0; i < std_items_size; ++i) {
ASSERT_EQ(string::format_string("%s/%s.txt", test_data_dir, std_items[i].id.c_str()), data_file_list[i]);
}
ASSERT_FALSE(data_reader->is_data_ready(test_data_dir));
std::ofstream fout(fs->path_join(test_data_dir, "done_file"));
fout << "done";
fout.close();
ASSERT_TRUE(data_reader->is_data_ready(test_data_dir));
int same_count = 0;
int sort_same_count = 0;
for (int i = 0; i < n_run; ++i) {
auto channel = framework::MakeChannel<DataItem>(128);
ASSERT_NE(nullptr, channel);
omp_set_num_threads(4);
ASSERT_EQ(0, data_reader->read_all(test_data_dir, channel));
std::vector<DataItem> items;
channel->ReadAll(items);
if (is_same_with_std_items(items)) {
++same_count;
}
std::sort(items.begin(), items.end(), [] (const DataItem& a, const DataItem& b) {
return a.id < b.id;
});
if (is_same_with_sorted_std_items(items)) {
++sort_same_count;
}
}
// n_run次有不同的(证明是多线程)
ASSERT_EQ(4, omp_get_max_threads());
ASSERT_GT(n_run, same_count);
// 但排序后都是相同的
ASSERT_EQ(n_run, sort_same_count);
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -13,66 +13,118 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <iostream>
#include <fstream>
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/train/custom_trainer/feed/io/file_system.h"
#include "paddle/fluid/train/custom_trainer/feed/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
TEST(testSimpleExecutor, initialize) {
SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>();
YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, execute.initialize(config, context_ptr));
config = YAML::Load("{startup_program: ./data/startup_program, main_program: ./data/main_program}");
ASSERT_EQ(0, execute.initialize(config, context_ptr));
config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}");
ASSERT_EQ(0, execute.initialize(config, context_ptr));
namespace {
const char test_data_dir[] = "test_data";
const char main_program_path[] = "test_data/main_program";
const char startup_program_path[] = "test_data/startup_program";
}
float uniform(float min, float max) {
float result = (float)rand() / RAND_MAX;
return min + result * (max - min);
}
class SimpleExecutorTest : public testing::Test
{
public:
static void SetUpTestCase()
{
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
{
std::unique_ptr<paddle::framework::ProgramDesc> startup_program(
new paddle::framework::ProgramDesc());
std::ofstream fout(startup_program_path, std::ios::out | std::ios::binary);
ASSERT_TRUE(fout);
fout << startup_program->Proto()->SerializeAsString();
fout.close();
}
{
std::unique_ptr<paddle::framework::ProgramDesc> main_program(
new paddle::framework::ProgramDesc());
auto load_block = main_program->MutableBlock(0);
framework::OpDesc* op = load_block->AppendOp();
op->SetType("mean");
op->SetInput("X", {"x"});
op->SetOutput("Out", {"mean"});
op->CheckAttrs();
std::ofstream fout(main_program_path, std::ios::out | std::ios::binary);
ASSERT_TRUE(fout);
fout << main_program->Proto()->SerializeAsString();
fout.close();
}
}
void next_batch(int batch_size, const paddle::platform::Place& place, paddle::framework::LoDTensor* x_tensor, paddle::framework::LoDTensor* y_tensor) {
x_tensor->Resize({batch_size, 2});
auto x_data = x_tensor->mutable_data<float>(place);
static void TearDownTestCase()
{
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
y_tensor->Resize({batch_size, 1});
auto y_data = y_tensor->mutable_data<float>(place);
virtual void SetUp()
{
context_ptr.reset(new TrainerContext());
}
for (int i = 0; i < batch_size; ++i) {
x_data[i * 2] = uniform(-2, 2);
x_data[i * 2 + 1] = uniform(-2, 2);
float dis = x_data[i * 2] * x_data[i * 2] + x_data[i * 2 + 1] * x_data[i * 2 + 1];
y_data[i] = dis < 1.0 ? 1.0 : 0.0;
virtual void TearDown()
{
context_ptr = nullptr;
}
std::shared_ptr<TrainerContext> context_ptr;
};
TEST_F(SimpleExecutorTest, initialize) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor);
YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, executor->initialize(config, context_ptr));
config = YAML::Load(string::format_string("{startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
}
TEST(testSimpleExecutor, run) {
SimpleExecutor execute;
auto context_ptr = std::make_shared<TrainerContext>();
auto config = YAML::Load("{thread_num: 2, startup_program: ./data/startup_program, main_program: ./data/main_program}");
ASSERT_EQ(0, execute.initialize(config, context_ptr));
TEST_F(SimpleExecutorTest, run) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor);
auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
ASSERT_EQ(0, executor->initialize(config, context_ptr));
auto x_var = execute.mutable_var<::paddle::framework::LoDTensor>("x");
auto y_var = execute.mutable_var<::paddle::framework::LoDTensor>("y");
auto x_var = executor->mutable_var<::paddle::framework::LoDTensor>("x");
executor->mutable_var<::paddle::framework::LoDTensor>("mean");
ASSERT_NE(nullptr, x_var);
ASSERT_NE(nullptr, y_var);
next_batch(1024, context_ptr->cpu_place, x_var, y_var);
int x_len = 10;
x_var->Resize({1, x_len});
auto x_data = x_var->mutable_data<float>(context_ptr->cpu_place);
std::cout << "x: ";
for (int i = 0; i < x_len; ++i) {
x_data[i] = i;
std::cout << i << " ";
}
std::cout << std::endl;
ASSERT_EQ(0, execute.run());
ASSERT_EQ(0, executor->run());
auto loss_var = execute.var<::paddle::framework::LoDTensor>("loss");
auto loss = loss_var.data<float>()[0];
std::cout << "loss: " << loss << std::endl;
auto mean_var = executor->var<::paddle::framework::LoDTensor>("mean");
auto mean = mean_var.data<float>()[0];
std::cout << "mean: " << mean << std::endl;
ASSERT_NEAR(4.5, mean, 1e-9);
}
} // namespace feed
......
......@@ -3,3 +3,9 @@ mkdir -p so
cp baidu_third-party_mklml/so/* so
rm -rf baidu_third-party_mklml
cp baidu_third-party_openmpi/so/* so
rm -rf baidu_third-party_openmpi
rm lib/libfake_paddle_proto.a
rmdir lib 2>/dev/null || :
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册