提交 15f5f10e 编写于 作者: Y Yu Yang

AddInput/AddOutput for OpHandle

上级 5368e50d
...@@ -88,7 +88,8 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo ...@@ -88,7 +88,8 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method) framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle) framework_proto backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle
fetch_op_handle)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
cc_library(var_handle SRCS var_handle.cc DEPS place) cc_library(var_handle SRCS var_handle.cc DEPS place)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context) cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes)
: data_(data), offset_(offset), local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
}
}
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
void FetchOpHandle::WaitAndMergeCPUTensors() const {
// Wait fetch stream done.
for (auto &ctx : dev_ctx_) {
ctx.second->Wait();
}
std::vector<const LoDTensor *> tensors_ptr;
tensors_ptr.reserve(tensors_.size());
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
void FetchOpHandle::RunImpl() {
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(this->dev_ctx_[var->place_]);
}
tensors_.resize(inputs_.size());
auto *var = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var->name_;
platform::CPUPlace cpu;
auto &scopes = *local_scopes_;
for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i];
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
#endif
} else {
tensors_[i].ShareDataWith(t);
tensors_[i].set_lod(t.lod());
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_;
FetchOpHandle(FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes);
~FetchOpHandle();
void Wait(platform::DeviceContext *waited_dev) override;
void WaitAndMergeCPUTensors() const;
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -79,6 +79,17 @@ void OpHandleBase::Wait(platform::DeviceContext *waited_dev) { ...@@ -79,6 +79,17 @@ void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
} }
#endif #endif
} }
void OpHandleBase::AddInput(VarHandleBase *in) {
this->inputs_.emplace_back(in);
in->pending_ops_.insert(this);
}
void OpHandleBase::AddOutput(VarHandleBase *out) {
outputs_.emplace_back(out);
out->generated_op_ = this;
}
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -39,6 +39,10 @@ struct OpHandleBase { ...@@ -39,6 +39,10 @@ struct OpHandleBase {
virtual void Wait(platform::DeviceContext *waited_dev); virtual void Wait(platform::DeviceContext *waited_dev);
void AddInput(VarHandleBase *in);
void AddOutput(VarHandleBase *out);
protected: protected:
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
}; };
......
...@@ -18,8 +18,11 @@ namespace paddle { ...@@ -18,8 +18,11 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope, ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place) platform::Place place,
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {} platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
dev_ctx_[place_] = dev_ctx;
}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {} ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -26,7 +27,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ...@@ -26,7 +27,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
Scope *scope_; Scope *scope_;
platform::Place place_; platform::Place place_;
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place); ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place,
platform::DeviceContext *context);
~ScaleLossGradOpHandle() final; ~ScaleLossGradOpHandle() final;
......
...@@ -17,77 +17,22 @@ limitations under the License. */ ...@@ -17,77 +17,22 @@ limitations under the License. */
#include "lod_tensor.h" #include "lod_tensor.h"
#include "lod_tensor_array.h" #include "lod_tensor_array.h"
#include "op_registry.h" #include "op_registry.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h" #include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/nccl_helper.h" #include "paddle/fluid/platform/nccl_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using details::DummyVarHandle; using details::DummyVarHandle;
using details::FetchOpHandle;
using details::OpHandleBase; using details::OpHandleBase;
using details::ScaleLossGradOpHandle; using details::ScaleLossGradOpHandle;
using details::VarHandle; using details::VarHandle;
using details::VarHandleBase; using details::VarHandleBase;
struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_;
~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
}
}
void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
void WaitAndMergeCPUTensors() const {
// Wait fetch stream done.
for (auto &ctx : dev_ctx_) {
ctx.second->Wait();
}
std::vector<const LoDTensor *> tensors_ptr;
tensors_ptr.reserve(tensors_.size());
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
protected:
void RunImpl() override {
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(this->dev_ctx_[var->place_]);
}
tensors_.resize(inputs_.size());
auto *var = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var->name_;
platform::CPUPlace cpu;
auto &scopes = *local_scopes_;
for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i];
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
} else {
tensors_[i].ShareDataWith(t);
tensors_[i].set_lod(t.lod());
}
}
}
};
class ParallelExecutorPrivate { class ParallelExecutorPrivate {
public: public:
explicit ParallelExecutorPrivate(size_t num_threads) explicit ParallelExecutorPrivate(size_t num_threads)
...@@ -99,19 +44,9 @@ class ParallelExecutorPrivate { ...@@ -99,19 +44,9 @@ class ParallelExecutorPrivate {
Scope *global_scope_; Scope *global_scope_;
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) { platform::PlaceHash>
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) { fetch_dev_ctxs_;
return const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
} else {
#ifdef PADDLE_WITH_CUDA
return nccl_ctxs_->DevCtx(place);
#else
PADDLE_THROW("Not compiled with CUDA")
#endif
}
}
platform::Place main_place_; platform::Place main_place_;
...@@ -119,6 +54,7 @@ class ParallelExecutorPrivate { ...@@ -119,6 +54,7 @@ class ParallelExecutorPrivate {
std::unordered_map<std::string, std::map<int, VarHandle>>, std::unordered_map<std::string, std::map<int, VarHandle>>,
platform::PlaceHash> platform::PlaceHash>
vars_; vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_; std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_; std::vector<std::unique_ptr<OpHandleBase>> ops_;
...@@ -183,10 +119,9 @@ class ParallelExecutorPrivate { ...@@ -183,10 +119,9 @@ class ParallelExecutorPrivate {
size_t version = vars.size(); size_t version = vars.size();
auto &var = vars[version]; auto &var = vars[version];
var.version_ = version; var.version_ = version;
var.generated_op_ = op_handle;
var.name_ = each_var_name; var.name_ = each_var_name;
var.place_ = place; var.place_ = place;
op_handle->outputs_.emplace_back(&var); op_handle->AddOutput(&var);
} }
}; // namespace framework }; // namespace framework
...@@ -198,7 +133,11 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -198,7 +133,11 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
explicit NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes, explicit NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs) const platform::NCCLContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {} : local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
for (auto &p : places_) {
this->dev_ctx_[p] = nccl_ctxs_.DevCtx(p);
}
}
void Wait(platform::DeviceContext *waited_dev) override { void Wait(platform::DeviceContext *waited_dev) override {
OpHandleBase::Wait(waited_dev); OpHandleBase::Wait(waited_dev);
...@@ -283,6 +222,17 @@ ParallelExecutor::ParallelExecutor( ...@@ -283,6 +222,17 @@ ParallelExecutor::ParallelExecutor(
: member_(new ParallelExecutorPrivate(num_threads)) { : member_(new ParallelExecutorPrivate(num_threads)) {
member_->places_ = places; member_->places_ = places;
member_->global_scope_ = scope; member_->global_scope_ = scope;
if (platform::is_cpu_place(places[0])) {
member_->fetch_dev_ctxs_[places[0]] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(places[0]));
} else {
for (auto &p : member_->places_) {
member_->fetch_dev_ctxs_[p] =
new platform::CUDADeviceContext(boost::get<platform::CUDAPlace>(p));
}
}
// Step 1. RunStartupProgram and Bcast the params to devs. // Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]); Executor exe(places[0]);
exe.Run(startup_program, scope, 0); exe.Run(startup_program, scope, 0);
...@@ -348,8 +298,7 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -348,8 +298,7 @@ void ParallelExecutor::ConstructDependencyGraph(
for (auto &each_var_name : var_names) { for (auto &each_var_name : var_names) {
VarHandle *var = member_->GetVarHandle(each_var_name, p); VarHandle *var = member_->GetVarHandle(each_var_name, p);
op_handle->inputs_.emplace_back(var); op_handle->AddInput(var);
var->pending_ops_.emplace(op_handle);
} }
var_names = op->OutputArgumentNames(); var_names = op->OutputArgumentNames();
...@@ -360,11 +309,10 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -360,11 +309,10 @@ void ParallelExecutor::ConstructDependencyGraph(
if (is_forwarding) { if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name) { if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle // Insert ScaleCost OpHandle
member_->ops_.emplace_back(new ScaleLossGradOpHandle( op_handle =
this->member_->local_scopes_.size(), s, p)); new ScaleLossGradOpHandle(this->member_->local_scopes_.size(), s,
op_handle = member_->ops_.back().get(); p, member_->nccl_ctxs_->DevCtx(p));
member_->ops_.emplace_back(op_handle);
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
// FIXME: Currently ScaleLossGradOp only use device_count as scale // FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators. // factor. So it does not depend on any other operators.
...@@ -399,15 +347,14 @@ void ParallelExecutor::ConstructDependencyGraph( ...@@ -399,15 +347,14 @@ void ParallelExecutor::ConstructDependencyGraph(
continue; continue;
} }
auto *prev_grad = &vars[vars.size() - 1]; auto *prev_grad = &vars[vars.size() - 1];
op_handle->inputs_.emplace_back(prev_grad); op_handle->AddInput(prev_grad);
prev_grad->pending_ops_.emplace(op_handle);
auto &var = vars[vars.size()]; auto &var = vars[vars.size()];
var.place_ = p; var.place_ = p;
var.generated_op_ = op_handle;
var.name_ = og; var.name_ = og;
var.version_ = vars.size() - 1; var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var);
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p); op_handle->AddOutput(&var);
} }
} }
} }
...@@ -454,12 +401,8 @@ void ParallelExecutor::PolishGraphToSupportDataHazards() const { ...@@ -454,12 +401,8 @@ void ParallelExecutor::PolishGraphToSupportDataHazards() const {
} }
auto *dep_var = new DummyVarHandle(); auto *dep_var = new DummyVarHandle();
read_op->AddOutput(dep_var);
dep_var->generated_op_ = read_op; write_op->AddInput(dep_var);
read_op->outputs_.emplace_back(dep_var);
dep_var->pending_ops_.emplace(write_op);
write_op->inputs_.emplace_back(dep_var);
member_->dep_vars_.emplace(dep_var); member_->dep_vars_.emplace(dep_var);
} }
} }
...@@ -561,24 +504,21 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -561,24 +504,21 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
for (size_t i = 0; i < fetch_tensors.size(); ++i) { for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i]; auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars[var_name]; auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back(); fetch_ops.emplace_back(&fetched_data, i, &member_->local_scopes_);
FetchOpHandle *op = &fetch_ops.back(); FetchOpHandle *op = &fetch_ops.back();
op->data_ = &fetched_data;
op->offset_ = i; // FIXME: Use new device context
op->local_scopes_ = &member_->local_scopes_;
for (auto &p : member_->places_) { for (auto &p : member_->places_) {
op->dev_ctx_[p] = member_->nccl_ctxs_->DevCtx(p); op->dev_ctx_[p] = member_->fetch_dev_ctxs_[p];
} }
for (auto *var : vars) { for (auto *var : vars) {
var->pending_ops_.emplace(op); op->AddInput(var);
op->inputs_.emplace_back(var);
} }
dummy_vars.emplace_back(); dummy_vars.emplace_back();
auto *var = &dummy_vars.back(); auto *var = &dummy_vars.back();
op->outputs_.emplace_back(var); op->AddOutput(var);
var->generated_op_ = op;
pending_vars[var] = false; pending_vars[var] = false;
pending_ops.insert({op, op->inputs_.size()}); pending_ops.insert({op, op->inputs_.size()});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册