提交 15f5f10e 编写于 作者: Y Yu Yang

AddInput/AddOutput for OpHandle

上级 5368e50d
......@@ -88,7 +88,8 @@ cc_library(feed_fetch_method SRCS feed_fetch_method.cc DEPS lod_tensor scope glo
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS op_registry device_context scope
framework_proto backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle)
framework_proto backward glog lod_rank_table simple_threadpool scale_loss_grad_op_handle
fetch_op_handle)
cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
cc_library(var_handle SRCS var_handle.cc DEPS place)
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes)
: data_(data), offset_(offset), local_scopes_(local_scopes) {}
FetchOpHandle::~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
}
}
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
void FetchOpHandle::WaitAndMergeCPUTensors() const {
// Wait fetch stream done.
for (auto &ctx : dev_ctx_) {
ctx.second->Wait();
}
std::vector<const LoDTensor *> tensors_ptr;
tensors_ptr.reserve(tensors_.size());
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
void FetchOpHandle::RunImpl() {
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(this->dev_ctx_[var->place_]);
}
tensors_.resize(inputs_.size());
auto *var = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var->name_;
platform::CPUPlace cpu;
auto &scopes = *local_scopes_;
for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i];
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
#ifdef PADDLE_WITH_CUDA
TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
#endif
} else {
tensors_[i].ShareDataWith(t);
tensors_[i].set_lod(t.lod());
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_;
FetchOpHandle(FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes);
~FetchOpHandle();
void Wait(platform::DeviceContext *waited_dev) override;
void WaitAndMergeCPUTensors() const;
protected:
void RunImpl() override;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -79,6 +79,17 @@ void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
}
#endif
}
void OpHandleBase::AddInput(VarHandleBase *in) {
this->inputs_.emplace_back(in);
in->pending_ops_.insert(this);
}
void OpHandleBase::AddOutput(VarHandleBase *out) {
outputs_.emplace_back(out);
out->generated_op_ = this;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -39,6 +39,10 @@ struct OpHandleBase {
virtual void Wait(platform::DeviceContext *waited_dev);
void AddInput(VarHandleBase *in);
void AddOutput(VarHandleBase *out);
protected:
virtual void RunImpl() = 0;
};
......
......@@ -18,8 +18,11 @@ namespace paddle {
namespace framework {
namespace details {
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
platform::Place place)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {}
platform::Place place,
platform::DeviceContext *dev_ctx)
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
dev_ctx_[place_] = dev_ctx;
}
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
......@@ -26,7 +27,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
Scope *scope_;
platform::Place place_;
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place);
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place,
platform::DeviceContext *context);
~ScaleLossGradOpHandle() final;
......
......@@ -17,77 +17,22 @@ limitations under the License. */
#include "lod_tensor.h"
#include "lod_tensor_array.h"
#include "op_registry.h"
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/var_handle.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/nccl_helper.h"
namespace paddle {
namespace framework {
using details::DummyVarHandle;
using details::FetchOpHandle;
using details::OpHandleBase;
using details::ScaleLossGradOpHandle;
using details::VarHandle;
using details::VarHandleBase;
struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_;
~FetchOpHandle() {
for (auto *input_var : inputs_) {
input_var->pending_ops_.erase(this);
}
}
void Wait(platform::DeviceContext *waited_dev) override {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
}
void WaitAndMergeCPUTensors() const {
// Wait fetch stream done.
for (auto &ctx : dev_ctx_) {
ctx.second->Wait();
}
std::vector<const LoDTensor *> tensors_ptr;
tensors_ptr.reserve(tensors_.size());
for (auto &t : tensors_) {
tensors_ptr.emplace_back(&t);
}
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
}
protected:
void RunImpl() override {
for (auto *input : inputs_) {
auto *var = static_cast<VarHandle *>(input);
var->generated_op_->Wait(this->dev_ctx_[var->place_]);
}
tensors_.resize(inputs_.size());
auto *var = static_cast<VarHandle *>(inputs_[0]);
auto &var_name = var->name_;
platform::CPUPlace cpu;
auto &scopes = *local_scopes_;
for (size_t i = 0; i < scopes.size(); ++i) {
auto &scope = scopes[i];
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
if (platform::is_gpu_place(var->place_)) {
TensorCopy(t, cpu, *dev_ctx_[t.place()], &tensors_[i]);
} else {
tensors_[i].ShareDataWith(t);
tensors_[i].set_lod(t.lod());
}
}
}
};
class ParallelExecutorPrivate {
public:
explicit ParallelExecutorPrivate(size_t num_threads)
......@@ -99,19 +44,9 @@ class ParallelExecutorPrivate {
Scope *global_scope_;
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
platform::DeviceContext *CommunicationDevCtx(const platform::Place &place) {
if (platform::is_cpu_place(place) || local_scopes_.size() == 1) {
return const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(place));
} else {
#ifdef PADDLE_WITH_CUDA
return nccl_ctxs_->DevCtx(place);
#else
PADDLE_THROW("Not compiled with CUDA")
#endif
}
}
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
fetch_dev_ctxs_;
platform::Place main_place_;
......@@ -119,6 +54,7 @@ class ParallelExecutorPrivate {
std::unordered_map<std::string, std::map<int, VarHandle>>,
platform::PlaceHash>
vars_;
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
std::vector<std::unique_ptr<OpHandleBase>> ops_;
......@@ -183,10 +119,9 @@ class ParallelExecutorPrivate {
size_t version = vars.size();
auto &var = vars[version];
var.version_ = version;
var.generated_op_ = op_handle;
var.name_ = each_var_name;
var.place_ = place;
op_handle->outputs_.emplace_back(&var);
op_handle->AddOutput(&var);
}
}; // namespace framework
......@@ -198,7 +133,11 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
explicit NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs)
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {}
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
for (auto &p : places_) {
this->dev_ctx_[p] = nccl_ctxs_.DevCtx(p);
}
}
void Wait(platform::DeviceContext *waited_dev) override {
OpHandleBase::Wait(waited_dev);
......@@ -283,6 +222,17 @@ ParallelExecutor::ParallelExecutor(
: member_(new ParallelExecutorPrivate(num_threads)) {
member_->places_ = places;
member_->global_scope_ = scope;
if (platform::is_cpu_place(places[0])) {
member_->fetch_dev_ctxs_[places[0]] = const_cast<platform::DeviceContext *>(
platform::DeviceContextPool::Instance().Get(places[0]));
} else {
for (auto &p : member_->places_) {
member_->fetch_dev_ctxs_[p] =
new platform::CUDADeviceContext(boost::get<platform::CUDAPlace>(p));
}
}
// Step 1. RunStartupProgram and Bcast the params to devs.
Executor exe(places[0]);
exe.Run(startup_program, scope, 0);
......@@ -348,8 +298,7 @@ void ParallelExecutor::ConstructDependencyGraph(
for (auto &each_var_name : var_names) {
VarHandle *var = member_->GetVarHandle(each_var_name, p);
op_handle->inputs_.emplace_back(var);
var->pending_ops_.emplace(op_handle);
op_handle->AddInput(var);
}
var_names = op->OutputArgumentNames();
......@@ -360,11 +309,10 @@ void ParallelExecutor::ConstructDependencyGraph(
if (is_forwarding) {
if (var_names.size() == 1 && var_names[0] == loss_var_name) {
// Insert ScaleCost OpHandle
member_->ops_.emplace_back(new ScaleLossGradOpHandle(
this->member_->local_scopes_.size(), s, p));
op_handle = member_->ops_.back().get();
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
op_handle =
new ScaleLossGradOpHandle(this->member_->local_scopes_.size(), s,
p, member_->nccl_ctxs_->DevCtx(p));
member_->ops_.emplace_back(op_handle);
// FIXME: Currently ScaleLossGradOp only use device_count as scale
// factor. So it does not depend on any other operators.
......@@ -399,15 +347,14 @@ void ParallelExecutor::ConstructDependencyGraph(
continue;
}
auto *prev_grad = &vars[vars.size() - 1];
op_handle->inputs_.emplace_back(prev_grad);
prev_grad->pending_ops_.emplace(op_handle);
op_handle->AddInput(prev_grad);
auto &var = vars[vars.size()];
var.place_ = p;
var.generated_op_ = op_handle;
var.name_ = og;
var.version_ = vars.size() - 1;
op_handle->outputs_.emplace_back(&var);
op_handle->dev_ctx_[p] = member_->CommunicationDevCtx(p);
op_handle->AddOutput(&var);
}
}
}
......@@ -454,12 +401,8 @@ void ParallelExecutor::PolishGraphToSupportDataHazards() const {
}
auto *dep_var = new DummyVarHandle();
dep_var->generated_op_ = read_op;
read_op->outputs_.emplace_back(dep_var);
dep_var->pending_ops_.emplace(write_op);
write_op->inputs_.emplace_back(dep_var);
read_op->AddOutput(dep_var);
write_op->AddInput(dep_var);
member_->dep_vars_.emplace(dep_var);
}
}
......@@ -561,24 +504,21 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
for (size_t i = 0; i < fetch_tensors.size(); ++i) {
auto &var_name = fetch_tensors[i];
auto &vars = fetched_vars[var_name];
fetch_ops.emplace_back();
fetch_ops.emplace_back(&fetched_data, i, &member_->local_scopes_);
FetchOpHandle *op = &fetch_ops.back();
op->data_ = &fetched_data;
op->offset_ = i;
op->local_scopes_ = &member_->local_scopes_;
// FIXME: Use new device context
for (auto &p : member_->places_) {
op->dev_ctx_[p] = member_->nccl_ctxs_->DevCtx(p);
op->dev_ctx_[p] = member_->fetch_dev_ctxs_[p];
}
for (auto *var : vars) {
var->pending_ops_.emplace(op);
op->inputs_.emplace_back(var);
op->AddInput(var);
}
dummy_vars.emplace_back();
auto *var = &dummy_vars.back();
op->outputs_.emplace_back(var);
var->generated_op_ = op;
op->AddOutput(var);
pending_vars[var] = false;
pending_ops.insert({op, op->inputs_.size()});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册