提交 f410622f 编写于 作者: Y Yang Yang

merge follow comment

...@@ -32,7 +32,68 @@ namespace framework { ...@@ -32,7 +32,68 @@ namespace framework {
const std::string kFeedOpType = "feed"; const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch"; const std::string kFetchOpType = "fetch";
std::vector<bool> Prune(const ProgramDesc& pdesc, int block_id) { Executor::Executor(const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_[i] = new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i]));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i]));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
Executor::~Executor() {
for (auto& device_context : device_contexts_) {
delete device_context;
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id);
auto& block = pdesc.blocks(block_id);
auto& device = device_contexts_[0];
// Instantiate all the vars in the global scope
for (auto& var : block.vars()) {
scope->NewVar(var.name());
}
Scope& local_scope = scope->NewScope();
std::vector<bool> should_run = Prune(pdesc, block_id);
PADDLE_ENFORCE_EQ(should_run.size(), static_cast<size_t>(block.ops_size()));
for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) {
for (auto& var : block.ops(i).outputs()) {
for (auto& argu : var.arguments()) {
if (local_scope.FindVar(argu) == nullptr) {
local_scope.NewVar(argu);
}
}
}
auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i));
op->Run(local_scope, *device);
}
}
// TODO(tonyyang-svail):
// - Destroy local_scope
}
std::vector<bool> Executor::Prune(const ProgramDesc& pdesc, int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
......
...@@ -66,7 +66,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, ...@@ -66,7 +66,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
template <typename T> template <typename T>
void SetFeedVariable(const std::vector<std::vector<T>>& inputs, void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
const std::vector<std::vector<int64_t>>& dims) { const std::vector<std::vector<int64_t>>& dims) {
Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); Variable* g_feed_value = GetGlobalScope().FindVar("feed_value");
auto& feed_inputs = auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>()); *(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>());
size_t size = inputs.size(); size_t size = inputs.size();
...@@ -81,7 +81,7 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs, ...@@ -81,7 +81,7 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
// So we can memcpy the data from fetch_value to vector<T> // So we can memcpy the data from fetch_value to vector<T>
template <typename T> template <typename T>
std::vector<std::vector<T>> GetFetchVariable() { std::vector<std::vector<T>> GetFetchVariable() {
Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value");
auto& fetch_outputs = auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>()); *(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>());
...@@ -231,8 +231,9 @@ TEST_F(ExecutorTesterRandom, CPU) { ...@@ -231,8 +231,9 @@ TEST_F(ExecutorTesterRandom, CPU) {
std::unique_ptr<Executor> executor(new Executor(places)); std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0); executor->Run(init_pdesc_, &GetGlobalScope(), 0);
executor->Run(pdesc_, GetGlobalScope(), 0); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
} }
...@@ -251,7 +252,7 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) { ...@@ -251,7 +252,7 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) {
for (int batch_id = 0; batch_id < 3; batch_id++) { for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0); executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) { for (size_t i = 0; i < result.size(); ++i) {
...@@ -279,10 +280,10 @@ TEST_F(ExecutorTesterRandom, GPU) { ...@@ -279,10 +280,10 @@ TEST_F(ExecutorTesterRandom, GPU) {
std::unique_ptr<Executor> executor(new Executor(places)); std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, GetGlobalScope(), 0); executor->Run(init_pdesc_, &GetGlobalScope(), 0);
for (int batch_id = 0; batch_id < 3; batch_id++) { for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0); executor->Run(pdesc_, &GetGlobalScope(), 0);
} }
} }
...@@ -303,7 +304,7 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) { ...@@ -303,7 +304,7 @@ TEST_F(ExecutorTesterFeedAndFetch, GPU) {
for (int batch_id = 0; batch_id < 3; batch_id++) { for (int batch_id = 0; batch_id < 3; batch_id++) {
SetFeedVariable<float>(inputs_, dims_); SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, GetGlobalScope(), 0); executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>(); std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) { for (size_t i = 0; i < result.size(); ++i) {
......
...@@ -67,14 +67,14 @@ void Scope::DropKids() { ...@@ -67,14 +67,14 @@ void Scope::DropKids() {
std::once_flag feed_variable_flag; std::once_flag feed_variable_flag;
framework::Scope* GetGlobalScope() { framework::Scope& GetGlobalScope() {
static std::unique_ptr<framework::Scope> g_scope{nullptr}; static std::unique_ptr<framework::Scope> g_scope{nullptr};
std::call_once(feed_variable_flag, [&]() { std::call_once(feed_variable_flag, [&]() {
g_scope.reset(new framework::Scope()); g_scope.reset(new framework::Scope());
g_scope->NewVar("feed_value"); g_scope->NewVar("feed_value");
g_scope->NewVar("fetch_value"); g_scope->NewVar("fetch_value");
}); });
return g_scope.get(); return *(g_scope.get());
} }
} // namespace framework } // namespace framework
......
...@@ -73,7 +73,7 @@ class Scope { ...@@ -73,7 +73,7 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
}; };
framework::Scope* GetGlobalScope(); framework::Scope& GetGlobalScope();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,7 @@ class FeedKernel : public framework::OpKernel<T> { ...@@ -26,7 +26,7 @@ class FeedKernel : public framework::OpKernel<T> {
framework::Tensor* out = ctx.Output<framework::Tensor>("Out"); framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable = framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value"); framework::GetGlobalScope().FindVar("feed_value");
const auto& tensors = const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>(); g_feed_variable->Get<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
......
...@@ -25,7 +25,7 @@ class FetchKernel : public framework::OpKernel<T> { ...@@ -25,7 +25,7 @@ class FetchKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input"); const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
framework::Variable* g_fetch_variable = framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value"); framework::GetGlobalScope().FindVar("fetch_value");
auto* tensors = auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>(); g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册