diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6b5ed1024449bdd5e73d69b3d6531e1d4e86126f..26113ee7e90bb9112a607e43c59aac1b5c21a4bb 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -32,6 +32,13 @@ paddle.fluid.BuildStrategy.ReduceStrategy.__init__ __init__(self: paddle.fluid.c paddle.fluid.BuildStrategy.__init__ __init__(self: paddle.fluid.core.ParallelExecutor.BuildStrategy) -> None paddle.fluid.create_lod_tensor ArgSpec(args=['data', 'recursive_seq_lens', 'place'], varargs=None, keywords=None, defaults=None) paddle.fluid.create_random_int_lodtensor ArgSpec(args=['recursive_seq_lens', 'base_shape', 'place', 'low', 'high'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DataFeedDesc.__init__ ArgSpec(args=['self', 'proto_file'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DataFeedDesc.desc ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DataFeedDesc.set_batch_size ArgSpec(args=['self', 'batch_size'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DataFeedDesc.set_dense_slots ArgSpec(args=['self', 'dense_slots_name'], varargs=None, keywords=None, defaults=None) +paddle.fluid.DataFeedDesc.set_use_slots ArgSpec(args=['self', 'use_slots_name'], varargs=None, keywords=None, defaults=None) +paddle.fluid.AsyncExecutor.__init__ ArgSpec(args=['self', 'place'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.AsyncExecutor.run ArgSpec(args=['self', 'program', 'data_feed', 'filelist', 'thread_num', 'fetch', 'debug'], varargs=None, keywords=None, defaults=(False,)) paddle.fluid.io.save_vars ArgSpec(args=['executor', 'dirname', 'main_program', 'vars', 'predicate', 'filename'], varargs=None, keywords=None, defaults=(None, None, None, None)) paddle.fluid.io.save_params ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.io.save_persistables ArgSpec(args=['executor', 'dirname', 'main_program', 'filename'], varargs=None, keywords=None, defaults=(None, None)) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 52946c7f11f90490b1af1347f20db236a8fe24af..9f5631b87cba62aa984f27b13418d61e12e86c8a 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -34,6 +34,7 @@ add_subdirectory(ir) add_subdirectory(details) # ddim lib proto_library(framework_proto SRCS framework.proto) +proto_library(async_executor_proto SRCS data_feed.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) @@ -135,7 +136,7 @@ endif(NOT WIN32) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc) nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) -py_proto_compile(framework_py_proto SRCS framework.proto) +py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) # Generate an empty __init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) add_dependencies(framework_py_proto framework_py_proto_init) @@ -157,18 +158,19 @@ endif(NOT WIN32) cc_library(lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor) cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glog) +cc_library(variable_helper SRCS variable_helper.cc DEPS lod_tensor) -cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) +cc_library(naive_executor SRCS naive_executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) if(WITH_DISTRIBUTE) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method sendrecvop_grpc cares grpc++_unsecure grpc_unsecure gpr graph_to_program_pass variable_helper) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() if(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper) else(NOT WIN32) - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) endif(NOT WIN32) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() @@ -176,8 +178,11 @@ endif() cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor graph build_strategy - fast_threaded_ssa_graph_executor) + fast_threaded_ssa_graph_executor variable_helper) +cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper) + +cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor) cc_library(prune SRCS prune.cc DEPS framework_proto) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc new file mode 100644 index 0000000000000000000000000000000000000000..afb2dd2f064384da39904f6aceead4fa915a80f2 --- /dev/null +++ b/paddle/fluid/framework/async_executor.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/async_executor.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +#include "gflags/gflags.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/executor_thread_worker.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/pybind/pybind.h" + +namespace paddle { +namespace framework { +AsyncExecutor::AsyncExecutor(Scope* scope, const platform::Place& place) + : root_scope_(scope), place_(place) {} + +void AsyncExecutor::CreateThreads( + ExecutorThreadWorker* worker, const ProgramDesc& main_program, + const std::shared_ptr& reader, + const std::vector& fetch_var_names, Scope* root_scope, + const int thread_index, const bool debug) { + worker->SetThreadId(thread_index); + worker->SetDebug(debug); + worker->SetRootScope(root_scope); + worker->CreateThreadResource(main_program, place_); + worker->SetDataFeed(reader); + worker->SetFetchVarNames(fetch_var_names); + worker->BindingDataFeedMemory(); +} + +void PrepareReaders(std::vector>& readers, // NOLINT + const int thread_num, const DataFeedDesc& data_feed_desc, + const std::vector& filelist) { + readers.resize(thread_num); + for (size_t i = 0; i < readers.size(); ++i) { + readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name()); + readers[i]->Init(data_feed_desc); // set batch_size and queue_size here + } + readers[0]->SetFileList(filelist); +} + +void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, + const std::string& data_feed_desc_str, + const std::vector& filelist, + const int thread_num, + const std::vector& fetch_var_names, + const bool debug) { + std::vector threads; + + auto& block = main_program.Block(0); + for (auto var_name : fetch_var_names) { + auto var_desc = block.FindVar(var_name); + auto shapes = var_desc->GetShape(); + PADDLE_ENFORCE(shapes[shapes.size() - 1] == 1, + "var %s: Fetched var has wrong shape, " + "only variables with the last dimension size 1 supported", + var_name); + } + + DataFeedDesc data_feed_desc; + google::protobuf::TextFormat::ParseFromString(data_feed_desc_str, + &data_feed_desc); + + int actual_thread_num = thread_num; + int file_cnt = filelist.size(); + PADDLE_ENFORCE(file_cnt > 0, "File list cannot be empty"); + + if (actual_thread_num > file_cnt) { + VLOG(1) << "Thread num = " << thread_num << ", file num = " << file_cnt + << ". Changing thread_num = " << file_cnt; + actual_thread_num = file_cnt; + } + + /* + readerDesc: protobuf description for reader initlization + argument: class_name, batch_size, use_slot, queue_size, buffer_size, + padding_index + + reader: + 1) each thread has a reader, reader will read input data and + put it into input queue + 2) each reader has a Next() iterface, that can fetch an instance + from the input queue + */ + // todo: should be factory method for creating datafeed + std::vector> readers; + PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist); + + std::vector> workers; + workers.resize(actual_thread_num); + for (auto& worker : workers) { + worker.reset(new ExecutorThreadWorker); + } + + // prepare thread resource here + for (int thidx = 0; thidx < actual_thread_num; ++thidx) { + CreateThreads(workers[thidx].get(), main_program, readers[thidx], + fetch_var_names, root_scope_, thidx, debug); + } + + // start executing ops in multiple threads + for (int thidx = 0; thidx < actual_thread_num; ++thidx) { + threads.push_back( + std::thread(&ExecutorThreadWorker::TrainFiles, workers[thidx].get())); + } + + for (auto& th : threads) { + th.join(); + } + + root_scope_->DropKids(); + + return; +} + +} // einit_modelnd namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..f4d2a79ac592e02f49ec0b988c824dc98883fbf6 --- /dev/null +++ b/paddle/fluid/framework/async_executor.h @@ -0,0 +1,58 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include +#include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/executor_thread_worker.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +class AsyncExecutor { + public: + AsyncExecutor(Scope* scope, const platform::Place& place); + virtual ~AsyncExecutor() {} + void RunFromFile(const ProgramDesc& main_program, + const std::string& data_feed_desc_str, + const std::vector& filelist, + const int thread_num, + const std::vector& fetch_names, + const bool debug = false); + + private: + void CreateThreads(ExecutorThreadWorker* worker, + const ProgramDesc& main_program, + const std::shared_ptr& reader, + const std::vector& fetch_var_names, + Scope* root_scope, const int thread_index, + const bool debug); + + public: + Scope* root_scope_; + platform::Place place_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc new file mode 100644 index 0000000000000000000000000000000000000000..851c7eda89e87b8a8e40b344b589ac3176ed5211 --- /dev/null +++ b/paddle/fluid/framework/data_feed.cc @@ -0,0 +1,375 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +#include "gflags/gflags.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" + +namespace paddle { +namespace framework { + +std::vector DataFeed::filelist_; +size_t DataFeed::file_idx_; +std::mutex DataFeed::mutex_for_pick_file_; +bool DataFeed::finish_set_filelist_; + +void DataFeed::AddFeedVar(Variable* var, const std::string& name) { + CheckInit(); + for (size_t i = 0; i < use_slots_.size(); ++i) { + if (name == use_slots_[i]) { + if (use_slots_is_dense_[i]) { + feed_vec_[i] = MixTensor(var->GetMutable()); + } else { + feed_vec_[i] = MixTensor(var->GetMutable()); + } + } + } +} + +bool DataFeed::SetFileList(const std::vector& files) { + std::unique_lock lock(mutex_for_pick_file_); + CheckInit(); + if (finish_set_filelist_) { + VLOG(3) << "info: you have set the filelist."; + return false; + } + PADDLE_ENFORCE(files.size(), "You have set an empty filelist."); + filelist_.assign(files.begin(), files.end()); + file_idx_ = 0; + + finish_set_filelist_ = true; + return true; +} + +void DataFeed::SetBatchSize(int batch_size) { + PADDLE_ENFORCE(batch_size > 0, "Illegal batch size: %d.", batch_size); + default_batch_size_ = batch_size; +} + +bool DataFeed::PickOneFile(std::string* filename) { + std::unique_lock lock(mutex_for_pick_file_); + if (file_idx_ == filelist_.size()) { + return false; + } + *filename = filelist_[file_idx_++]; + return true; +} + +void DataFeed::CheckInit() { + PADDLE_ENFORCE(finish_init_, "Initialization did not succeed."); +} + +void DataFeed::CheckSetFileList() { + PADDLE_ENFORCE(finish_set_filelist_, "Set filelist did not succeed."); +} + +void DataFeed::CheckStart() { + PADDLE_ENFORCE(finish_start_, "Datafeed has not started running yet."); +} + +template +void PrivateQueueDataFeed::SetQueueSize(int queue_size) { + PADDLE_ENFORCE(queue_size > 0, "Illegal queue size: %d.", queue_size); + queue_size_ = queue_size; + queue_ = std::unique_ptr>( + new paddle::operators::reader::BlockingQueue(queue_size_)); +} + +template +bool PrivateQueueDataFeed::Start() { + CheckSetFileList(); + read_thread_ = std::thread(&PrivateQueueDataFeed::ReadThread, this); + read_thread_.detach(); + + finish_start_ = true; + return true; +} + +template +void PrivateQueueDataFeed::ReadThread() { + std::string filename; + while (PickOneFile(&filename)) { + file_.open(filename.c_str()); // is_text_feed + PADDLE_ENFORCE(file_.good(), "Open file<%s> fail.", filename.c_str()); + T instance; + while (ParseOneInstance(&instance)) { + queue_->Send(instance); + } + file_.close(); + } + queue_->Close(); +} + +template +int PrivateQueueDataFeed::Next() { + CheckStart(); + int index = 0; + T instance; + T ins_vec; + while (index < default_batch_size_) { + if (!queue_->Receive(&instance)) { + break; + } + AddInstanceToInsVec(&ins_vec, instance, index++); + } + batch_size_ = index; + if (batch_size_ != 0) { + PutToFeedVec(ins_vec); + } + return batch_size_; +} + +#ifdef _WIN32 +template class PrivateQueueDataFeed>; +#endif + +void MultiSlotDataFeed::Init( + const paddle::framework::DataFeedDesc& data_feed_desc) { + finish_init_ = false; + finish_set_filelist_ = false; + finish_start_ = false; + + PADDLE_ENFORCE(data_feed_desc.has_multi_slot_desc(), + "Multi_slot_desc has not been set."); + paddle::framework::MultiSlotDesc multi_slot_desc = + data_feed_desc.multi_slot_desc(); + SetBatchSize(data_feed_desc.batch_size()); + SetQueueSize(data_feed_desc.batch_size()); + size_t all_slot_num = multi_slot_desc.slots_size(); + all_slots_.resize(all_slot_num); + all_slots_type_.resize(all_slot_num); + use_slots_index_.resize(all_slot_num); + use_slots_.clear(); + use_slots_is_dense_.clear(); + for (size_t i = 0; i < all_slot_num; ++i) { + const auto& slot = multi_slot_desc.slots(i); + all_slots_[i] = slot.name(); + all_slots_type_[i] = slot.type(); + use_slots_index_[i] = slot.is_used() ? use_slots_.size() : -1; + if (slot.is_used()) { + use_slots_.push_back(all_slots_[i]); + use_slots_is_dense_.push_back(slot.is_dense()); + } + } + feed_vec_.resize(use_slots_.size()); + finish_init_ = true; +} + +bool MultiSlotDataFeed::CheckFile(const char* filename) { + CheckInit(); // get info of slots + std::ifstream fin(filename); + if (!fin.good()) { + VLOG(1) << "error: open file<" << filename << "> fail"; + return false; + } + std::string line; + int instance_cout = 0; + std::string all_slots_alias = ""; + for (const auto& alias : all_slots_) { + all_slots_alias += alias + " "; + } + std::string use_slots_alias = ""; + for (const auto& alias : use_slots_) { + use_slots_alias += alias + " "; + } + VLOG(3) << "total slots num: " << all_slots_.size(); + VLOG(3) << "total slots alias: " << all_slots_alias; + VLOG(3) << "used slots num: " << use_slots_.size(); + VLOG(3) << "used slots alias: " << use_slots_alias; + while (getline(fin, line)) { + ++instance_cout; + const char* str = line.c_str(); + char* endptr = const_cast(str); + int len = line.length(); + for (size_t i = 0; i < all_slots_.size(); ++i) { + int num = strtol(endptr, &endptr, 10); + if (num < 0) { + VLOG(1) << "error: the number of ids is a negative number: " << num; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } else if (num == 0) { + VLOG(1) + << "error: the number of ids can not be zero, you need " + "padding it in data generator; or if there is something wrong" + " with the data, please check if the data contains unresolvable " + "characters."; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } else if (errno == ERANGE || num > INT_MAX) { + VLOG(1) << "error: the number of ids greater than INT_MAX"; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } + if (all_slots_type_[i] == "float") { + for (int i = 0; i < num; ++i) { + strtof(endptr, &endptr); + if (errno == ERANGE) { + VLOG(1) << "error: the value is out of the range of " + "representable values for float"; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } + if (i + 1 != num && endptr - str == len) { + VLOG(1) << "error: there is a wrong with the number of ids."; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } + } + } else if (all_slots_type_[i] == "uint64") { + for (int i = 0; i < num; ++i) { + strtoull(endptr, &endptr, 10); + if (errno == ERANGE) { + VLOG(1) << "error: the value is out of the range of " + "representable values for uint64_t"; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } + if (i + 1 != num && endptr - str == len) { + VLOG(1) << "error: there is a wrong with the number of ids."; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } + } + } else { + VLOG(1) << "error: this type<" << all_slots_type_[i] + << "> is not supported"; + return false; + } + } + if (endptr - str != len) { + VLOG(1) << "error: there is some data at the end of the line."; + VLOG(1) << "please check line<" << instance_cout << "> in file<" + << filename << ">"; + return false; + } + } + VLOG(3) << "instances cout: " << instance_cout; + VLOG(3) << "The file format is correct"; + return true; +} + +bool MultiSlotDataFeed::ParseOneInstance(std::vector* instance) { + std::string line; + if (getline(file_, line)) { + int use_slots_num = use_slots_.size(); + instance->resize(use_slots_num); + // parse line + const char* str = line.c_str(); + char* endptr = const_cast(str); + int pos = 0; + for (size_t i = 0; i < use_slots_index_.size(); ++i) { + int idx = use_slots_index_[i]; + int num = strtol(&str[pos], &endptr, 10); + PADDLE_ENFORCE( + num, + "The number of ids can not be zero, you need padding " + "it in data generator; or if there is something wrong with " + "the data, please check if the data contains unresolvable " + "characters.\nplease check this error line: %s", + str); + if (idx != -1) { + (*instance)[idx].Init(all_slots_type_[i]); + if ((*instance)[idx].GetType()[0] == 'f') { // float + for (int j = 0; j < num; ++j) { + float feasign = strtof(endptr, &endptr); + (*instance)[idx].AddValue(feasign); + } + } else if ((*instance)[idx].GetType()[0] == 'u') { // uint64 + for (int j = 0; j < num; ++j) { + uint64_t feasign = (uint64_t)strtoull(endptr, &endptr, 10); + (*instance)[idx].AddValue(feasign); + } + } + pos = endptr - str; + } else { + for (int j = 0; j <= num; ++j) { + pos = line.find_first_of(' ', pos + 1); + } + } + } + } else { + return false; + } + return true; +} + +void MultiSlotDataFeed::AddInstanceToInsVec( + std::vector* ins_vec, + const std::vector& instance, int index) { + if (index == 0) { + ins_vec->resize(instance.size()); + for (size_t i = 0; i < instance.size(); ++i) { + (*ins_vec)[i].Init(instance[i].GetType()); + (*ins_vec)[i].InitOffset(); + } + } + for (size_t i = 0; i < instance.size(); ++i) { + (*ins_vec)[i].AddIns(instance[i]); + } +} + +void MultiSlotDataFeed::PutToFeedVec( + const std::vector& ins_vec) { + for (size_t i = 0; i < use_slots_.size(); ++i) { + const auto& type = ins_vec[i].GetType(); + const auto& offset = ins_vec[i].GetOffset(); + int total_instance = static_cast(offset.back()); + if (type[0] == 'f') { // float + const auto& feasign = ins_vec[i].GetFloatData(); + if (feed_vec_[i].IsDense()) { + int size_in_each_batch = total_instance / batch_size_; + float* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data( + {batch_size_, size_in_each_batch}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); + } else { + float* tensor_ptr = feed_vec_[i].GetLoDTensor()->mutable_data( + {total_instance, 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(float)); + LoD data_lod{offset}; + feed_vec_[i].GetLoDTensor()->set_lod(data_lod); + } + } else if (type[0] == 'u') { // uint64 + // no uint64_t type in paddlepaddle + const auto& feasign = ins_vec[i].GetUint64Data(); + if (feed_vec_[i].IsDense()) { + int size_in_each_batch = total_instance / batch_size_; + int64_t* tensor_ptr = feed_vec_[i].GetTensor()->mutable_data( + {batch_size_, size_in_each_batch}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); + } else { + int64_t* tensor_ptr = + feed_vec_[i].GetLoDTensor()->mutable_data( + {total_instance, 1}, platform::CPUPlace()); + memcpy(tensor_ptr, &feasign[0], total_instance * sizeof(int64_t)); + LoD data_lod{offset}; + feed_vec_[i].GetLoDTensor()->set_lod(data_lod); + } + } + } +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h new file mode 100644 index 0000000000000000000000000000000000000000..a7f8d1d31752af200145bc7934e7880910338e9d --- /dev/null +++ b/paddle/fluid/framework/data_feed.h @@ -0,0 +1,269 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/operators/reader/blocking_queue.h" + +namespace paddle { +namespace framework { + +// Pack Tensor type and LoDTensor type into MixTensor type, in order +// to record either Tensor or LoDTensor information at the same time. +class MixTensor { + public: + MixTensor() {} + explicit MixTensor(LoDTensor* lodtensor) { + is_dense_ = false; + lodtensor_ = lodtensor; + } + explicit MixTensor(Tensor* tensor) { + is_dense_ = true; + tensor_ = tensor; + } + bool IsDense() { return is_dense_; } + LoDTensor* GetLoDTensor() { + PADDLE_ENFORCE(!is_dense_, "Let a dense var return a LoDTensor ptr."); + return lodtensor_; + } + Tensor* GetTensor() { + PADDLE_ENFORCE(is_dense_, "Let a sparse var return a Tensor ptr."); + return tensor_; + } + + private: + bool is_dense_; + LoDTensor* lodtensor_; + Tensor* tensor_; +}; + +// DataFeed is the base virtual class for all ohther DataFeeds. +// It is used to read files and parse the data for subsequent trainer. +// Example: +// DataFeed* reader = +// paddle::framework::DataFeedFactory::CreateDataFeed(data_feed_name); +// reader->Init(data_feed_desc); // data_feed_desc is a protobuf object +// reader->SetFileList(filelist); +// const std::vector & use_slot_alias = +// reader->GetUseSlotAlias(); +// for (auto name: use_slot_alias){ // for binding memory +// reader->AddFeedVar(scope->Var(name), name); +// } +// reader->Start(); +// while (reader->Next()) { +// // trainer do something +// } +class DataFeed { + public: + DataFeed() {} + virtual ~DataFeed() {} + virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; + virtual bool CheckFile(const char* filename) { + PADDLE_THROW("This function(CheckFile) is not implemented."); + } + // Set filelist for DataFeed. + // Pay attention that it must init all readers before call this function. + // Otherwise, Init() function will init finish_set_filelist_ flag. + virtual bool SetFileList(const std::vector& files); + virtual bool Start() = 0; + // The trainer calls the Next() function, and the DataFeed will load a new + // batch to the feed_vec. The return value of this function is the batch + // size of the current batch. + virtual int Next() = 0; + // Get all slots' alias which defined in protofile + virtual const std::vector& GetAllSlotAlias() { + return all_slots_; + } + // Get used slots' alias which defined in protofile + virtual const std::vector& GetUseSlotAlias() { + return use_slots_; + } + // This function is used for binding feed_vec memory + virtual void AddFeedVar(Variable* var, const std::string& name); + + protected: + // The following three functions are used to check if it is executed in this + // order: + // Init() -> SetFileList() -> Start() -> Next() + virtual void CheckInit(); + virtual void CheckSetFileList(); + virtual void CheckStart(); + virtual void SetBatchSize( + int batch); // batch size will be set in Init() function + // This function is used to pick one file from the global filelist(thread + // safe). + virtual bool PickOneFile(std::string* filename); + + static std::vector filelist_; + static size_t file_idx_; + static std::mutex mutex_for_pick_file_; + + // the alias of used slots, and its order is determined by + // data_feed_desc(proto object) + std::vector use_slots_; + std::vector use_slots_is_dense_; + + // the alias of all slots, and its order is determined by data_feed_desc(proto + // object) + std::vector all_slots_; + std::vector all_slots_type_; + std::vector + use_slots_index_; // -1: not used; >=0: the index of use_slots_ + + // The data read by DataFeed will be stored here + std::vector feed_vec_; + + // the batch size defined by user + int default_batch_size_; + // current batch size + int batch_size_; + + bool finish_init_; + static bool finish_set_filelist_; + bool finish_start_; +}; + +// PrivateQueueDataFeed is the base virtual class for ohther DataFeeds. +// It use a read-thread to read file and parse data to a private-queue +// (thread level), and get data from this queue when trainer call Next(). +template +class PrivateQueueDataFeed : public DataFeed { + public: + PrivateQueueDataFeed() {} + virtual ~PrivateQueueDataFeed() {} + virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc) = 0; + virtual bool Start(); + virtual int Next(); + + protected: + // The thread implementation function for reading file and parse. + virtual void ReadThread(); + // This function is used to set private-queue size, and the most + // efficient when the queue size is close to the batch size. + virtual void SetQueueSize(int queue_size); + // The reading and parsing method called in the ReadThread. + virtual bool ParseOneInstance(T* instance) = 0; + // This function is used to put instance to vec_ins + virtual void AddInstanceToInsVec(T* vec_ins, const T& instance, + int index) = 0; + // This function is used to put ins_vec to feed_vec + virtual void PutToFeedVec(const T& ins_vec) = 0; + + // The thread for read files + std::thread read_thread_; + // using ifstream one line and one line parse is faster + // than using fread one buffer and one buffer parse. + // for a 601M real data: + // ifstream one line and one line parse: 6034 ms + // fread one buffer and one buffer parse: 7097 ms + std::ifstream file_; + size_t queue_size_; + // The queue for store parsed data + std::unique_ptr> queue_; +}; + +// This class define the data type of instance(ins_vec) in MultiSlotDataFeed +class MultiSlotType { + public: + MultiSlotType() {} + ~MultiSlotType() {} + void Init(const std::string& type) { + CheckType(type); + if (type_[0] == 'f') { + float_feasign_.clear(); + } else if (type_[0] == 'u') { + uint64_feasign_.clear(); + } + type_ = type; + } + void InitOffset() { + offset_.resize(1); + // LoDTensor' lod is counted from 0, the size of lod + // is one size larger than the size of data. + offset_[0] = 0; + } + const std::vector& GetOffset() const { return offset_; } + void AddValue(const float v) { + CheckFloat(); + float_feasign_.push_back(v); + } + void AddValue(const uint64_t v) { + CheckUint64(); + uint64_feasign_.push_back(v); + } + void AddIns(const MultiSlotType& ins) { + if (ins.GetType()[0] == 'f') { // float + CheckFloat(); + auto& vec = ins.GetFloatData(); + offset_.push_back(offset_.back() + vec.size()); + float_feasign_.insert(float_feasign_.end(), vec.begin(), vec.end()); + } else if (ins.GetType()[0] == 'u') { // uint64 + CheckUint64(); + auto& vec = ins.GetUint64Data(); + offset_.push_back(offset_.back() + vec.size()); + uint64_feasign_.insert(uint64_feasign_.end(), vec.begin(), vec.end()); + } + } + const std::vector& GetFloatData() const { return float_feasign_; } + const std::vector& GetUint64Data() const { return uint64_feasign_; } + const std::string& GetType() const { return type_; } + + private: + void CheckType(const std::string& type) const { + PADDLE_ENFORCE((type == "uint64") || (type == "float"), + "There is no this type<%s>.", type); + } + void CheckFloat() const { + PADDLE_ENFORCE(type_[0] == 'f', "Add %s value to float slot.", type_); + } + void CheckUint64() const { + PADDLE_ENFORCE(type_[0] == 'u', "Add %s value to uint64 slot.", type_); + } + std::vector float_feasign_; + std::vector uint64_feasign_; + std::string type_; + std::vector offset_; +}; + +// This DataFeed is used to feed multi-slot type data. +// The format of multi-slot type data: +// [n feasign_0 feasign_1 ... feasign_n]* +class MultiSlotDataFeed + : public PrivateQueueDataFeed> { + public: + MultiSlotDataFeed() {} + virtual ~MultiSlotDataFeed() {} + virtual void Init(const paddle::framework::DataFeedDesc& data_feed_desc); + virtual bool CheckFile(const char* filename); + + protected: + virtual void AddInstanceToInsVec(std::vector* vec_ins, + const std::vector& instance, + int index); + virtual bool ParseOneInstance(std::vector* instance); + virtual void PutToFeedVec(const std::vector& ins_vec); +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto new file mode 100644 index 0000000000000000000000000000000000000000..489fec08d86ccf61ece29bbba6d0204f25530b0f --- /dev/null +++ b/paddle/fluid/framework/data_feed.proto @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +syntax = "proto2"; +package paddle.framework; + +message Slot { + required string name = 1; + required string type = 2; + optional bool is_dense = 3 [ default = false ]; + optional bool is_used = 4 [ default = false ]; +} + +message MultiSlotDesc { repeated Slot slots = 1; } + +message DataFeedDesc { + optional string name = 1; + optional int32 batch_size = 2 [ default = 32 ]; + optional MultiSlotDesc multi_slot_desc = 3; +} diff --git a/paddle/fluid/framework/data_feed_factory.cc b/paddle/fluid/framework/data_feed_factory.cc new file mode 100644 index 0000000000000000000000000000000000000000..72148b9f7d343e19d60bb2be44d8270ad78d1412 --- /dev/null +++ b/paddle/fluid/framework/data_feed_factory.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_feed_factory.h" +#include +#include +#include + +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { +typedef std::shared_ptr (*Createdata_feedFunction)(); +typedef std::unordered_map data_feedMap; +data_feedMap g_data_feed_map; + +#define REGISTER_DATAFEED_CLASS(data_feed_class) \ + namespace { \ + std::shared_ptr Creator_##data_feed_class() { \ + return std::shared_ptr(new data_feed_class); \ + } \ + class __Registerer_##data_feed_class { \ + public: \ + __Registerer_##data_feed_class() { \ + g_data_feed_map[#data_feed_class] = &Creator_##data_feed_class; \ + } \ + }; \ + __Registerer_##data_feed_class g_registerer_##data_feed_class; \ + } // namespace + +std::string DataFeedFactory::DataFeedTypeList() { + std::string data_feed_types; + for (auto iter = g_data_feed_map.begin(); iter != g_data_feed_map.end(); + ++iter) { + if (iter != g_data_feed_map.begin()) { + data_feed_types += ", "; + } + data_feed_types += iter->first; + } + return data_feed_types; +} + +std::shared_ptr DataFeedFactory::CreateDataFeed( + std::string data_feed_class) { + if (g_data_feed_map.count(data_feed_class) < 1) { + exit(-1); + } + return g_data_feed_map[data_feed_class](); +} + +REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_feed_factory.h b/paddle/fluid/framework/data_feed_factory.h new file mode 100644 index 0000000000000000000000000000000000000000..13678edb0b8d084a0b3016d93f6e1bc32ce0169a --- /dev/null +++ b/paddle/fluid/framework/data_feed_factory.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/data_feed.h" + +namespace paddle { +namespace framework { +class DataFeedFactory { + public: + static std::string DataFeedTypeList(); + static std::shared_ptr CreateDataFeed(std::string data_feed_class); +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_feed_test.cc b/paddle/fluid/framework/data_feed_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..3974f8dbadf332801a822618d77f140db440b29d --- /dev/null +++ b/paddle/fluid/framework/data_feed_test.cc @@ -0,0 +1,337 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/data_feed.h" +#include +#include // NOLINT +#include +#include +#include +#include // NOLINT +#include +#include // NOLINT +#include +#include +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/data_feed_factory.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/scope.h" + +paddle::framework::DataFeedDesc load_datafeed_param_from_file( + const char* filename) { + paddle::framework::DataFeedDesc data_feed_desc; + int file_descriptor = open(filename, O_RDONLY); + PADDLE_ENFORCE(file_descriptor != -1, "Can not open %s.", filename); + google::protobuf::io::FileInputStream fileInput(file_descriptor); + google::protobuf::TextFormat::Parse(&fileInput, &data_feed_desc); + close(file_descriptor); + return data_feed_desc; +} + +const std::vector load_filelist_from_file(const char* filename) { + std::vector filelist; + std::ifstream fin(filename); + PADDLE_ENFORCE(fin.good(), "Can not open %s.", filename); + std::string line; + while (getline(fin, line)) { + filelist.push_back(line); + } + fin.close(); + return filelist; +} + +void GenerateFileForTest(const char* protofile, const char* filelist) { + std::ofstream w_protofile(protofile); + w_protofile << "name: \"MultiSlotDataFeed\"\n" + "batch_size: 2\n" + "multi_slot_desc {\n" + " slots {\n" + " name: \"uint64_sparse_slot\"\n" + " type: \"uint64\"\n" + " is_dense: false\n" + " is_used: true\n" + " }\n" + " slots {\n" + " name: \"float_sparse_slot\"\n" + " type: \"float\"\n" + " is_dense: false\n" + " is_used: true\n" + " }\n" + " slots {\n" + " name: \"uint64_dense_slot\"\n" + " type: \"uint64\"\n" + " is_dense: true\n" + " is_used: true\n" + " }\n" + " slots {\n" + " name: \"float_dense_slot\"\n" + " type: \"float\"\n" + " is_dense: true\n" + " is_used: true\n" + " }\n" + " slots {\n" + " name: \"not_used_slot\"\n" + " type: \"uint64\"\n" + " is_dense: false\n" + " is_used: false\n" + " }\n" + "}"; + w_protofile.close(); + std::ofstream w_filelist(filelist); + int total_file = 4; + for (int i = 0; i < total_file; ++i) { + std::string filename = "TestMultiSlotDataFeed.data." + std::to_string(i); + w_filelist << filename; + if (i + 1 != total_file) { + w_filelist << std::endl; + } + std::ofstream w_datafile(filename.c_str()); + w_datafile << "3 3978 620 82 1 1926.08 1 1926 1 6.02 1 1996\n" + "2 1300 2983353 1 985.211 1 8 1 0.618 1 12\n" + "1 19260827 2 3.14 2.718 1 27 1 2.236 1 28\n"; + w_datafile.close(); + } + w_filelist.close(); +} + +class MultiTypeSet { + public: + MultiTypeSet() { + uint64_set_.clear(); + float_set_.clear(); + } + ~MultiTypeSet() {} + void AddValue(uint64_t v) { uint64_set_.insert(v); } + void AddValue(float v) { float_set_.insert(v); } + const std::set& GetUint64Set() const { return uint64_set_; } + const std::set& GetFloatSet() const { return float_set_; } + + private: + std::set uint64_set_; + std::set float_set_; +}; + +void GetElemSetFromReader(std::vector* reader_elem_set, + const paddle::framework::DataFeedDesc& data_feed_desc, + const std::vector& filelist, + const int thread_num) { + int used_slot_num = 0; + for (auto i = 0; i < data_feed_desc.multi_slot_desc().slots_size(); ++i) { + if (data_feed_desc.multi_slot_desc().slots(i).is_used()) { + ++used_slot_num; + } + } + reader_elem_set->resize(used_slot_num); + std::vector threads; + std::vector> readers; + readers.resize(thread_num); + for (int i = 0; i < thread_num; ++i) { + readers[i] = paddle::framework::DataFeedFactory::CreateDataFeed( + data_feed_desc.name()); + readers[i]->Init(data_feed_desc); + } + readers[0]->SetFileList(filelist); + std::mutex mu; + for (int idx = 0; idx < thread_num; ++idx) { + threads.emplace_back(std::thread([&, idx] { + std::unique_ptr scope( + new paddle::framework::Scope()); + const auto& multi_slot_desc = data_feed_desc.multi_slot_desc(); + std::map + lodtensor_targets; + std::map tensor_targets; + for (int i = 0; i < multi_slot_desc.slots_size(); ++i) { + const auto& slot = multi_slot_desc.slots(i); + if (slot.is_used()) { + const auto& name = slot.name(); + readers[idx]->AddFeedVar(scope->Var(name), name); + if (slot.is_dense()) { + tensor_targets[name] = + &scope->FindVar(name)->Get(); + } else { + lodtensor_targets[name] = + &scope->FindVar(name)->Get(); + } + } + } + readers[idx]->Start(); + while (readers[idx]->Next()) { + int index = 0; + for (int k = 0; k < multi_slot_desc.slots_size(); ++k) { + const auto& slot = multi_slot_desc.slots(k); + if (!slot.is_used()) { + continue; + } + if (slot.is_dense()) { // dense branch + const paddle::framework::Tensor* tens = tensor_targets[slot.name()]; + if (slot.type() == "uint64") { + const int64_t* data = tens->data(); + int batch_size = tens->dims()[0]; + int dim = tens->dims()[1]; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < dim; ++j) { + std::lock_guard lock(mu); + (*reader_elem_set)[index].AddValue( + (uint64_t)data[i * dim + j]); + } + } + } else if (slot.type() == "float") { + const float* data = tens->data(); + int batch_size = tens->dims()[0]; + int dim = tens->dims()[1]; + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < dim; ++j) { + std::lock_guard lock(mu); + (*reader_elem_set)[index].AddValue(data[i * dim + j]); + } + } + } else { + PADDLE_THROW("Error type in proto file."); + } + } else { // sparse branch + const paddle::framework::LoDTensor* tens = + lodtensor_targets[slot.name()]; + if (slot.type() == "uint64") { + const int64_t* data = tens->data(); + for (size_t i = 0; i < tens->NumElements(); ++i) { + std::pair element = tens->lod_element(0, i); + for (size_t j = element.first; j < element.second; ++j) { + std::lock_guard lock(mu); + (*reader_elem_set)[index].AddValue((uint64_t)data[j]); + } + } + } else if (slot.type() == "float") { + const float* data = tens->data(); + for (size_t i = 0; i < tens->NumElements(); ++i) { + std::pair element = tens->lod_element(0, i); + for (size_t j = element.first; j < element.second; ++j) { + std::lock_guard lock(mu); + (*reader_elem_set)[index].AddValue(data[j]); + } + } + } else { + PADDLE_THROW("Error type in proto file."); + } + } // end sparse branch + ++index; + } // end slots loop + } // end while Next() + })); // end anonymous function + } + for (auto& th : threads) { + th.join(); + } +} + +void CheckIsUnorderedSame(const std::vector& s1, + const std::vector& s2) { + EXPECT_EQ(s1.size(), s2.size()); + for (size_t i = 0; i < s1.size(); ++i) { + // check for uint64 + const std::set& uint64_s1 = s1[i].GetUint64Set(); + const std::set& uint64_s2 = s2[i].GetUint64Set(); + EXPECT_EQ(uint64_s1.size(), uint64_s2.size()); + auto uint64_it1 = uint64_s1.begin(); + auto uint64_it2 = uint64_s2.begin(); + while (uint64_it1 != uint64_s1.end()) { + EXPECT_EQ(*uint64_it1, *uint64_it2); + ++uint64_it1; + ++uint64_it2; + } + // check for float + const std::set& float_s1 = s1[i].GetFloatSet(); + const std::set& float_s2 = s2[i].GetFloatSet(); + EXPECT_EQ(float_s1.size(), float_s2.size()); + auto float_it1 = float_s1.begin(); + auto float_it2 = float_s2.begin(); + while (float_it1 != float_s1.end()) { + EXPECT_EQ(*float_it1, *float_it2); + ++float_it1; + ++float_it2; + } + } +} + +void GetElemSetFromFile(std::vector* file_elem_set, + const paddle::framework::DataFeedDesc& data_feed_desc, + const std::vector& filelist) { + int used_slot_num = 0; + for (auto i = 0; i < data_feed_desc.multi_slot_desc().slots_size(); ++i) { + if (data_feed_desc.multi_slot_desc().slots(i).is_used()) { + ++used_slot_num; + } + } + file_elem_set->resize(used_slot_num); + for (const auto& file : filelist) { + std::ifstream fin(file.c_str()); + PADDLE_ENFORCE(fin.good(), "Can not open %s.", file.c_str()); + while (1) { + bool end_flag = false; + int index = 0; + for (auto i = 0; i < data_feed_desc.multi_slot_desc().slots_size(); ++i) { + int num; + if (fin >> num) { + auto slot = data_feed_desc.multi_slot_desc().slots(i); + auto type = slot.type(); + if (type == "uint64") { + while (num--) { + uint64_t feasign; + fin >> feasign; + if (slot.is_used()) { + (*file_elem_set)[index].AddValue(feasign); + } + } + } else if (type == "float") { + while (num--) { + float feasign; + fin >> feasign; + if (slot.is_used()) { + (*file_elem_set)[index].AddValue(feasign); + } + } + } else { + PADDLE_THROW("Error type in proto file."); + } + if (slot.is_used()) { + ++index; + } + } else { + end_flag = true; + break; + } + } + if (end_flag) { + break; + } + } + fin.close(); + } +} + +TEST(DataFeed, MultiSlotUnitTest) { + const char* protofile = "data_feed_desc.prototxt"; + const char* filelist_name = "filelist.txt"; + GenerateFileForTest(protofile, filelist_name); + const std::vector filelist = + load_filelist_from_file(filelist_name); + paddle::framework::DataFeedDesc data_feed_desc = + load_datafeed_param_from_file(protofile); + std::vector reader_elem_set; + std::vector file_elem_set; + GetElemSetFromReader(&reader_elem_set, data_feed_desc, filelist, 4); + GetElemSetFromFile(&file_elem_set, data_feed_desc, filelist); + CheckIsUnorderedSame(reader_elem_set, file_elem_set); +} diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index e5b1eaa7318aecde1dbf89de8fe242a3008db97c..499246a9856bb3ba67a155c6f00c3ad06af50edf 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -16,7 +16,7 @@ #include #include #include -#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/profiler.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/framework/details/reference_count_op_handle.h" diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 96132a2c18233ca10d7bad4e26dfabadd39d84db..73cec21e20f2fd26e144872f1f7b5bb7065adb74 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/transfer_scope_cache.h" +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/detail/macros.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" @@ -114,36 +115,6 @@ void Executor::Close() { #endif } -void InitializeVariable(Variable* var, proto::VarType::Type var_type) { - if (var_type == proto::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == proto::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else if (var_type == proto::VarType::FEED_MINIBATCH) { - var->GetMutable(); - } else if (var_type == proto::VarType::FETCH_LIST) { - var->GetMutable(); - } else if (var_type == proto::VarType::STEP_SCOPES) { - var->GetMutable>(); - } else if (var_type == proto::VarType::LOD_RANK_TABLE) { - var->GetMutable(); - } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { - var->GetMutable(); - } else if (var_type == proto::VarType::PLACE_LIST) { - var->GetMutable(); - } else if (var_type == proto::VarType::READER) { - var->GetMutable(); - } else if (var_type == proto::VarType::RAW) { - // GetMutable will be called in operator - } else { - PADDLE_THROW( - "Variable type %d is not in " - "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER, RAW]", - var_type); - } -} - void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id) { auto& global_block = pdesc.Block(block_id); diff --git a/paddle/fluid/framework/executor.h b/paddle/fluid/framework/executor.h index 36b36d49c2728dbef93042158dffa26d8f56d529..2d47903ffbd8d821b7c31386b225fe5e65ca2720 100644 --- a/paddle/fluid/framework/executor.h +++ b/paddle/fluid/framework/executor.h @@ -26,7 +26,6 @@ limitations under the License. */ namespace paddle { namespace framework { -extern void InitializeVariable(Variable* var, proto::VarType::Type var_type); template std::unordered_map GetNonPersistableReferenceCount( diff --git a/paddle/fluid/framework/executor_thread_worker.cc b/paddle/fluid/framework/executor_thread_worker.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e4001e979fdd0774779fa288402c7847af90637 --- /dev/null +++ b/paddle/fluid/framework/executor_thread_worker.cc @@ -0,0 +1,223 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/executor_thread_worker.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/message.h" +#include "google/protobuf/text_format.h" + +#include "gflags/gflags.h" +#include "paddle/fluid/framework/feed_fetch_method.h" +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/pybind/pybind.h" +namespace paddle { +namespace framework { + +void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) { + auto& block = program.Block(0); + op_names_.clear(); + for (auto& op_desc : block.AllOps()) { + std::unique_ptr local_op = OpRegistry::CreateOp(*op_desc); + op_names_.push_back(op_desc->Type()); + OperatorBase* local_op_ptr = local_op.release(); + ops_.push_back(local_op_ptr); + continue; + } +} + +void ExecutorThreadWorker::CreateThreadResource( + const framework::ProgramDesc& program, + const paddle::platform::Place& place) { + CreateThreadScope(program); + CreateThreadOperators(program); + SetMainProgram(program); + SetPlace(place); +} + +void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) { + auto& block = program.Block(0); + + PADDLE_ENFORCE_NOT_NULL( + root_scope_, "root_scope should be set before creating thread scope"); + + thread_scope_ = &root_scope_->NewScope(); + for (auto& var : block.AllVars()) { + if (var->Persistable()) { + auto* ptr = root_scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } else { + auto* ptr = thread_scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + } + } +} + +void ExecutorThreadWorker::SetDataFeed( + const std::shared_ptr& datafeed) { + thread_reader_ = datafeed; +} + +void ExecutorThreadWorker::BindingDataFeedMemory() { + const std::vector& input_feed = + thread_reader_->GetUseSlotAlias(); + for (auto name : input_feed) { + thread_reader_->AddFeedVar(thread_scope_->Var(name), name); + } +} + +void ExecutorThreadWorker::SetFetchVarNames( + const std::vector& fetch_var_names) { + fetch_var_names_.clear(); + fetch_var_names_.insert(fetch_var_names_.end(), fetch_var_names.begin(), + fetch_var_names.end()); +} + +void ExecutorThreadWorker::SetDevice() { +#if defined _WIN32 || defined __APPLE__ + return; +#else + static unsigned concurrency_cap = std::thread::hardware_concurrency(); + int thread_id = this->thread_id_; + + if (thread_id < concurrency_cap) { + unsigned proc = thread_id; + + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(proc, &mask); + + if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) { + VLOG(1) << "WARNING: Failed to set thread affinity for thread " + << thread_id; + } else { + CPU_ZERO(&mask); + if ((0 != sched_getaffinity(0, sizeof(mask), &mask)) || + (CPU_ISSET(proc, &mask) == 0)) { + VLOG(3) << "WARNING: Failed to set thread affinity for thread " + << thread_id; + } + } + } else { + VLOG(1) << "WARNING: Failed to set thread affinity for thread " + << thread_id; + } +#endif +} + +template +void print_lod_tensor(std::string var_name, const LoDTensor& lod_tensor) { + auto inspect = lod_tensor.data(); + auto element_num = lod_tensor.numel(); + + std::ostringstream sstream; + sstream << var_name << " (element num " << element_num << "): ["; + sstream << inspect[0]; + for (int j = 1; j < element_num; ++j) { + sstream << " " << inspect[j]; + } + sstream << "]"; + + std::cout << sstream.str() << std::endl; +} + +void print_fetch_var(Scope* scope, std::string var_name) { + const LoDTensor& tensor = scope->FindVar(var_name)->Get(); + + if (std::type_index(tensor.type()) == + std::type_index(typeid(platform::float16))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == std::type_index(typeid(float))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == + std::type_index(typeid(double))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == std::type_index(typeid(int))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == + std::type_index(typeid(int64_t))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == std::type_index(typeid(bool))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == + std::type_index(typeid(uint8_t))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == + std::type_index(typeid(int16_t))) { + print_lod_tensor(var_name, tensor); + } else if (std::type_index(tensor.type()) == + std::type_index(typeid(int8_t))) { + print_lod_tensor(var_name, tensor); + } else { + VLOG(1) << "print_fetch_var: unrecognized data type:" + << tensor.type().name(); + } + + return; +} + +void ExecutorThreadWorker::TrainFiles() { + // todo: configurable + SetDevice(); + + int fetch_var_num = fetch_var_names_.size(); + fetch_values_.clear(); + fetch_values_.resize(fetch_var_num); + + thread_reader_->Start(); + + int cur_batch; + int batch_cnt = 0; + while ((cur_batch = thread_reader_->Next()) > 0) { + // executor run here + for (auto& op : ops_) { + op->Run(*thread_scope_, place_); + } + + ++batch_cnt; + thread_scope_->DropKids(); + + if (debug_ == false || thread_id_ != 0) { + continue; + } + + for (int i = 0; i < fetch_var_num; ++i) { + print_fetch_var(thread_scope_, fetch_var_names_[i]); + } // end for (int i = 0...) + } // end while () +} + +void ExecutorThreadWorker::SetThreadId(int tid) { thread_id_ = tid; } + +void ExecutorThreadWorker::SetPlace(const platform::Place& place) { + place_ = place; +} + +void ExecutorThreadWorker::SetMainProgram( + const ProgramDesc& main_program_desc) { + main_program_.reset(new ProgramDesc(main_program_desc)); +} + +void ExecutorThreadWorker::SetRootScope(Scope* g_scope) { + root_scope_ = g_scope; +} + +} // einit_modelnd namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/executor_thread_worker.h b/paddle/fluid/framework/executor_thread_worker.h new file mode 100644 index 0000000000000000000000000000000000000000..13ec2442c46459116320236bf98f23c91340f389 --- /dev/null +++ b/paddle/fluid/framework/executor_thread_worker.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include // NOLINT +#include +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +void CreateTensor(Variable* var, proto::VarType::Type var_type); + +class ExecutorThreadWorker { + public: + ExecutorThreadWorker() + : thread_id_(-1), root_scope_(NULL), thread_scope_(NULL), debug_(false) {} + ~ExecutorThreadWorker() {} + + void CreateThreadResource(const framework::ProgramDesc& program, + const paddle::platform::Place& place); + void SetThreadId(int tid); + void SetDebug(const bool debug) { debug_ = debug; } + void SetRootScope(Scope* g_scope); + // set cpu device in this function + // cpu binding is used by default + void SetDevice(); + // since we read data into memory that can not be accessed by program + // we need to bind memory of data with corresponding variables in program + // this function should be called after data feed is set + void BindingDataFeedMemory(); + // set data feed declared in executor + void SetDataFeed(const std::shared_ptr& datafeed); + // A multi-thread training function + void TrainFiles(); + // set fetch variable names from python interface assigned by users + void SetFetchVarNames(const std::vector& fetch_var_names); + + private: + void CreateThreadScope(const framework::ProgramDesc& program); + void CreateThreadOperators(const framework::ProgramDesc& program); + void SetMainProgram(const ProgramDesc& main_program_desc); + void SetPlace(const paddle::platform::Place& place); + + protected: + // thread index + std::shared_ptr thread_reader_; // shared queue, thread buffer + int thread_id_; + // operator name + std::vector op_names_; + // thread level, local operators for forward and backward + std::vector ops_; + // main program for training + std::unique_ptr main_program_; + // execution place + platform::Place place_; + // root scope for model parameters + Scope* root_scope_; + // a thread scope, father scope is global score which is shared + Scope* thread_scope_; + + private: + std::vector fetch_var_names_; + std::vector> fetch_values_; + bool debug_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index e8295639520b5838dce3c9c9e443cc846bd9c1ec..f1642bc0d2b10f97295e80ee201db8f83bfd06ef 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -21,42 +21,11 @@ #include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/string/pretty_log.h" namespace paddle { namespace framework { - -// These code can be shared with Executor. -static void InitializeVariable(Variable *var, proto::VarType::Type var_type) { - if (var_type == proto::VarType::LOD_TENSOR) { - var->GetMutable(); - } else if (var_type == proto::VarType::SELECTED_ROWS) { - var->GetMutable(); - } else if (var_type == proto::VarType::FEED_MINIBATCH) { - var->GetMutable(); - } else if (var_type == proto::VarType::FETCH_LIST) { - var->GetMutable(); - } else if (var_type == proto::VarType::STEP_SCOPES) { - var->GetMutable>(); - } else if (var_type == proto::VarType::LOD_RANK_TABLE) { - var->GetMutable(); - } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { - var->GetMutable(); - } else if (var_type == proto::VarType::PLACE_LIST) { - var->GetMutable(); - } else if (var_type == proto::VarType::READER) { - var->GetMutable(); - } else if (var_type == proto::VarType::RAW) { - // GetMutable will be called in operator - } else { - PADDLE_THROW( - "Variable type %d is not in " - "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]", - var_type); - } -} - void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc, int block_id, bool with_feed_fetch_ops) { if (!scope) { diff --git a/paddle/fluid/framework/variable_helper.cc b/paddle/fluid/framework/variable_helper.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc4525549caeebb06dea766ccb123b5ebc6d5b13 --- /dev/null +++ b/paddle/fluid/framework/variable_helper.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/variable_helper.h" + +#include + +#include "paddle/fluid/framework/feed_fetch_type.h" +#include "paddle/fluid/framework/lod_rank_table.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/lod_tensor_array.h" +#include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +void InitializeVariable(Variable* var, proto::VarType::Type var_type) { + if (var_type == proto::VarType::LOD_TENSOR) { + var->GetMutable(); + } else if (var_type == proto::VarType::SELECTED_ROWS) { + var->GetMutable(); + } else if (var_type == proto::VarType::FEED_MINIBATCH) { + var->GetMutable(); + } else if (var_type == proto::VarType::FETCH_LIST) { + var->GetMutable(); + } else if (var_type == proto::VarType::STEP_SCOPES) { + var->GetMutable>(); + } else if (var_type == proto::VarType::LOD_RANK_TABLE) { + var->GetMutable(); + } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) { + var->GetMutable(); + } else if (var_type == proto::VarType::PLACE_LIST) { + var->GetMutable(); + } else if (var_type == proto::VarType::READER) { + var->GetMutable(); + } else if (var_type == proto::VarType::RAW) { + // GetMutable will be called in operator + } else { + PADDLE_THROW( + "Variable type %d is not in " + "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " + "LOD_RANK_TABLE, PLACE_LIST, READER, RAW]", + var_type); + } +} +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/variable_helper.h b/paddle/fluid/framework/variable_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..0e0c72c3621dce0a6b372f9a9110a63fbc0a1d71 --- /dev/null +++ b/paddle/fluid/framework/variable_helper.h @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/framework/variable.h" +namespace paddle { +namespace framework { +void InitializeVariable(Variable *var, proto::VarType::Type var_type); +} +} diff --git a/paddle/fluid/operators/distributed/request_handler_impl.cc b/paddle/fluid/operators/distributed/request_handler_impl.cc index 0258f8f2384669c8a7466fd3c60f9b55a0fde9fd..9722f8c96e91d2dfbe929dcc11645a40c44afb4e 100644 --- a/paddle/fluid/operators/distributed/request_handler_impl.cc +++ b/paddle/fluid/operators/distributed/request_handler_impl.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include #include #include @@ -20,7 +21,7 @@ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" -#include "paddle/fluid/operators/distributed/request_handler_impl.h" +#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/string/printf.h" diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 25d241d9768c16e1da304a78f259d5a626f702fc..d602613fc82223e14f48830a87533880696eb550 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ -set(PYBIND_DEPS pybind python proto_desc memory executor prune feed_fetch_method pass_builder parallel_executor profiler) -set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc) +set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler) +set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc) if(WITH_PYTHON) if(WITH_AMD_GPU) hip_library(paddle_pybind SHARED diff --git a/paddle/fluid/pybind/async_executor_py.cc b/paddle/fluid/pybind/async_executor_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..470e8b050808295d49728bbdb757b6a612df9a01 --- /dev/null +++ b/paddle/fluid/pybind/async_executor_py.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3) +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif +#include +#include + +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" +#include "paddle/fluid/framework/async_executor.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/variant.h" +#include "paddle/fluid/pybind/async_executor_py.h" + +namespace py = pybind11; +namespace pd = paddle::framework; + +namespace paddle { +namespace pybind { +using set_name_func = void (pd::DataFeedDesc::*)(const std::string&); +void BindAsyncExecutor(py::module* m) { + py::class_(*m, "AsyncExecutor") + .def(py::init([](framework::Scope* scope, const platform::Place& place) { + return std::unique_ptr( + new framework::AsyncExecutor(scope, place)); + })) + .def("run_from_files", &framework::AsyncExecutor::RunFromFile); +} // end BindAsyncExecutor +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/async_executor_py.h b/paddle/fluid/pybind/async_executor_py.h new file mode 100644 index 0000000000000000000000000000000000000000..a99d6e04218c9310ede00de7d9bdfc015889bd22 --- /dev/null +++ b/paddle/fluid/pybind/async_executor_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindAsyncExecutor(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 1835c064055635a4284fc64f4ca4dd8728f933ca..fc7991d2974c9262e6225de1537025944c1068c1 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -42,6 +42,7 @@ limitations under the License. */ #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/pybind/async_executor_py.h" #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/protobuf.h" @@ -932,6 +933,7 @@ All parameter, weight, gradient are variables in Paddle. }); BindRecordIOWriter(&m); + BindAsyncExecutor(&m); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/fluid/__init__.py b/python/paddle/fluid/__init__.py index f7fefb3e5b767e25373665058d4fd6a298fb3d60..a1ffbf42622bcda44ec038edf20811fb0032891f 100644 --- a/python/paddle/fluid/__init__.py +++ b/python/paddle/fluid/__init__.py @@ -20,6 +20,13 @@ from .framework import * # import all class inside executor into fluid module from . import executor from .executor import * + +from . import data_feed_desc +from .data_feed_desc import * + +from . import async_executor +from .async_executor import * + from . import trainer from . import inferencer @@ -54,7 +61,8 @@ Tensor = LoDTensor __all__ = framework.__all__ + executor.__all__ + \ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \ - parallel_executor.__all__ + lod_tensor.__all__ + [ + parallel_executor.__all__ + lod_tensor.__all__ + \ + data_feed_desc.__all__ + async_executor.__all__ + [ 'io', 'initializer', 'layers', diff --git a/python/paddle/fluid/async_executor.py b/python/paddle/fluid/async_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..2664a7301db3bf471126ff26504e7042f02b7d84 --- /dev/null +++ b/python/paddle/fluid/async_executor.py @@ -0,0 +1,151 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import contextlib +import six +from .framework import Program, default_main_program, Variable +from . import core +from .executor import global_scope, Executor +from paddle.fluid.proto import data_feed_pb2 +from google.protobuf import text_format +from . import io +from .data_feed_desc import DataFeedDesc + +__all__ = ['AsyncExecutor'] + + +class AsyncExecutor(object): + """ + An asynchronous Executor in Python. Through exploiting the power of + multi-core processor and data queueing, AsyncExecutor makes data reading + and cosuming decoupled, each run in multiple threads in parallel. + + Instead of reading data in python side, AsyncExecutor accepts a training + file list, which will be retrieved in C++, then training inputs will be + read, parsed and fed to training network within C++ code. + + AsyncExecutor is in active development and the API might change in the near + future. + + Example: + >>> data_feed = fluid.DataFeedDesc('data.proto') + >>> startup_program = fluid.default_startup_program() + >>> main_program = fluid.default_main_program() + >>> filelist = ["train_data/part-%d" % i for i in range(100)] + >>> thread_num = len(filelist) / 4 + >>> + >>> place = fluid.CPUPlace() + >>> async_executor = fluid.AsyncExecutor(place) + >>> + >>> async_executor.run_startup_program(startup_program) + >>> + >>> epoch = 10 + >>> for i in range(epoch): + >>> async_executor.run(main_program, + >>> data_feed, + >>> filelist, + >>> thread_num, + >>> [acc], + >>> debug=False) + + Args: + place(fluid.CPUPlace|None): indicate the executor run on which device. + Only CPUPlace supported + + Note: + For debugging complicated network in parallel-GPUs, you can test it + on the executor. They has the exactly same arguments, and expected + the same results. + + Note: Only running on CPUPlace supported. + """ + + def __init__(self, place=None): + if place is None: + place = core.CPUPlace() + if not isinstance(place, core.CPUPlace): + raise ValueError("AsyncExecutor only supports CPU device") + + p = core.Place() + p.set_place(place) + + scope = global_scope() + self.executor = core.AsyncExecutor(scope, p) + + def run(self, program, data_feed, filelist, thread_num, fetch, debug=False): + """ + Run program by this AsyncExecutor. Training dataset will be in filelist. + Users can also inspect certain variables by naming them in parameter + :code:`fetch`, like in fluid.Executor. Unlike fluid.Executor, however, + AsyncExecutor doesn't return fetched variables, instead, it will dump + the values of each fetched variable to stdandard output. + + Running the dataset will be on multiple threads, within each a thread + local scope will be created, then all OPs also created in that scope. + Parameters are updated by all the OPs simultaneously. + + Args: + program(Program): the program that need to run, if not provied, + then default_main_program will be used. + data_feed(DataFeedDesc): A DataFeedDesc object + filelist(str): a file containing the training dataset file list + thread_num(int): number of concurrent training threads. See + :code:`Note` for how to set this properly + fetch(str|list): the var name or a list of var names to inspect + debug(bool): When set to True, fetch vars will be printed to + standard output after each minibatch + + Note: + the executor will run all operators in the program but not only + the operators dependent by the fetch_list. + + Note: + Running AsyncExecutor will be on multiple threads, each bound to a + CPU core. To achieve best performance, it's suggested to set thread + num to be equal or slightly less than that of CPU cores. + """ + if program is None: + program = default_main_program() + program_desc = program.desc + + if data_feed is None: + raise ValueError('ValueError: data_feed should be provided') + + if filelist is None: + raise ValueError('ValueError: filelist should be provided') + + if isinstance(filelist, str): + filelist = [filelist] + + if not isinstance(thread_num, int): + raise TypeError('TypeError: thread_num should be a positive number') + + if fetch is not None: + if isinstance(fetch, Variable): + fetch = [fetch] + fetch_var_names = [var.name for var in fetch] + for fetch_var in fetch: + shape = fetch_var.shape + if shape[len(shape) - 1] != 1: + raise AssertionError( + "%s: Fetch variable has wrong shape. Only varibles " + "with the last dimension size 1 supported." % + (fetch_var.name)) + + self.executor.run_from_files(program_desc, + data_feed.desc(), filelist, thread_num, + fetch_var_names, debug) diff --git a/python/paddle/fluid/data_feed_desc.py b/python/paddle/fluid/data_feed_desc.py new file mode 100644 index 0000000000000000000000000000000000000000..d2ec74d6cfdeb34c1f48c086a3aa30d5100c3efb --- /dev/null +++ b/python/paddle/fluid/data_feed_desc.py @@ -0,0 +1,152 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid.proto import data_feed_pb2 +from google.protobuf import text_format + +__all__ = ['DataFeedDesc'] + + +class DataFeedDesc(object): + """ + Datafeed descriptor, describing input training data format. This class is + currently only used for AsyncExecutor (See comments for class AsyncExecutor + for a brief introduction) + + DataFeedDesc shall be initialized from a valid protobuf message from disk: + >>> data_feed = fluid.DataFeedDesc('data.proto') + + See :code:`paddle/fluid/framework/data_feed.proto` for message definition. + A typical message might look like: + + >>> name: "MultiSlotDataFeed" + >>> batch_size: 2 + >>> multi_slot_desc { + >>> slots { + >>> name: "words" + >>> type: "uint64" + >>> is_dense: false + >>> is_used: true + >>> } + >>> slots { + >>> name: "label" + >>> type: "uint64" + >>> is_dense: false + >>> is_used: true + >>> } + >>> } + + However, users usually shouldn't care about the message format; instead, + they are encouragd to use :code:`Data Generator` as a tool to generate a + valid data description, in the process of converting their raw log files to + training files acceptable to AsyncExecutor. + + DataFeedDesc can also be changed during runtime. Once you got familiar with + what each field mean, you can modify it to better suit your need. E.g.: + >>> data_feed.set_batch_size(128) + >>> data_feed.set_dense_slots('wd') # The slot named 'wd' will be dense + >>> data_feed.set_use_slots('wd') # The slot named 'wd' will be used + + Finally, the content can be dumped out for debugging purpose: + >>> print(data_feed.desc()) + + Args: + proto_file(string): Disk file containing a data feed description. + + """ + + def __init__(self, proto_file): + self.proto_desc = data_feed_pb2.DataFeedDesc() + with open(proto_file, 'r') as f: + text_format.Parse(f.read(), self.proto_desc) + if self.proto_desc.name == "MultiSlotDataFeed": + self.__name_to_index = { + slot.name: i + for i, slot in enumerate(self.proto_desc.multi_slot_desc.slots) + } + + def set_batch_size(self, batch_size): + """ + Set batch size. Will be effective during training + + Example: + >>> data_feed = fluid.DataFeedDesc('data.proto') + >>> data_feed.set_batch_size(128) + + Args: + batch_size: batch size + + """ + self.proto_desc.batch_size = batch_size + + def set_dense_slots(self, dense_slots_name): + """ + Set if a specific slot will be dense. Will be effective during training. + features for a dense slot will be fed into a Tensor, while those for a + sparse slot will be fed into a LoDTensor + + Example: + >>> data_feed = fluid.DataFeedDesc('data.proto') + >>> data_feed.set_dense_slots(['words']) + + Args: + dense_slots_name: a list of slot names which will be set dense + + Note: + Default is sparse for all slots + """ + if self.proto_desc.name != "MultiSlotDataFeed": + raise ValueError( + "Only MultiSlotDataFeed need set_dense_slots, pls check your datafeed.proto" + ) + for name in dense_slots_name: + self.proto_desc.multi_slot_desc.slots[self.__name_to_index[ + name]].is_dense = True + + def set_use_slots(self, use_slots_name): + """ + Set if a specific slot will be used for training. A dataset shall + contain a lot of features, through this function one can select which + ones will be used for a specific model. + + Example: + >>> data_feed = fluid.DataFeedDesc('data.proto') + >>> data_feed.set_use_slots(['words']) + + Args: + use_slots_name: a list of slot names which will be used in training + + Note: + Default is not used for all slots + """ + if self.proto_desc.name != "MultiSlotDataFeed": + raise ValueError( + "Only MultiSlotDataFeed need set_use_slots, pls check your datafeed.proto" + ) + for name in use_slots_name: + self.proto_desc.multi_slot_desc.slots[self.__name_to_index[ + name]].is_used = True + + def desc(self): + """ + Returns a protobuf message for this DataFeedDesc + + Example: + >>> data_feed = fluid.DataFeedDesc('data.proto') + >>> print(data_feed.desc()) + + Returns: + A string message + """ + return text_format.MessageToString(self.proto_desc) diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 288951cd7cd32155f136125fb817c35dd2ec6444..42c2484b284844a1f1acf53f79296e13da72676a 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -278,6 +278,7 @@ class Executor(object): p = core.Place() p.set_place(place) self.executor = core.Executor(p) + self.program_caches = dict() self._closed = False diff --git a/python/paddle/fluid/tests/demo/async_executor.py b/python/paddle/fluid/tests/demo/async_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..fe8da0aab74bd5fc6219666236a04423a6d60489 --- /dev/null +++ b/python/paddle/fluid/tests/demo/async_executor.py @@ -0,0 +1,100 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tarfile +import paddle.fluid as fluid +import paddle +from paddle.fluid import core + +URL = 'http://paddle-unittest-data.gz.bcebos.com/python_paddle_fluid_tests_demo_async-executor/train_data.tar.gz' +MD5 = '2a405a31508969b3ab823f42c0f522ca' + + +def bow_net(data, + label, + dict_dim=89528, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + BOW net + This model is from https://github.com/PaddlePaddle/models: + models/fluid/PaddleNLP/text_classification/nets.py + """ + # embedding + emb = fluid.layers.embedding( + input=data, size=[dict_dim, emb_dim], is_sparse=True) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bowh = fluid.layers.tanh(bow) + # fc layer after conv + fc_1 = fluid.layers.fc(input=bowh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + # probability of each class + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + # cross entropy loss + cost = fluid.layers.cross_entropy(input=prediction, label=label) + # mean loss + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + return avg_cost, acc, prediction + + +def train(): + # Download data + with tarfile.open(paddle.dataset.common.download(URL, "imdb", MD5)) as tarf: + tarf.extractall(path='./') + tarf.close() + + # Initialize dataset description + dataset = fluid.DataFeedDesc('train_data/data.prototxt') + dataset.set_batch_size(128) # See API doc for how to change other fields + print dataset.desc() # Debug purpose: see what we get + + # define network + # input text data + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + # label data + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost, acc, prediction = bow_net(data, label) + sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002) + opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost) + + # Run startup program + startup_program = fluid.default_startup_program() + place = fluid.CPUPlace() + executor = fluid.Executor(place) + executor.run(startup_program) + + async_executor = fluid.AsyncExecutor(place) + main_program = fluid.default_main_program() + epochs = 10 + filelist = ["train_data/part-%d" % i for i in range(12)] + for i in range(epochs): + thread_num = 4 + async_executor.run( + main_program, # This can be changed during iteration + dataset, # This can be changed during iteration + filelist, # This can be changed during iteration + thread_num, # This can be changed during iteration + [data, acc], # Multiple fetch targets can be specified + debug=False) + fluid.io.save_inference_model('imdb/epoch%d.model' % i, + [data.name, label.name], [acc], executor) + + +if __name__ == "__main__": + train() diff --git a/python/paddle/fluid/tests/unittests/test_async_executor.py b/python/paddle/fluid/tests/unittests/test_async_executor.py new file mode 100644 index 0000000000000000000000000000000000000000..43855b95f9e3096d58ca3e8acfdb25f034bab175 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_async_executor.py @@ -0,0 +1,142 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import paddle +import unittest +import tarfile +import os +import shutil + +proto_str = ('name: "MultiSlotDataFeed"\n' + 'batch_size: 2\n' + 'multi_slot_desc {\n' + ' slots {\n' + ' name: "words"\n' + ' type: "uint64"\n' + ' is_dense: false\n' + ' is_used: true\n' + ' }\n' + ' slots {\n' + ' name: "label"\n' + ' type: "uint64"\n' + ' is_dense: false\n' + ' is_used: true\n' + ' }\n' + '}') + +URL = 'http://paddle-unittest-data.gz.bcebos.com/python_paddle_fluid_tests_demo_async-executor/train_data.tar.gz' +MD5 = '2a405a31508969b3ab823f42c0f522ca' + + +def bow_net(data, + label, + dict_dim=89528, + emb_dim=128, + hid_dim=128, + hid_dim2=96, + class_dim=2): + """ + BOW net + This model is from https://github.com/PaddlePaddle/models: + models/fluid/PaddleNLP/text_classification/nets.py + """ + # embedding + emb = fluid.layers.embedding( + input=data, size=[dict_dim, emb_dim], is_sparse=True) + bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') + bowh = fluid.layers.tanh(bow) + # fc layer after conv + fc_1 = fluid.layers.fc(input=bowh, size=hid_dim, act="tanh") + fc_2 = fluid.layers.fc(input=fc_1, size=hid_dim2, act="tanh") + # probability of each class + prediction = fluid.layers.fc(input=[fc_2], size=class_dim, act="softmax") + # cross entropy loss + cost = fluid.layers.cross_entropy(input=prediction, label=label) + # mean loss + avg_cost = fluid.layers.mean(x=cost) + acc = fluid.layers.accuracy(input=prediction, label=label) + return avg_cost, acc, prediction + + +class TestAsyncExecutor(unittest.TestCase): + def setUp(self): + with open('./data.prototxt', 'w+') as f: + f.write(proto_str) + f.close() + + with tarfile.open(paddle.dataset.common.download(URL, "imdb", + MD5)) as tarf: + tarf.extractall(path='./') + tarf.close() + + def test_data_feed_desc(self): + data_feed = fluid.DataFeedDesc('./data.prototxt') + # assertEqueal(data_feed.proto_desc.batch, 2) + # assertEqual(len(data_feed.proto_desc.multi_slot_desc), 2) + self.assertEqual(" ".join(data_feed.desc().split()), + " ".join(proto_str.split())) + + def test_run(self): + # Initialize dataset description + data_feed = fluid.DataFeedDesc('train_data/data.prototxt') + data_feed.set_batch_size( + 128) # See API doc for how to change other fields + + # define network + # input text data + data = fluid.layers.data( + name="words", shape=[1], dtype="int64", lod_level=1) + # label data + label = fluid.layers.data(name="label", shape=[1], dtype="int64") + + avg_cost, acc, prediction = bow_net(data, label) + sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=0.002) + opt_ops, weight_and_grad = sgd_optimizer.minimize(avg_cost) + + # Run startup program + startup_program = fluid.default_startup_program() + place = fluid.CPUPlace() + executor = fluid.Executor(place) + executor.run(startup_program) + + main_program = fluid.default_main_program() + async_executor = fluid.AsyncExecutor(place) + + self.assertRaises(TypeError, async_executor.run) + self.assertRaises(TypeError, async_executor.run, main_program) + self.assertRaises(TypeError, async_executor.run, main_program, + data_feed) + + filelist = ['train_data/part-%d' % i for i in range(10)] + self.assertRaises(TypeError, async_executor.run, main_program, + data_feed, filelist) + + thread_num = 4 + self.assertRaises(TypeError, async_executor.run, main_program, + data_feed, filelist, thread_num) + + async_executor.run(main_program, data_feed, filelist, thread_num, [acc]) + fluid.io.save_inference_model("imdb.model", [data.name, label.name], + [acc], executor) + statinfo = os.stat('imdb.model/__model__') + self.assertGreater(statinfo.st_size, 0) + + os.remove('./data.prototxt') + shutil.rmtree('./train_data') + shutil.rmtree('./imdb.model') + + +if __name__ == '__main__': + unittest.main()