提交 70d4b812 编写于 作者: D dongdaxiang

split async executor into executor_thread_worker and async_executor, refactor...

split async executor into executor_thread_worker and async_executor, refactor pybind, add datafeed and corresponding proto
上级 b57c0bf5
......@@ -28,6 +28,8 @@ limitations under the License. */
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
......@@ -40,167 +42,22 @@ limitations under the License. */
namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<Scope>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type);
}
}
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
}
}
}
void ExecutorThreadWorker::SetDataFeed(
const std::shared_ptr<DataFeed>& datafeed) {
local_reader_ = datafeed;
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void ExecutorThreadWorker::SetDevice() {
// at most 48 threads binding currently
static unsigned priority[] = {
0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47
};
unsigned int i = this->thread_id_;
if (i < sizeof(priority) / sizeof(unsigned)) {
unsigned proc = priority[i];
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(proc, &mask);
if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i;
} else {
CPU_ZERO(&mask);
if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
&& CPU_ISSET(proc, &mask)) {
LOG(ERROR) << "TRACE: Thread " << i <<
" is running on processor " << proc << "...";
}
}
}
}
void ExecutorThreadWorker::TrainFiles() {
// todo: configurable
SetDevice();
thread_reader_->Start();
while (int cur_batch = thread_reader_->Next()) {
// executor run here
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
thread_scope_->DropKids();
}
}
void ExecutorThreadWorker::SetThreadId(int tid) {
thread_id_ = tid;
}
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
AsyncExecutor::AsyncExecutor(const platform::Place& place) {
place_ = place;
}
void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) {
main_program_.reset(new ProgramDesc(main_program_desc));
}
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_ = max_epoch;
}
AsyncExecutor::AsyncExecutor(const platform::Place& place) : place_(place) {}
void AsyncExecutor::CreateThreads(const ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
const DataFeed& reader,
const Scope& root_scope,
const int thread_index) {
worker->SetThreadid(thread_index);
worker->CreateThreadOperators(main_program);
worker->CreateThreadScope(main_program);
worker->CreateThreadResource(main_program, place_);
worker->SetDataFeed(reader);
worker->BindingDataFeedMemory(reader);
worker->SetMainProgram(main_program);
worker->SetRootScope(root_scope);
}
shared_ptr<DataFeed> AsyncExecutor::CreateDataFeed(const char * feed_name) {
if (g_datafeed_map.count(string(feed_name)) < 1) {
return NULL;
}
return g_datafeed_map[feed_name]();
}
void AsyncExecutor::CheckFiles(
const std::vector<std::string>& files) {
// function for user to check file formats
......@@ -247,11 +104,19 @@ void AsyncExecutor::RunFromFiles(
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed> > readers;
readers.resize(thread_num);
for (auto& reader : readers) {
// create by factory name
reader.reset(new DataFeed);
reader.SetFileList(files);
for (int i = 0; i < readers.size(); ++i) {
readers[i] = DataFeedFactory::CreateDataFeed(data_feed_desc.name());
}
/*
std::vector<std::shared_ptr<ExecutorStrategy> > workers;
workers.resize(thread_num);
std::string str_name = strategy_.name;
for (auto& worker : workers) {
worker.reset(
ExecutorStrategyFactory::CreateExecutorStrategy(str_name));
}
*/
std::vector<std::shared_ptr<ExecutorThreadWorker> > workers;
workers.resize(thread_num);
......
......@@ -30,62 +30,17 @@ limitations under the License. */
namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);
class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
~ExecutorThreadWorker() {}
/**
* Create thread level scope which is a child of root scope
*/
void CreateThreadScope(const framework::ProgramDesc& program);
void SetThreadId(int tid);
/**
* Create
*/
void CreateThreadOperators(const framework::ProgramDesc& program);
/**
* Set current root scope
*/
void SetRootScope(Scope* g_scope);
void SetDevice();
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
/**
* current DataFeed is defined in class
**/
void BindingDataFeedMemory();
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
protected:
// thread index
std::shared_ptr<DataFeed> thread_reader_; // shared queue, thread buffer
int thread_id_;
// operator name
std::vector<std::string> op_names_;
// thread level, local operators for forward and backward
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<framework::ProgramDesc> main_program_;
// execution place
platform::Place place_;
// root scope for model parameters
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
};
class AsyncExecutor {
public:
explicit AsyncExecutor(const platform::Place& place);
virtual ~AsyncExecutor() {}
void SetRootScope(const Scope* root_scope);
Scope* GetRootScope() { return root_scope_; }
void CheckFiles(
const std::vector<std::string>& files);
void RunFromFiles(
const ProgramDesc& main_program,
const DataFeedDesc& data_feed_desc,
const std::vector<std::string>& files,
const int thread_num);
......
......@@ -47,6 +47,21 @@ struct Instance {
std::vector<Gauc> gauc_vec;
};
class DataFeed {
DataFeed() {}
virtual ~DataFeed() {}
};
class BlockingQueueDataFeed : DataFeed {
BlockingQueueDataFeed() {}
virtual ~BlockingQueueDataFeed() {}
};
class ThreadedDataFeed : DataFeed {
ThreadedDataFeed() {}
virtual ~ThreadedDataFeed() {}
};
class DataFeed {
public:
DataFeed() {}
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
message DataFeedDesc {
optional string name = 1;
optional int32 batch = 2 [default = 32];
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/executor_thread_worker.h"
#include <stdio.h>
#include <string.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <map>
#include <algorithm>
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/pybind/pybind.h"
namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type) {
if (var_type == proto::VarType::LOD_TENSOR) {
var->GetMutable<LoDTensor>();
} else if (var_type == proto::VarType::SELECTED_ROWS) {
var->GetMutable<SelectedRows>();
} else if (var_type == proto::VarType::FEED_MINIBATCH) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::FETCH_LIST) {
var->GetMutable<FeedFetchList>();
} else if (var_type == proto::VarType::STEP_SCOPES) {
var->GetMutable<std::vector<Scope>>();
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarType::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
var_type);
}
}
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
op_names_.clear();
for (auto& op_desc : block.AllOps()) {
std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
op_names_.push_back(op_desc->Type());
OperatorBase* local_op_ptr = local_op.release();
ops_.push_back(local_op_ptr);
continue;
}
}
void ExecutorThreadWorker::CreateThreadResource(
const framework::ProgramDesc& program,
const paddle::platform::Place& place) {
CreateThreadScope(program);
CreateThreadOperators(program);
SetMainProgram(program);
SetPlace(place);
}
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
auto& block = program.Block(0);
thread_scope_ = &root_scope_->NewScope();
for (auto& var : block.AllVars()) {
if (var->Persistable()) {
auto* ptr = root_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
} else {
auto* ptr = thread_scope_->Var(var->Name());
CreateTensor(ptr, var->GetType());
}
}
}
void ExecutorThreadWorker::SetDataFeed(
const std::shared_ptr<DataFeed>& datafeed) {
local_reader_ = datafeed;
}
void ExecutorThreadWorker::BindingDataFeedMemory() {
const std::vector<std::string>& input_feed =
thread_reader_->GetUseSlotAlias();
for (auto name : input_feed) {
local_reader_->AddFeedVar(thread_scope_->Var(name), name);
}
}
void ExecutorThreadWorker::SetDevice() {
// at most 48 threads binding currently
static unsigned priority[] = {
0, 1, 2, 3, 4, 5,
6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35,
36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47
};
unsigned int i = this->thread_id_;
if (i < sizeof(priority) / sizeof(unsigned)) {
unsigned proc = priority[i];
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(proc, &mask);
if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i;
} else {
CPU_ZERO(&mask);
if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
&& CPU_ISSET(proc, &mask)) {
LOG(ERROR) << "TRACE: Thread " << i <<
" is running on processor " << proc << "...";
}
}
}
}
void ExecutorThreadWorker::TrainFiles() {
// todo: configurable
SetDevice();
thread_reader_->Start();
while (int cur_batch = thread_reader_->Next()) {
// executor run here
for (auto& op : ops_) {
op->Run(*thread_scope_, place_);
}
thread_scope_->DropKids();
}
}
void ExecutorThreadWorker::SetThreadId(int tid) {
thread_id_ = tid;
}
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
place_ = place;
}
void ExecutorThreadWorker::SetMainProgram(
const ProgramDesc& main_program_desc) {
main_program_.reset(new ProgramDesc(main_program_desc));
}
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
} // einit_modelnd namespace framework
} // end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_FRAMEWORK_EXECUTOR_THREAD_WORKER_H_
#define PADDLE_FLUID_FRAMEWORK_EXECUTOR_THREAD_WORKER_H_
#include <memory>
#include <mutex> // NOLINT
#include <set>
#include <map>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);
class ExecutorThreadWorker {
public:
ExecutorThreadWorker() {}
~ExecutorThreadWorker() {}
void CreateThreadResource(const framework::ProgramDesc& program,
const paddle::platform::Place& place);
void SetThreadId(int tid);
void SetRootScope(Scope* g_scope);
void SetDevice();
void BindingDataFeedMemory();
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
private:
void CreateThreadScope(const framework::ProgramDesc& program);
void CreateThreadOperators(const framework::ProgramDesc& program);
void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place);
protected:
// thread index
std::shared_ptr<DataFeed> thread_reader_; // shared queue, thread buffer
int thread_id_;
// operator name
std::vector<std::string> op_names_;
// thread level, local operators for forward and backward
std::vector<OperatorBase *> ops_;
// main program for training
std::unique_ptr<framework::ProgramDesc> main_program_;
// execution place
platform::Place place_;
// root scope for model parameters
Scope* root_scope_;
// a thread scope, father scope is global score which is shared
Scope* thread_scope_;
};
} // namespace framework
} // namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <string>
#include <vector>
#include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/pybind/async_executor_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindAsyncExecutor(py::module* m) {
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def("run_from_files", &framework::AsyncExecutor::RunFromFiles)
.def("check_files", &framework::AsyncExecutor::CheckFiles);
} // end BindAsyncExecutor
} // end namespace pybind
} // end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册