提交 b57c0bf5 编写于 作者: D dongdaxiang

clear async executor interface and add data feed factory

上级 38798281
......@@ -194,6 +194,13 @@ void AsyncExecutor::CreateThreads(const ExecutorThreadWorker* worker,
worker->SetRootScope(root_scope);
}
shared_ptr<DataFeed> AsyncExecutor::CreateDataFeed(const char * feed_name) {
if (g_datafeed_map.count(string(feed_name)) < 1) {
return NULL;
}
return g_datafeed_map[feed_name]();
}
void AsyncExecutor::CheckFiles(
const std::vector<std::string>& files) {
// function for user to check file formats
......
......@@ -36,25 +36,35 @@ class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
~ExecutorThreadWorker() {}
/**
* Create thread level scope which is a child of root scope
*/
void CreateThreadScope(const framework::ProgramDesc& program);
void SetDataFeed(const DataFeed& datafeed);
void SetThreadId(int tid);
/**
* Create
*/
void CreateThreadOperators(const framework::ProgramDesc& program);
/**
* Set current root scope
*/
void SetRootScope(Scope* g_scope);
void SetDevice();
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
/**
* current DataFeed is defined in class
**/
void BindingDataFeedMemory();
void SetSparseCommData(const std::map<std::string, int>& param_names);
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
protected:
// thread index
std::shared_ptr<DataFeed> thread_reader_; // shared queue, thread buffer
int thread_id_;
// op name
// operator name
std::vector<std::string> op_names_;
// local ops for forward and backward
// thread level, local operators for forward and backward
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<framework::ProgramDesc> main_program_;
......@@ -72,7 +82,8 @@ class AsyncExecutor {
virtual ~AsyncExecutor() {}
void SetRootScope(const Scope* root_scope);
Scope* GetRootScope() { return root_scope_; }
void CheckFiles(const std::vector<std::string>& files);
void CheckFiles(
const std::vector<std::string>& files);
void RunFromFiles(
const ProgramDesc& main_program,
const std::vector<std::string>& files,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_feed_factory.h"
namespace paddle {
namespace framework {
typedef shared_ptr<DataFeed> (*Createdata_feedFunction)();
typedef std::unordered_map<std::string, Createdata_feedFunction> data_feedMap;
data_feedMap g_data_feed_map;
#define REGISTER_DATAFEED_CLASS(data_feed_class) \
namespace { \
shared_ptr<DataFeed> Creator_##data_feed_class() { \
return shared_ptr<DataFeed>(new data_feed_class); \
} \
class __Registerer_##data_feed_class { \
public: \
__Registerer_##data_feed_class() { \
g_data_feed_map[#data_feed_class] = &Creator_##data_feed_class; \
} \
}; \
__Registerer_##data_feed_class g_registerer_##data_feed_class; \
} // namespace
string DataFeedFactory::DataFeedTypeList() {
string data_feed_types;
for (auto iter = g_data_feed_map.begin();
iter != g_data_feed_map.end(); ++iter) {
if (iter != g_data_feed_map.begin()) {
data_feed_types += ", ";
}
data_feed_types += iter->first;
}
return data_feed_types;
}
shared_ptr<DataFeed> DataFeedFactory::CreateDataFeed(
const char* data_feed_class) {
if (g_data_feed_map.count(string(data_feed_class)) < 1) {
exit(-1);
}
return g_data_feed_map[data_feed_class]();
}
REGISTER_DATAFEED_CLASS(MultiSlotDataFeed);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#include <string>
#include "paddle/framework/data_feed.h"
namespace paddle {
namespace framework {
class DataFeedFactory {
public:
static std::string DataFeedTypeList();
static shared_ptr<DataFeed> CreateDataFeed(const char* data_feed_class);
};
} // namespace framework
} // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册