提交 98c12b1a 编写于 作者: X Xin Pan 提交者: Yang Yang(Tony)

Clean up C++ codes. (#10022)

* Privatize OpHandleBase

* Clean up a few private members
上级 777cb55c
...@@ -29,9 +29,7 @@ namespace framework { ...@@ -29,9 +29,7 @@ namespace framework {
namespace details { namespace details {
struct BroadcastOpHandle : public OpHandleBase { struct BroadcastOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; public:
const std::vector<platform::Place> &places_;
BroadcastOpHandle(const std::vector<Scope *> &local_scopes, BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
...@@ -41,6 +39,10 @@ struct BroadcastOpHandle : public OpHandleBase { ...@@ -41,6 +39,10 @@ struct BroadcastOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
}; };
} // namespace details } // namespace details
......
...@@ -90,7 +90,7 @@ struct TestBroadcastOpHandle { ...@@ -90,7 +90,7 @@ struct TestBroadcastOpHandle {
op_handle_->AddInput(dummy_var_handle); op_handle_->AddInput(dummy_var_handle);
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]); VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle); vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
......
...@@ -28,8 +28,8 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ...@@ -28,8 +28,8 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
auto *cur_ctx = dev_ctxes_[place_]; auto *cur_ctx = dev_ctxes_[place_];
for (auto *in : inputs_) { for (auto *in : inputs_) {
bool need_wait = bool need_wait = in->generated_op_ &&
in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx; in->generated_op_->DeviceContext(place_) != cur_ctx;
if (need_wait) { if (need_wait) {
in->generated_op_->Wait(cur_ctx); in->generated_op_->Wait(cur_ctx);
} }
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -24,10 +27,7 @@ namespace paddle { ...@@ -24,10 +27,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct ComputationOpHandle : public OpHandleBase { struct ComputationOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; public:
Scope *scope_;
platform::Place place_;
ComputationOpHandle(const OpDesc &op_desc, Scope *scope, ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
platform::Place place); platform::Place place);
...@@ -35,6 +35,11 @@ struct ComputationOpHandle : public OpHandleBase { ...@@ -35,6 +35,11 @@ struct ComputationOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private:
std::unique_ptr<OperatorBase> op_;
Scope *scope_;
platform::Place place_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -24,11 +27,7 @@ namespace framework { ...@@ -24,11 +27,7 @@ namespace framework {
namespace details { namespace details {
struct FetchOpHandle : public OpHandleBase { struct FetchOpHandle : public OpHandleBase {
FeedFetchList *data_; public:
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_;
FetchOpHandle(FeedFetchList *data, size_t offset, FetchOpHandle(FeedFetchList *data, size_t offset,
std::vector<Scope *> *local_scopes); std::vector<Scope *> *local_scopes);
...@@ -42,6 +41,12 @@ struct FetchOpHandle : public OpHandleBase { ...@@ -42,6 +41,12 @@ struct FetchOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private:
FeedFetchList *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<LoDTensor> tensors_;
}; };
} // namespace details } // namespace details
......
...@@ -29,9 +29,7 @@ namespace framework { ...@@ -29,9 +29,7 @@ namespace framework {
namespace details { namespace details {
struct GatherOpHandle : public OpHandleBase { struct GatherOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_; public:
const std::vector<platform::Place> &places_;
GatherOpHandle(const std::vector<Scope *> &local_scopes, GatherOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places); const std::vector<platform::Place> &places);
...@@ -41,6 +39,10 @@ struct GatherOpHandle : public OpHandleBase { ...@@ -41,6 +39,10 @@ struct GatherOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
}; };
} // namespace details } // namespace details
......
...@@ -78,7 +78,7 @@ struct TestGatherOpHandle { ...@@ -78,7 +78,7 @@ struct TestGatherOpHandle {
op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_)); op_handle_.reset(new GatherOpHandle(local_scopes_, gpu_list_));
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); op_handle_->SetDeviceContext(gpu_list_[j], ctxs_[j].get());
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]); auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle); vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
......
...@@ -60,7 +60,8 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, ...@@ -60,7 +60,8 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
const platform::Place &p, const platform::Place &p,
const size_t &i) const { const size_t &i) const {
auto *op_handle = result->ops_.back().get(); auto *op_handle = result->ops_.back().get();
op_handle->dev_ctxes_[p] = platform::DeviceContextPool::Instance().Get(p); op_handle->SetDeviceContext(p,
platform::DeviceContextPool::Instance().Get(p));
auto var_names = op.InputArgumentNames(); auto var_names = op.InputArgumentNames();
......
...@@ -27,10 +27,6 @@ namespace framework { ...@@ -27,10 +27,6 @@ namespace framework {
namespace details { namespace details {
struct NCCLAllReduceOpHandle : public OpHandleBase { struct NCCLAllReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const platform::NCCLContextMap &nccl_ctxs_;
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes, NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
const platform::NCCLContextMap &ctxs); const platform::NCCLContextMap &ctxs);
...@@ -43,6 +39,11 @@ struct NCCLAllReduceOpHandle : public OpHandleBase { ...@@ -43,6 +39,11 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
const platform::NCCLContextMap &nccl_ctxs_;
}; };
} // namespace details } // namespace details
......
...@@ -27,28 +27,15 @@ namespace details { ...@@ -27,28 +27,15 @@ namespace details {
constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@"; constexpr char kLocalExecScopeName[] = "@LCOAL_SCOPE@";
class OpHandleBase { class OpHandleBase {
private:
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
public: public:
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctxes_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;
#endif
OpHandleBase() {} OpHandleBase() {}
virtual ~OpHandleBase();
std::string DebugString() const; std::string DebugString() const;
virtual std::string Name() const = 0; virtual std::string Name() const = 0;
virtual ~OpHandleBase();
void Run(bool use_event); void Run(bool use_event);
virtual void Wait(platform::DeviceContext *waited_dev); virtual void Wait(platform::DeviceContext *waited_dev);
...@@ -61,6 +48,18 @@ class OpHandleBase { ...@@ -61,6 +48,18 @@ class OpHandleBase {
// will likely block other computations. // will likely block other computations.
virtual bool IsMultiDeviceTransfer() { return false; } virtual bool IsMultiDeviceTransfer() { return false; }
const platform::DeviceContext *DeviceContext(platform::Place place) {
return dev_ctxes_[place];
}
void SetDeviceContext(platform::Place place, platform::DeviceContext *ctx_) {
dev_ctxes_[place] = ctx_;
}
const std::vector<VarHandleBase *> &Inputs() const { return inputs_; }
const std::vector<VarHandleBase *> &Outputs() const { return outputs_; }
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
...@@ -68,6 +67,18 @@ class OpHandleBase { ...@@ -68,6 +67,18 @@ class OpHandleBase {
const std::function<void()> &callback); const std::function<void()> &callback);
virtual void RunImpl() = 0; virtual void RunImpl() = 0;
std::vector<VarHandleBase *> inputs_;
std::vector<VarHandleBase *> outputs_;
std::unordered_map<platform::Place, platform::DeviceContext *,
platform::PlaceHash>
dev_ctxes_;
#ifdef PADDLE_WITH_CUDA
std::unordered_map<int, cudaEvent_t> events_;
#endif
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
}; };
} // namespace details } // namespace details
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -23,10 +25,6 @@ namespace framework { ...@@ -23,10 +25,6 @@ namespace framework {
namespace details { namespace details {
struct ScaleLossGradOpHandle : public OpHandleBase { struct ScaleLossGradOpHandle : public OpHandleBase {
float coeff_;
Scope *scope_;
platform::Place place_;
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place, ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place,
platform::DeviceContext *context); platform::DeviceContext *context);
...@@ -36,6 +34,11 @@ struct ScaleLossGradOpHandle : public OpHandleBase { ...@@ -36,6 +34,11 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private:
float coeff_;
Scope *scope_;
platform::Place place_;
}; };
} // namespace details } // namespace details
......
...@@ -28,10 +28,6 @@ namespace framework { ...@@ -28,10 +28,6 @@ namespace framework {
namespace details { namespace details {
struct SendOpHandle : public OpHandleBase { struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place); const platform::Place& place);
...@@ -43,6 +39,11 @@ struct SendOpHandle : public OpHandleBase { ...@@ -43,6 +39,11 @@ struct SendOpHandle : public OpHandleBase {
protected: protected:
void RunImpl() override; void RunImpl() override;
private:
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
}; };
} // namespace details } // namespace details
......
...@@ -117,12 +117,12 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { ...@@ -117,12 +117,12 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
std::string op_name = "op_" + std::to_string(op_id++); std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]" sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl; << std::endl;
for (auto in : op->inputs_) { for (auto in : op->Inputs()) {
std::string var_name = "var_" + std::to_string(vars[in]); std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl; sout << var_name << " -> " << op_name << std::endl;
} }
for (auto out : op->outputs_) { for (auto out : op->Outputs()) {
std::string var_name = "var_" + std::to_string(vars[out]); std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl; sout << op_name << " -> " << var_name << std::endl;
} }
...@@ -133,7 +133,7 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) { ...@@ -133,7 +133,7 @@ void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) { void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) { for (auto &op : graph->ops_) {
if (!op->outputs_.empty()) { if (!op->Outputs().empty()) {
continue; continue;
} }
auto *dummy_leaf = new DummyVarHandle(); auto *dummy_leaf = new DummyVarHandle();
......
...@@ -53,7 +53,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -53,7 +53,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
}; };
auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) { auto InsertPendingOp = [&pending_ops](OpHandleBase &op_instance) {
pending_ops.insert({&op_instance, op_instance.inputs_.size()}); pending_ops.insert({&op_instance, op_instance.Inputs().size()});
}; };
// Transform SSAGraph to pending_ops & pending_vars // Transform SSAGraph to pending_ops & pending_vars
...@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -69,7 +69,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
} }
for (auto &op : graph_->ops_) { for (auto &op : graph_->ops_) {
if (op->inputs_.empty()) { // Special case, Op has no input. if (op->Inputs().empty()) { // Special case, Op has no input.
ready_ops.insert(op.get()); ready_ops.insert(op.get());
} else { } else {
InsertPendingOp(*op); InsertPendingOp(*op);
...@@ -99,7 +99,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run( ...@@ -99,7 +99,7 @@ FeedFetchList ThreadedSSAGraphExecutor::Run(
fetch_ops.emplace_back(op); fetch_ops.emplace_back(op);
for (auto &p : places_) { for (auto &p : places_) {
op->dev_ctxes_[p] = fetch_ctxs_.Get(p); op->SetDeviceContext(p, fetch_ctxs_.Get(p));
} }
for (auto *var : vars) { for (auto *var : vars) {
...@@ -180,7 +180,7 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -180,7 +180,7 @@ void ThreadedSSAGraphExecutor::RunOp(
op->Run(use_event_); op->Run(use_event_);
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
ready_var_q->Extend(op->outputs_); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted"; VLOG(10) << op << " " << op->Name() << "Signal posted";
} catch (platform::EnforceNotMet ex) { } catch (platform::EnforceNotMet ex) {
exception_.reset(new platform::EnforceNotMet(ex)); exception_.reset(new platform::EnforceNotMet(ex));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册