提交 64c139e8 编写于 作者: Y Yu Yang

Using constructor for VarHandle

上级 64bf3df0
...@@ -77,14 +77,9 @@ struct TestBroadcastOpHandle { ...@@ -77,14 +77,9 @@ struct TestBroadcastOpHandle {
local_scopes_[input_scope_idx]->Var("input"); local_scopes_[input_scope_idx]->Var("input");
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_)); op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
auto* in_var_handle =
vars_.emplace_back(new VarHandle()); new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
VarHandle* in_var_handle = static_cast<VarHandle*>(vars_.back().get()); vars_.emplace_back(in_var_handle);
in_var_handle->place_ = gpu_list_[input_scope_idx];
in_var_handle->name_ = "input";
in_var_handle->version_ = 1;
in_var_handle->scope_idx_ = input_scope_idx;
in_var_handle->generated_op_ = nullptr;
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
// add dummy var // add dummy var
...@@ -96,12 +91,8 @@ struct TestBroadcastOpHandle { ...@@ -96,12 +91,8 @@ struct TestBroadcastOpHandle {
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
vars_.emplace_back(new VarHandle()); VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
VarHandle* out_var_handle = static_cast<VarHandle*>(vars_.back().get()); vars_.emplace_back(out_var_handle);
out_var_handle->place_ = gpu_list_[j];
out_var_handle->name_ = "out";
out_var_handle->version_ = 2;
out_var_handle->scope_idx_ = j;
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
} }
......
...@@ -79,13 +79,8 @@ struct TestGatherOpHandle { ...@@ -79,13 +79,8 @@ struct TestGatherOpHandle {
// add input // add input
for (size_t j = 0; j < gpu_list_.size(); ++j) { for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get(); op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
vars_.emplace_back(new VarHandle()); auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
VarHandle* in_var_handle = static_cast<VarHandle*>(vars_.back().get()); vars_.emplace_back(in_var_handle);
in_var_handle->place_ = gpu_list_[j];
in_var_handle->name_ = "input";
in_var_handle->version_ = 1;
in_var_handle->scope_idx_ = j;
in_var_handle->generated_op_ = nullptr;
op_handle_->AddInput(in_var_handle); op_handle_->AddInput(in_var_handle);
} }
...@@ -97,12 +92,9 @@ struct TestGatherOpHandle { ...@@ -97,12 +92,9 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle); op_handle_->AddInput(in_dummy_var_handle);
// add output // add output
vars_.emplace_back(new VarHandle()); auto* out_var_handle =
VarHandle* out_var_handle = static_cast<VarHandle*>(vars_.back().get()); new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
out_var_handle->place_ = gpu_list_[input_scope_idx]; vars_.emplace_back(out_var_handle);
out_var_handle->name_ = "out";
out_var_handle->version_ = 2;
out_var_handle->scope_idx_ = input_scope_idx;
op_handle_->AddOutput(out_var_handle); op_handle_->AddOutput(out_var_handle);
// add dummy var // add dummy var
......
...@@ -177,13 +177,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -177,13 +177,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto &prev_grad = vars[vars.size() - 1]; auto &prev_grad = vars[vars.size() - 1];
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
vars.emplace_back(new VarHandle); auto var = new VarHandle(vars.size() - 1, i, og, p);
auto &var = vars.back(); vars.emplace_back(var);
var->place_ = p; op_handle->AddOutput(var);
var->name_ = og;
var->version_ = vars.size() - 1;
op_handle->AddOutput(var.get());
} }
#else #else
PADDLE_ENFORCE("Not implemented"); PADDLE_ENFORCE("Not implemented");
......
...@@ -54,13 +54,8 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ...@@ -54,13 +54,8 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name]; auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr; VarHandle *var = nullptr;
if (var_holder.empty()) { if (var_holder.empty()) {
var_holder.emplace_back(new VarHandle); var = new VarHandle(0, place_offset, each_var_name, place);
auto &init_var = var_holder[0]; var_holder.emplace_back(var);
init_var->place_ = place;
init_var->name_ = each_var_name;
init_var->generated_op_ = nullptr;
init_var->version_ = 0;
var = init_var.get();
} else { } else {
var = var_holder.rbegin()->get(); var = var_holder.rbegin()->get();
} }
...@@ -73,12 +68,9 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle, ...@@ -73,12 +68,9 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) { size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name]; auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size(); size_t version = vars.size();
vars.emplace_back(new VarHandle()); auto var = new VarHandle(version, place_offset, each_var_name, place);
auto &var = vars.back(); vars.emplace_back(var);
var->version_ = version; op_handle->AddOutput(var);
var->name_ = each_var_name;
var->place_ = place;
op_handle->AddOutput(var.get());
} }
template <typename Callback> template <typename Callback>
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
#include <utility>
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -33,10 +34,10 @@ struct VarHandleBase { ...@@ -33,10 +34,10 @@ struct VarHandleBase {
// The operator who generate this variable. nullptr if the variable // The operator who generate this variable. nullptr if the variable
// is a root node. // is a root node.
OpHandleBase *generated_op_; OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready. // Operators which depend on this variable ready.
std::unordered_set<OpHandleBase *> pending_ops_; std::unordered_set<OpHandleBase*> pending_ops_;
}; };
// VarHandle is actually a single version of Runtime Variable. // VarHandle is actually a single version of Runtime Variable.
...@@ -47,6 +48,13 @@ struct VarHandleBase { ...@@ -47,6 +48,13 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase { struct VarHandle : public VarHandleBase {
std::string DebugString() const override; std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name,
platform::Place place)
: version_(version),
scope_idx_(scope_index),
name_(std::move(name)),
place_(std::move(place)) {}
// version field currently is not used, however, just store the version to // version field currently is not used, however, just store the version to
// debug easily. // debug easily.
size_t version_; size_t version_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册