From b57c0bf5758ff0f377e96cae67ad9f2da1c5a967 Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Tue, 13 Nov 2018 10:45:50 +0800 Subject: [PATCH] clear async executor interface and add data feed factory --- .../framework/async_executor_refactor.cc | 7 +++ .../fluid/framework/async_executor_refactor.h | 21 +++++-- paddle/fluid/framework/data_feed_factory.cc | 60 +++++++++++++++++++ paddle/fluid/framework/data_feed_factory.h | 31 ++++++++++ 4 files changed, 114 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/framework/data_feed_factory.cc create mode 100644 paddle/fluid/framework/data_feed_factory.h diff --git a/paddle/fluid/framework/async_executor_refactor.cc b/paddle/fluid/framework/async_executor_refactor.cc index d59b1355b56..36397ff2e8e 100644 --- a/paddle/fluid/framework/async_executor_refactor.cc +++ b/paddle/fluid/framework/async_executor_refactor.cc @@ -194,6 +194,13 @@ void AsyncExecutor::CreateThreads(const ExecutorThreadWorker* worker, worker->SetRootScope(root_scope); } +shared_ptr AsyncExecutor::CreateDataFeed(const char * feed_name) { + if (g_datafeed_map.count(string(feed_name)) < 1) { + return NULL; + } + return g_datafeed_map[feed_name](); +} + void AsyncExecutor::CheckFiles( const std::vector& files) { // function for user to check file formats diff --git a/paddle/fluid/framework/async_executor_refactor.h b/paddle/fluid/framework/async_executor_refactor.h index 6245e75517f..90812f2de85 100644 --- a/paddle/fluid/framework/async_executor_refactor.h +++ b/paddle/fluid/framework/async_executor_refactor.h @@ -36,25 +36,35 @@ class ExecutorThreadWorker { public: ExecutorThreadWorker() {} ~ExecutorThreadWorker() {} + /** + * Create thread level scope which is a child of root scope + */ void CreateThreadScope(const framework::ProgramDesc& program); - void SetDataFeed(const DataFeed& datafeed); void SetThreadId(int tid); + /** + * Create + */ void CreateThreadOperators(const framework::ProgramDesc& program); + /** + * Set current root scope + */ void SetRootScope(Scope* g_scope); void SetDevice(); void SetMainProgram(const ProgramDesc& main_program_desc); void SetPlace(const paddle::platform::Place& place); + /** + * current DataFeed is defined in class + **/ void BindingDataFeedMemory(); - void SetSparseCommData(const std::map& param_names); void SetDataFeed(const std::shared_ptr& datafeed); protected: // thread index std::shared_ptr thread_reader_; // shared queue, thread buffer int thread_id_; - // op name + // operator name std::vector op_names_; - // local ops for forward and backward + // thread level, local operators for forward and backward std::vector ops_; // main program for training std::unique_ptr main_program_; @@ -72,7 +82,8 @@ class AsyncExecutor { virtual ~AsyncExecutor() {} void SetRootScope(const Scope* root_scope); Scope* GetRootScope() { return root_scope_; } - void CheckFiles(const std::vector& files); + void CheckFiles( + const std::vector& files); void RunFromFiles( const ProgramDesc& main_program, const std::vector& files, diff --git a/paddle/fluid/framework/data_feed_factory.cc b/paddle/fluid/framework/data_feed_factory.cc new file mode 100644 index 00000000000..b07f770a584 --- /dev/null +++ b/paddle/fluid/framework/data_feed_factory.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/data_feed_factory.h" + +namespace paddle { +namespace framework { +typedef shared_ptr (*Createdata_feedFunction)(); +typedef std::unordered_map data_feedMap; +data_feedMap g_data_feed_map; + +#define REGISTER_DATAFEED_CLASS(data_feed_class) \ + namespace { \ + shared_ptr Creator_##data_feed_class() { \ + return shared_ptr(new data_feed_class); \ + } \ + class __Registerer_##data_feed_class { \ + public: \ + __Registerer_##data_feed_class() { \ + g_data_feed_map[#data_feed_class] = &Creator_##data_feed_class; \ + } \ + }; \ + __Registerer_##data_feed_class g_registerer_##data_feed_class; \ + } // namespace + + +string DataFeedFactory::DataFeedTypeList() { + string data_feed_types; + for (auto iter = g_data_feed_map.begin(); + iter != g_data_feed_map.end(); ++iter) { + if (iter != g_data_feed_map.begin()) { + data_feed_types += ", "; + } + data_feed_types += iter->first; + } + return data_feed_types; +} + +shared_ptr DataFeedFactory::CreateDataFeed( + const char* data_feed_class) { + if (g_data_feed_map.count(string(data_feed_class)) < 1) { + exit(-1); + } + return g_data_feed_map[data_feed_class](); +} + +REGISTER_DATAFEED_CLASS(MultiSlotDataFeed); +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/data_feed_factory.h b/paddle/fluid/framework/data_feed_factory.h new file mode 100644 index 00000000000..af203001c54 --- /dev/null +++ b/paddle/fluid/framework/data_feed_factory.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_ +#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_ + +#include +#include "paddle/framework/data_feed.h" + +namespace paddle { +namespace framework { +class DataFeedFactory { + public: + static std::string DataFeedTypeList(); + static shared_ptr CreateDataFeed(const char* data_feed_class); +}; +} // namespace framework +} // namespace paddle + +#endif // PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_ -- GitLab