From c555948caad61e028061bc529a7a7db290d8f525 Mon Sep 17 00:00:00 2001 From: wangguibao Date: Wed, 24 Oct 2018 23:24:52 +0800 Subject: [PATCH] AsyncExecutor: C++ side --- paddle/fluid/framework/CMakeLists.txt | 5 + paddle/fluid/framework/async_executor.cc | 570 +++++++++++++++++++++ paddle/fluid/framework/async_executor.h | 175 +++++++ paddle/fluid/framework/data_feed.cc | 162 ++++++ paddle/fluid/framework/data_feed.h | 333 ++++++++++++ paddle/fluid/framework/datafeed_creator.cc | 26 + paddle/fluid/framework/datafeed_creator.h | 22 + paddle/fluid/pybind/CMakeLists.txt | 2 +- proto/FeedDataParameter.proto | 48 ++ 9 files changed, 1342 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/framework/async_executor.cc create mode 100644 paddle/fluid/framework/async_executor.h create mode 100644 paddle/fluid/framework/data_feed.cc create mode 100644 paddle/fluid/framework/data_feed.h create mode 100644 paddle/fluid/framework/datafeed_creator.cc create mode 100644 paddle/fluid/framework/datafeed_creator.h create mode 100644 proto/FeedDataParameter.proto diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 84429114060..ba7dc258a2d 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -174,6 +174,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS fast_threaded_ssa_graph_executor) endif() # NOT WIN32 +cc_library(async_executor + SRCS async_executor.cc data_feed.cc datafeed_creator.cc + DEPS op_registry device_context scope framework_proto glog + lod_rank_table feed_fetch_method graph_to_program_pass) + cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc new file mode 100644 index 00000000000..943cc20e80b --- /dev/null +++ b/paddle/fluid/framework/async_executor.cc @@ -0,0 +1,570 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/async_executor.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +#include "gflags/gflags.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/pybind/pybind.h" + +namespace paddle { +namespace framework { +std::mutex ExecutorThreadWorker::_s_locker_for_pick_file; +unsigned int ExecutorThreadWorker::_s_current_file_idx = 0; +size_t ExecutorThreadWorker::_s_current_finished_file_cnt = 0; +unsigned int ExecutorThreadWorker::_s_current_epoch = 0; +int ExecutorThreadWorker::_s_current_save_epoch = 0; +bool ExecutorThreadWorker::_s_is_first_worker = false; +std::vector ExecutorThreadWorker::_s_thread_filelist; + +void CreateTensor(Variable* var, proto::VarType::Type var_type) { + if (var_type == proto::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == proto::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else if (var_type == proto::VarType::FEED_MINIBATCH) { + var->GetMutable(); + } else if (var_type == proto::VarType::FETCH_LIST) { + var->GetMutable(); + } else if (var_type == proto::VarType::STEP_SCOPES) { + var->GetMutable>(); + } else if (var_type == proto::VarType::LOD_RANK_TABLE) { + var->GetMutable(); + } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { + var->GetMutable(); + } else if (var_type == proto::VarType::PLACE_LIST) { + var->GetMutable(); + } else if (var_type == proto::VarType::READER) { + var->GetMutable(); + } else if (var_type == proto::VarType::RAW) { + // GetMutable will be called in operator + } else { + PADDLE_THROW( + "Variable type %d is not in " + "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " + "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]", + var_type); + } +} + +static void read_binary_file(const std::string& filename, + std::string* content) { + std::string &contents = *content; + std::ifstream fin(filename, std::ios::in | std::ios::binary); + if (!fin.good()) { + LOG(ERROR) << "Cannot open file " << filename.c_str(); + } + fin.seekg(0, std::ios::end); + contents.clear(); + contents.resize(fin.tellg()); + fin.seekg(0, std::ios::beg); + fin.read(&contents[0], contents.size()); + fin.close(); +} + +static void save_model( + const std::unique_ptr & main_program, + Scope* scope, + const std::vector & param_names, + const std::string & model_name, + bool save_combine) { + auto place = platform::CPUPlace(); + const BlockDesc& global_block = main_program->Block(0); + std::vector paralist; + + for (auto* var : global_block.AllVars()) { + bool is_model_param = false; + for (auto param_name : param_names) { + if (var->Name() == param_name) { + is_model_param = true; + break; + } + } + + if (!is_model_param) continue; + + if (!save_combine) { + LOG(ERROR) << "model var name: " << var->Name().c_str(); + + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", model_name + "/" + var->Name()}); + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save", + {{"X", {var->Name()}}}, + {}, + attrs); + + save_op->Run(*scope, place); + } else { + paralist.push_back(var->Name()); + } + } + + if (save_combine) { + std::sort(paralist.begin(), paralist.end()); + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", model_name}); + auto save_op = paddle::framework::OpRegistry::CreateOp( + "save_combine", + {{"X", paralist}}, + {}, + attrs); + save_op->Run(*scope, place); + } +} // end save_model + + +void ExecutorThreadWorker::add_train_file(const std::string& file) { + _s_thread_filelist.push_back(file); +} + +void ExecutorThreadWorker::create_thread_operators(const ProgramDesc& program) { + auto& block = program.Block(0); + _op_names.clear(); + for (auto& op_desc : block.AllOps()) { + std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); + _op_names.push_back(op_desc->Type()); + OperatorBase* local_op_ptr = local_op.release(); + _ops.push_back(local_op_ptr); + continue; + } +} + +void ExecutorThreadWorker::create_thread_scope(const ProgramDesc& program) { + auto& block = program.Block(0); + _thread_scope = &_root_scope->NewScope(); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto* ptr = _root_scope->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + // LOGERR("create Persistable var[%s] finished", + // var->Name().c_str()); + } else { + auto* ptr = _thread_scope->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + // LOGERR("create unpersistable var[%s] finished", + // var->Name().c_str()); + } + } +} + +void ExecutorThreadWorker::set_datafeed(const std::shared_ptr& datafeed) { + _local_reader = datafeed; +} + +void ExecutorThreadWorker::binding_datafeed_memory() { + const std::vector& input_feed = _local_reader->get_use_slot_alias(); + for (auto name : input_feed) { + _local_reader->add_feed_var(_thread_scope->Var(name), name); + } +} + +void ExecutorThreadWorker::set_inspect_var_name( + const std::string& inspect_var_name) { + _inspect_var_name = inspect_var_name; +} + +void ExecutorThreadWorker::set_model_param_names( + const std::vector& param_names) { + _model_param_names = param_names; +} + +void ExecutorThreadWorker::set_sparse_comm_data( + const std::map& param_names) { + _sparse_comm_data = param_names; +} + +void ExecutorThreadWorker::set_device() { + static unsigned priority[] = { + 0, 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, + 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, + 42, 43, 44, 45, 46, 47 + }; + + unsigned int i = this->_thread_id; + + if (i < sizeof(priority) / sizeof(unsigned)) { + unsigned proc = priority[i]; + + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(proc, &mask); + + if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) { + LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i; + } else { + CPU_ZERO(&mask); + if ((0 == sched_getaffinity(0, sizeof(mask), &mask)) + && CPU_ISSET(proc, &mask)) { + LOG(ERROR) << "TRACE: Thread " << i << " is running on processor " << proc << "..."; + } + } + } +} + +void ExecutorThreadWorker::update_epoch_num() { + _s_current_finished_file_cnt++; + + if (_s_current_finished_file_cnt >= _s_thread_filelist.size()) { + _s_current_finished_file_cnt = 0; + _s_current_epoch++; + } +} + +const char* ExecutorThreadWorker::pick_one_file() { + std::string file_to_be_preocessed; + std::lock_guard lock(_s_locker_for_pick_file); + + if (_s_current_file_idx >= _s_thread_filelist.size()) { + std::random_shuffle(_s_thread_filelist.begin(), + _s_thread_filelist.end()); + _s_current_file_idx = 0; + // _s_current_epoch++; //example: when one file, one thread, it's bug + LOG(ERROR) << "thread " << _thread_id + << ": finish traing for epoch " << _s_current_epoch + 1; + } + file_to_be_preocessed = _s_thread_filelist[_s_current_file_idx]; + + _s_current_file_idx++; + return file_to_be_preocessed.c_str(); +} + +void ExecutorThreadWorker::train() { + LOG(ERROR) << "begin to train"; + set_device(); +#ifdef LOCAL_PROF + std::vector op_total_time; + std::vector op_name; + // int total_batch = 0; + for (auto& op : _ops) { + op_name.push_back(op->Type()); + } + op_total_time.resize(_ops.size()); + for (int i = 0; i < op_total_time.size(); ++i) { + op_total_time[i] = 0.0; + } +#endif + std::string inspect_key = "inspect"; + if (!_inspect_var_name.empty()) { + inspect_key = _inspect_var_name.substr(0, + _inspect_var_name.find_first_of('_')); + } + + for (unsigned i = 0; i < _max_epoch; ++i) { + LOG(ERROR) << "epoch: " << i; +#ifdef LOCAL_PROF + Timer timeline; + double total_time = 0.0; + double read_time = 0.0; +#endif + float total_inspect = 0; + int batch_num = 1; + while (i == _s_current_epoch) { + const char* filename = pick_one_file(); + _local_reader->set_file(filename); + while (true) { +#ifdef LOCAL_PROF + timeline.start(); +#endif + bool flag = _local_reader->read_batch(); + if (!flag) { + break; + } +#ifdef LOCAL_PROF + timeline.pause(); + read_time += timeline.elapsed_sec(); + total_time += timeline.elapsed_sec(); +#endif + if (!flag) { + break; + } + + for (unsigned int i = 0; i < _ops.size(); ++i) { +#ifdef LOCAL_PROF + timeline.start(); +#endif + _ops[i]->Run(*_thread_scope, _place); +#ifdef LOCAL_PROF + timeline.pause(); + op_total_time[i] += timeline.elapsed_sec(); + total_time += timeline.elapsed_sec(); +#endif + } + batch_num++; + float avg_inspect = 0.0; + if (!_inspect_var_name.empty()) { + avg_inspect = _thread_scope->FindVar(_inspect_var_name) + ->GetMutable() + ->data()[0]; + } + total_inspect += avg_inspect; + _thread_scope->DropKids(); + } + update_epoch_num(); + LOG(ERROR) << "memory used after epoch " << i + 1 + << " called: " << memory::memory_usage(_place); + +#ifdef LOCAL_PROF + for (int i = 0; i < op_total_time.size(); ++i) { + std::cerr << "op_name:[" << i << "][" << op_name[i] << "]" + << " op_mean_time:[" << op_total_time[i] << "s]" + << std::endl; + } + std::cerr << "read time: " << read_time << "s" << std::endl; +#endif + } +#ifdef LOCAL_PROF + LOG(ERROR) << "mean " << inspect_key.c_str() + << " of epoch " << i + 1 << ": " << total_inspect / batch_num + << ", total_time: " << total_time; +#else + LOG(ERROR) << "mean " << inspect_key.c_str() + << " of epoch " << i + 1 << ": " << total_inspect / batch_num; +#endif + if (_thread_id == 0) { + char modelfile[1024]; + snprintf(&modelfile[0], + sizeof(modelfile), + "%s_epoch%d.model", + _model_prefix.c_str(), + i); + std::string model_filename = std::string(modelfile); + // this save_inference_model can only save imdbtask, should make this + // general + // + // currently comment it + LOG(ERROR) << "Going to save model " << modelfile; + save_model(_main_program, + _thread_scope, + _model_param_names, + model_filename, + true); + } + } +} + +void ExecutorThreadWorker::set_thread_id(int tid) { + _thread_id = tid; +} + +void ExecutorThreadWorker::set_place(const platform::Place& place) { + _place = place; +} + +void ExecutorThreadWorker::set_main_program( + const ProgramDesc& main_program_desc) { + _main_program.reset(new ProgramDesc(main_program_desc)); +} + +void ExecutorThreadWorker::set_root_scope(Scope* g_scope) { + _root_scope = g_scope; +} + +void ExecutorThreadWorker::set_max_training_epoch(int max_epoch) { + _max_epoch = max_epoch; +} + +MultiExecutor::MultiExecutor(const platform::Place& place) : _place(place) {} + +void MultiExecutor::init_root_scope(Scope* scope) { + _root_scope = scope; +} + +void MultiExecutor::set_max_training_epoch(int max_epoch) { + _max_epoch = max_epoch; +} + +void MultiExecutor::set_datafeed_name(const char* feedname) { + _feed_name = std::string(feedname); +} + +void MultiExecutor::set_model_prefix(const std::string& model_prefix) { + _model_prefix = model_prefix; +} + +void MultiExecutor::run_startup_program(const ProgramDesc& program, + Scope* scope) { + auto& block = program.Block(0); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto* ptr = scope->Var(var->Name()); + CreateTensor(ptr, var->GetType()); + // LOGERR("Persistable Var Name:%s", var->Name().c_str()); + } + } + + std::map param_dict; + std::vector ops; + for (auto& op_desc : block.AllOps()) { + std::vector param_name_vec = op_desc->OutputArgumentNames(); + bool need_to_run = false; + for (auto& name : param_name_vec) { + if (param_dict.find(name) == param_dict.end()) { + param_dict[name] = 1; + need_to_run = true; + } + } + if (need_to_run) { + std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); + OperatorBase* local_op_ptr = local_op.release(); + ops.push_back(local_op_ptr); + } + } + // LOGERR("There are %d parameters in startup program, %d op needs to run", + // param_dict.size(), ops.size()); + + for (auto& op : ops) { + op->Run(*scope, _place); + } + // LOGERR("total time for startup program: %fs", timeline.elapsed_sec()); + for (auto& op : ops) { + delete op; + } + // LOGERR("run startup program done."); +} + +std::unique_ptr MultiExecutor::load_desc_from_file( + const std::string& f) { + std::string program_desc_str; + read_binary_file(f, &program_desc_str); + std::unique_ptr program(new ProgramDesc(program_desc_str)); + return program; +} + +void MultiExecutor::set_dense_comm_tensor( + const std::vector& dense_comm_tensor) { + _dense_comm_tensor.resize(dense_comm_tensor.size()); + for (unsigned int i = 0; i < dense_comm_tensor.size(); ++i) { + _dense_comm_tensor[i] = dense_comm_tensor[i]; + } +} + +void MultiExecutor::set_sparse_comm_tensor( + const std::vector& sparse_comm_tensor) { + _sparse_comm_tensor.resize(sparse_comm_tensor.size()); + for (unsigned int i = 0; i < sparse_comm_tensor.size(); ++i) { + _sparse_comm_tensor[i] = sparse_comm_tensor[i]; + } +} + +void MultiExecutor::set_sparse_comm_data( + const std::map& sparse_comm_data) { + _sparse_comm_data = sparse_comm_data; + LOG(INFO) << "Sparse comm data: " << _sparse_comm_data.size(); +} + +void MultiExecutor::set_filelist(const char* filelist) { + _filelist.clear(); + std::ifstream fin(filelist); + std::string filename; + while (fin >> filename) { + LOG(ERROR) << "add " << filename.c_str() << " to filelist"; + _filelist.push_back(filename); + } + fin.close(); +} + +void MultiExecutor::set_filelist(std::vector tfiles) { + _filelist.clear(); + _filelist.insert(_filelist.end(), tfiles.begin(), tfiles.end()); + return; +} + +void MultiExecutor::set_inspect_var_name(const std::string& inspect_var_name) { + _inspect_var_name = inspect_var_name; +} + +void MultiExecutor::set_param_names(const std::vector& param_names) { + _model_param_names = param_names; +} + +void MultiExecutor::set_thread_num(const int thread_num) { + _thread_num = thread_num; +} + +void MultiExecutor::prepare_threads(const ProgramDesc& host_program) { + _workers.resize(_thread_num); + for (unsigned i = 0; i < _thread_num; ++i) { + _workers[i].reset(new ExecutorThreadWorker); + _workers[i]->set_thread_id(i); + _workers[i]->create_thread_operators(host_program); + _workers[i]->set_root_scope(_root_scope); + _workers[i]->set_place(_place); + _workers[i]->set_max_training_epoch(_max_epoch); + _workers[i]->create_thread_scope(host_program); + _workers[i]->set_inspect_var_name(_inspect_var_name); + _workers[i]->set_model_param_names(_model_param_names); + _workers[i]->set_sparse_comm_data(_sparse_comm_data); + _workers[i]->set_main_program(host_program); + _workers[i]->set_model_prefix(_model_prefix); + } + + for (unsigned i = 0; i < _filelist.size(); ++i) { + // suppose at least one trainer thread here, and + // filelist is static so that we only add filelist once + _workers[0]->add_train_file(_filelist[i]); + } + // mpi_wrapper::ModelParam model_param(true); + // _workers[0]->register_parallel_training_param(model_param); + + for (unsigned i = 0; i < _thread_num; ++i) { + // new a datafeed here + std::shared_ptr local_feed = create_datafeed(_feed_name.c_str()); + local_feed->init(_data_feed_param); + local_feed->set_batch_size(_batch_size); + _workers[i]->set_datafeed(local_feed); + _workers[i]->binding_datafeed_memory(); + _workers[i]->set_thread_id(i); + } +} + +void MultiExecutor::run_multi_executor(const ProgramDesc& host_program) { + // thread binding here? + prepare_threads(host_program); + for (unsigned i = 0; i < _thread_num; ++i) { + _threads.push_back(std::thread(&ExecutorThreadWorker::train, + _workers[i].get())); + } + + for (auto& th : _threads) { + th.join(); + } +} + +} // end namespace framework +} // end namespace paddle + +/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h new file mode 100644 index 00000000000..56b46b8afef --- /dev/null +++ b/paddle/fluid/framework/async_executor.h @@ -0,0 +1,175 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_ +#define PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_ + +#include +#include // NOLINT +#include +#include +#include +#include // NOLINT +#include +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/datafeed_creator.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +void CreateTensor(Variable* var, proto::VarType::Type var_type); + +class ExecutorThreadWorker { + public: + ExecutorThreadWorker() {} + virtual ~ExecutorThreadWorker() {} + void create_thread_scope(const framework::ProgramDesc& program); + void set_datafeed(const DataFeed& datafeed); + void set_thread_id(int tid); + void create_thread_operators(const framework::ProgramDesc& program); + void set_root_scope(Scope* g_scope); + void set_device(); + virtual void add_fid_set(); + void set_comm_batch(int comm_batch) { _comm_batch = comm_batch; } + void add_train_file(const std::string& filename); + void set_main_program(const ProgramDesc& main_program_desc); + void set_place(const paddle::platform::Place& place); + void set_max_training_epoch(const int max_epoch); + void binding_datafeed_memory(); + void set_model_prefix(const std::string& prefix) { _model_prefix = prefix; } + void set_inspect_var_name(const std::string& inspect_var_name); + void set_model_param_names(const std::vector& param_names); + void set_sparse_comm_data(const std::map& param_names); + void set_datafeed(const std::shared_ptr& datafeed); + virtual void mpi_train(); + void gpu_train(); + void train(); + virtual const char* pick_one_file(); + void update_epoch_num(); + + virtual void set_dense_comm_tensor( + const std::vector& param_names) {} + virtual void initialize() {} + + public: + static std::mutex _s_locker_for_pick_file; + static unsigned int _s_current_file_idx; + static size_t _s_current_finished_file_cnt; + static unsigned int _s_current_epoch; + static int _s_current_save_epoch; + static std::vector _s_thread_filelist; // filelist + static bool _s_is_first_worker; + + protected: + // thread index + int _thread_id; + + // current training file + int _cur_fileidx; + + // max epoch for each thread + unsigned int _max_epoch; + + // instances learned currently + int _comm_batch; + std::string _model_prefix; + std::vector _op_names; + + // local ops for forward and backward + std::vector _ops; + + // main program for training + std::unique_ptr _main_program; + + // binary data reader + std::shared_ptr _local_reader; + + std::string _inspect_var_name; + std::vector _model_param_names; + std::map _sparse_comm_data; + std::vector _ids_buffer; + + // execution place + platform::Place _place; + + // root scope for model parameters + Scope* _root_scope; + + // a thread scope, father scope is global score which is shared + Scope* _thread_scope; +}; + +class MultiExecutor { + public: + explicit MultiExecutor(const platform::Place& place); + virtual ~MultiExecutor() {} + static std::unique_ptr load_desc_from_file( + const std::string& filename); + void init_root_scope(Scope* scope); + void set_inspect_var_name(const std::string& inspect_var_name); + void set_param_names(const std::vector& param_names); + void set_max_training_epoch(const int max_epoch); + Scope* get_root_scope() { return _root_scope; } + void set_thread_num(const int thread_num); + void set_batch_size(const int batch_size) { _batch_size = batch_size; } + void set_filelist(const char* filelist); + void set_filelist(const std::vector filelist); + void set_datafeed_name(const char* feedname); + + void set_data_feed_param(const datafeed::DataFeedParameter& feed_param) { + _data_feed_param = feed_param; + } + + void set_comm_batch(int comm_batch) { + _comm_batch = comm_batch; + } + + void set_model_prefix(const std::string& model_prefix); + void set_dense_comm_tensor(const std::vector& dense_comm_tensor); + void set_sparse_comm_tensor( + const std::vector& sparse_comm_tensor); + void set_sparse_comm_data(const std::map& sparse_comm_data); + virtual void prepare_threads(const framework::ProgramDesc& host_program); + void run_startup_program(const framework::ProgramDesc& program, + framework::Scope* scope); + void run_multi_executor(const ProgramDesc& host_program); + + public: + unsigned int _thread_num; + datafeed::DataFeedParameter _data_feed_param; + int _max_epoch; + int _batch_size; + int _comm_batch; + std::vector > _workers; + std::vector _threads; + std::vector _filelist; + std::string _inspect_var_name; + std::vector _model_param_names; + std::vector _dense_comm_tensor; + std::vector _sparse_comm_tensor; + std::map _sparse_comm_data; + int node_num; + std::string _model_prefix; + ProgramDesc _host_program; + std::string _feed_name; + Scope* _root_scope; + platform::Place _place; +}; + +} // namespace framework +} // namespace paddle +#endif // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_ +/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc new file mode 100644 index 00000000000..97b0aa73d9a --- /dev/null +++ b/paddle/fluid/framework/data_feed.cc @@ -0,0 +1,162 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" + +#include "gflags/gflags.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/framework/data_feed.h" + +DEFINE_bool(is_text_feed, false, "is_text_feed"); + +namespace paddle { +namespace framework { +void TextClassDataFeed::init(const datafeed::DataFeedParameter& feed_param) { + // hard coding for a specific datafeed + _feed_vec.resize(2); + // _feed_vec[0].reset(new LoDTensor); + // _feed_vec[1].reset(new LoDTensor); + _all_slot_ids = {0, 1}; + _use_slot_ids = {0, 1}; + _use_slot_alias = {"words", "label"}; + _file_content_buffer_host.reset(new char[200*1024*1024], + [](char *p) {delete[] p;}); + _file_content_buffer = _file_content_buffer_host.get(); + _file_content_buffer_ptr = _file_content_buffer; + _batch_id_host.reset(new int[10240*1024], + [](int *p) {delete[] p;}); // max word num in a batch + _label_host.reset(new int[10240], + [](int *p) {delete[] p;}); // max label in a batch + _batch_id_buffer = _batch_id_host.get(); + _label_ptr = _label_host.get(); +} + + // todo: use elegant implemention for this function +bool TextClassDataFeed::read_batch() { + paddle::framework::Vector offset; + int tlen = 0; + int llen = 0; + int inst_idx = 0; + offset.resize(_batch_size + 1); + offset[0] = 0; + while (inst_idx < _batch_size) { + int ptr_offset = 0; + if (_file_content_buffer_ptr - _file_content_buffer >= _file_size) { + break; + } + + memcpy(reinterpret_cast(&llen), + _file_content_buffer_ptr + ptr_offset, + sizeof(int)); + ptr_offset += sizeof(int); + + memcpy(reinterpret_cast(_batch_id_buffer + tlen), + _file_content_buffer_ptr + ptr_offset, + llen * sizeof(int)); + tlen += llen; + + offset[inst_idx + 1] = offset[inst_idx] + llen; + ptr_offset += sizeof(int) * llen; + + memcpy(reinterpret_cast(_label_ptr + inst_idx), + _file_content_buffer_ptr + ptr_offset, + sizeof(int)); + ptr_offset += sizeof(int); + + _file_content_buffer_ptr += ptr_offset; + inst_idx++; + } + + if (inst_idx != _batch_size) { + return false; + } + + LoD input_lod{offset}; + paddle::framework::Vector label_offset; + label_offset.resize(_batch_size + 1); + for (int i = 0; i <= _batch_size; ++i) { + label_offset[i] = i; + } + + LoD label_lod{label_offset}; + int64_t* input_ptr = _feed_vec[0]->mutable_data( + {static_cast(offset.back()), 1}, + platform::CPUPlace()); + int64_t* label_ptr = _feed_vec[1]->mutable_data({_batch_size, 1}, + platform::CPUPlace()); + for (unsigned int i = 0; i < offset.back(); ++i) { + input_ptr[i] = static_cast(_batch_id_buffer[i]); + } + for (int i = 0; i < _batch_size; ++i) { + label_ptr[i] = static_cast(_label_ptr[i]); + } + _feed_vec[0]->set_lod(input_lod); + _feed_vec[1]->set_lod(label_lod); + return true; +} + +void TextClassDataFeed::add_feed_var(Variable* feed, const std::string& name) { + for (unsigned int i = 0; i < _use_slot_alias.size(); ++i) { + if (name == _use_slot_alias[i]) { + _feed_vec[i] = feed->GetMutable(); + } + } +} + +bool TextClassDataFeed::set_file(const char* filename) { + // termnum termid termid ... termid label + int filesize = read_whole_file(filename, _file_content_buffer); + // todo , remove magic number + if (filesize < 0 || filesize >= 1024 * 1024 * 1024) { + return false; + } + _file_content_buffer_ptr = _file_content_buffer; + _file_size = filesize; + return true; +} + +int TextClassDataFeed::read_whole_file(const std::string& filename, + char* buffer) { + std::ifstream ifs(filename.c_str(), std::ios::binary); + if (ifs.fail()) { + return -1; + } + + ifs.seekg(0, std::ios::end); + int file_size = ifs.tellg(); + ifs.seekg(0, std::ios::beg); + ifs.read(buffer, file_size); + return file_size; +} + +} // namespace framework +} // namespace paddle +/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ + diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h new file mode 100644 index 00000000000..1680e5c480c --- /dev/null +++ b/paddle/fluid/framework/data_feed.h @@ -0,0 +1,333 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ +#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ + +#include +#include +#include +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "proto/FeedDataParameter.pb.h" + +namespace paddle { +namespace framework { +typedef uint64_t FeatureKey; + +struct FeatureItem { + FeatureItem() {} + FeatureItem(FeatureKey sign_, uint16_t slot_) { + sign() = sign_; + slot() = slot_; + } + + FeatureKey& sign() { + return *(reinterpret_cast(sign_buffer())); + } + + const FeatureKey& sign() const { + return *(const FeatureKey*)sign_buffer(); + } + + uint16_t& slot() { + return _slot; + } + + const uint16_t& slot() const { + return _slot; + } + + private: + char _sign[sizeof(FeatureKey)]; + uint16_t _slot; + char* sign_buffer() const { + return (char *)_sign; + } +}; + +// Record(average:14031B) is smaller than Sample(average:16530B) +struct Record { + int show, click; + std::vector feas; + std::string lineid; + std::string tags; +}; + +struct Gauc { + int show, click; + uint64_t fea; + std::string lineid; +}; + +struct Instance { + std::vector> feed_vec_buffer; + std::vector> feed_vec_lod; + std::vector other_label; + std::vector gauc_vec; +}; + +struct Sample { + uint64_t label; + std::map> feas; + + bool from_string(const std::string& input, const std::set& slots) { + size_t end = input.find_first_of(' '); + if (end == std::string::npos) { + LOG(ERROR) << "[ERROR] Fail in parsing:" << input; + return false; + } + label = input[end + 3] - '0'; + CHECK(label == 0 || label == 1) << "invalid label:" << label; + + std::stringstream ss(input); + + std::string token; + uint16_t slot_id = 0; + uint64_t feature_id = 0; + int num_nonfeas_token = 0; + std::ostringstream os; + while (ss >> token) { + size_t end = token.find_first_of(':'); + if (end == std::string::npos) { + ++num_nonfeas_token; + continue; + } + + try { + slot_id = stoi(token.substr(end + 1)); + } catch (...) { + LOG(ERROR) << "Error in parsing slot id:" << token; + return false; + } + + try { + feature_id = stoull(token.substr(0, end)); + } catch (...) { + LOG(ERROR) << "Error in parsing feature id:" << token; + return false; + } + + if (slot_id <= 0) { + LOG(ERROR) << "invalid slot:" << slot_id << " feasign:" << feature_id + << " line:" << input; + return false; + } + + if (slots.find(slot_id) == slots.end()) { + continue; + } + + feas[slot_id].push_back(feature_id); + } + + if (num_nonfeas_token != 4) { + LOG(ERROR) << "Format error. Invalid number of non-feasign token:" + << num_nonfeas_token; + return false; + } + + return true; + } +}; + +struct TeacherStudentSample { + uint64_t label; + std::map> feas; + float q_score; + + void print() { + LOG(ERROR) << "label: " << label << " score: " << q_score; + for (auto &slot : feas) { + for (auto &fea : slot.second) { + LOG(ERROR) << "slot: " << slot.first << " fea: " << fea; + } + } + } + + bool from_string(const std::string& input, + const std::set& slots, + Gauc& gauc) { // NOLINT + size_t end = input.find_first_of(' '); + if (end == std::string::npos) { + LOG(ERROR) << "[ERROR] Fail in parsing:" << input; + return false; + } + + label = input[end + 3] - '0'; + CHECK(label == 0 || label == 1) << "invalid label:" << label; + gauc.show = 1; + gauc.click = label; + gauc.lineid = input.substr(0, end); + gauc.fea = 0; + size_t dnn_start = input.find("*"); + if (dnn_start == std::string::npos) { + q_score = -1.0; + } else { + dnn_start += 1; + size_t dnn_end = input.find(' ', dnn_start); + q_score = static_cast( + atof(input.substr(dnn_start, dnn_end - dnn_start).c_str())); + } + + size_t head_pos = input.find("\t"); + std::string head = input.substr(0, head_pos); + std::stringstream ss(head); + + std::string token; + uint16_t slot_id = 0; + uint64_t feature_id = 0; + int num_nonfeas_token = 0; + std::ostringstream os; + while (ss >> token) { + size_t end = token.find_first_of(':'); + if (end == std::string::npos) { + ++num_nonfeas_token; + continue; + } + + try { + slot_id = stoi(token.substr(end + 1)); + } catch (...) { + LOG(ERROR) << "Error in parsing slot id:" << token; + return false; + } + + try { + feature_id = stoull(token.substr(0, end)); + } catch (...) { + LOG(ERROR) << "Error in parsing feature id:" << token; + return false; + } + + if (slot_id <= 0) { + LOG(ERROR) << "invalid slot:" << slot_id << " feasign:" << feature_id + << " line:" << input; + return false; + } + + if (slots.find(slot_id) == slots.end()) { + continue; + } + + if (slot_id == 6048) { + gauc.fea = feature_id; + } + feas[slot_id].push_back(feature_id); + } + + if (num_nonfeas_token != 4) { + LOG(ERROR) << "Format error. Invalid number of non-feasign token:" + << num_nonfeas_token; + return false; + } + return true; + } +}; + +class DataFeed { + public: + DataFeed() {} + virtual ~DataFeed() {} + virtual void init(const datafeed::DataFeedParameter& feed_param) = 0; + /* + * This function will be used to check file format. + * Considering that this function may be used alone, + * it does not check anything. + * */ + virtual bool check_file(const char* filename) = 0; + virtual bool set_file(const char* filename) = 0; + virtual bool read_batch() = 0; + virtual const std::vector& get_all_slot_ids() { + return _all_slot_ids; + } + + virtual const std::vector& get_use_slot_ids() { + return _use_slot_ids; + } + + virtual const std::vector& get_use_slot_alias() { + return _use_slot_alias; + } + + virtual void add_feed_var(Variable* var, + const std::string& name) = 0; + virtual void bind_scope(Scope* scope) = 0; + virtual void set_batch_size(int batch) { _default_batch_size = batch; } + virtual int get_batch_size() { return _batch_size; } + virtual void set_buffer_size(int buffer_size) {} + + std::vector& get_feed_vec() { + return _feed_vec; + } + + virtual std::vector& get_feed_vec(const Instance& ins) { + LOG(ERROR) << "use defalut get_feed_vec"; + return _feed_vec; + } + + protected: + std::vector _all_slot_ids; + std::vector _use_slot_ids; + std::vector _use_slot_alias; + std::vector _feed_vec; + int _default_batch_size; + int _batch_size; +}; + +class TextClassDataFeed : public DataFeed { + public: + virtual ~TextClassDataFeed() {} + virtual void init(const datafeed::DataFeedParameter& feed_param); + virtual bool read_batch(); + virtual void add_feed_var(Variable* feed, const std::string& name); + virtual void bind_scope(Scope* scope) {} + virtual bool set_file(const char* filename); + + virtual bool check_file(const char* filename) { + // TODO(xxx) + return false; + } + + void set_batch_size(int batch) {_batch_size = batch;} + + private: + int read_whole_file(const std::string& filename, char* buffer); + char* _file_content_buffer; + char* _file_content_buffer_ptr; + int* _batch_id_buffer; + int* _label_ptr; + int _file_size; + std::vector _names; + std::shared_ptr _file_content_buffer_host; + std::shared_ptr _batch_id_host; + std::shared_ptr _label_host; +}; + +} // namespace framework +} // namespace paddle + +#endif // PADDLE_FLUID_FRAMEWORK_DATA_FEED_H_ +/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ diff --git a/paddle/fluid/framework/datafeed_creator.cc b/paddle/fluid/framework/datafeed_creator.cc new file mode 100644 index 00000000000..5dd83292532 --- /dev/null +++ b/paddle/fluid/framework/datafeed_creator.cc @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/datafeed_creator.h" + +std::shared_ptr create_datafeed( + const char* datafeed_class) { + if (strcmp(datafeed_class, "TextClass") == 0) { + return std::shared_ptr( + new paddle::framework::TextClassDataFeed); + } + + return NULL; +} +/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ diff --git a/paddle/fluid/framework/datafeed_creator.h b/paddle/fluid/framework/datafeed_creator.h new file mode 100644 index 00000000000..8fd95ba8812 --- /dev/null +++ b/paddle/fluid/framework/datafeed_creator.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_ +#define PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_ +#include +#include "paddle/fluid/framework/data_feed.h" + +std::shared_ptr create_datafeed( + const char* datafeed_class); +#endif // PADDLE_FLUID_FRAMEWORK_DATAFEED_CREATOR_H_ diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index e7f634c4a62..92b5d0dc398 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,5 +1,5 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder) +set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc) if(NOT WIN32) list(APPEND PYBIND_DEPS parallel_executor profiler) diff --git a/proto/FeedDataParameter.proto b/proto/FeedDataParameter.proto new file mode 100644 index 00000000000..e165ce605c6 --- /dev/null +++ b/proto/FeedDataParameter.proto @@ -0,0 +1,48 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +syntax = "proto2"; +package datafeed; + +message DataFeedParameter { + optional FeedDataParameter feed_data_param = 1; + optional JointOneHotParameter joint_onehot_data_param = 2; + optional ACDXParameter acdx_data_param = 3; +} + +message FeedDataParameter { + repeated int32 slot_id = 1; + repeated int32 use_slot_id = 2; + repeated string use_slot_alias = 3; + repeated uint64 use_slot_mod = 4; + repeated int32 use_slot_type = 5; + optional int32 max_batch_num = 6 [default = 128]; + optional int32 max_feasign_num = 7 [default = 1000]; +} + +message JointOneHotParameter { + optional int32 max_batch_num = 1 [default = 128]; + optional int32 max_title_num = 2 [default = 400]; + optional int32 max_term_num = 3 [default = 1024]; + required float sampling_rate = 4; + repeated int32 slot_id = 5; + repeated int32 use_slot_id = 6; + repeated string use_slot_alias = 7; + repeated uint64 use_slot_mod = 8; + repeated int32 use_slot_type = 9; +} + +message ACDXParameter { + optional int32 max_batch_num = 1 [default = 128]; + optional int32 max_term_num = 3 [default = 512]; +} -- GitLab