提交 64c139e8 编写于 作者: Y Yu Yang

Using constructor for VarHandle

上级 64bf3df0
......@@ -77,14 +77,9 @@ struct TestBroadcastOpHandle {
local_scopes_[input_scope_idx]->Var("input");
op_handle_.reset(new BroadcastOpHandle(local_scopes_, gpu_list_));
vars_.emplace_back(new VarHandle());
VarHandle* in_var_handle = static_cast<VarHandle*>(vars_.back().get());
in_var_handle->place_ = gpu_list_[input_scope_idx];
in_var_handle->name_ = "input";
in_var_handle->version_ = 1;
in_var_handle->scope_idx_ = input_scope_idx;
in_var_handle->generated_op_ = nullptr;
auto* in_var_handle =
new VarHandle(1, input_scope_idx, "input", gpu_list_[input_scope_idx]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
// add dummy var
......@@ -96,12 +91,8 @@ struct TestBroadcastOpHandle {
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
vars_.emplace_back(new VarHandle());
VarHandle* out_var_handle = static_cast<VarHandle*>(vars_.back().get());
out_var_handle->place_ = gpu_list_[j];
out_var_handle->name_ = "out";
out_var_handle->version_ = 2;
out_var_handle->scope_idx_ = j;
VarHandle* out_var_handle = new VarHandle(2, j, "out", gpu_list_[j]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
}
......
......@@ -79,13 +79,8 @@ struct TestGatherOpHandle {
// add input
for (size_t j = 0; j < gpu_list_.size(); ++j) {
op_handle_->dev_ctxes_[gpu_list_[j]] = ctxs_[j].get();
vars_.emplace_back(new VarHandle());
VarHandle* in_var_handle = static_cast<VarHandle*>(vars_.back().get());
in_var_handle->place_ = gpu_list_[j];
in_var_handle->name_ = "input";
in_var_handle->version_ = 1;
in_var_handle->scope_idx_ = j;
in_var_handle->generated_op_ = nullptr;
auto* in_var_handle = new VarHandle(1, j, "input", gpu_list_[j]);
vars_.emplace_back(in_var_handle);
op_handle_->AddInput(in_var_handle);
}
......@@ -97,12 +92,9 @@ struct TestGatherOpHandle {
op_handle_->AddInput(in_dummy_var_handle);
// add output
vars_.emplace_back(new VarHandle());
VarHandle* out_var_handle = static_cast<VarHandle*>(vars_.back().get());
out_var_handle->place_ = gpu_list_[input_scope_idx];
out_var_handle->name_ = "out";
out_var_handle->version_ = 2;
out_var_handle->scope_idx_ = input_scope_idx;
auto* out_var_handle =
new VarHandle(2, input_scope_idx, "out", gpu_list_[input_scope_idx]);
vars_.emplace_back(out_var_handle);
op_handle_->AddOutput(out_var_handle);
// add dummy var
......
......@@ -177,13 +177,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto &prev_grad = vars[vars.size() - 1];
op_handle->AddInput(prev_grad.get());
vars.emplace_back(new VarHandle);
auto &var = vars.back();
var->place_ = p;
var->name_ = og;
var->version_ = vars.size() - 1;
op_handle->AddOutput(var.get());
auto var = new VarHandle(vars.size() - 1, i, og, p);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
#else
PADDLE_ENFORCE("Not implemented");
......
......@@ -54,13 +54,8 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto &var_holder = var_holders[each_var_name];
VarHandle *var = nullptr;
if (var_holder.empty()) {
var_holder.emplace_back(new VarHandle);
auto &init_var = var_holder[0];
init_var->place_ = place;
init_var->name_ = each_var_name;
init_var->generated_op_ = nullptr;
init_var->version_ = 0;
var = init_var.get();
var = new VarHandle(0, place_offset, each_var_name, place);
var_holder.emplace_back(var);
} else {
var = var_holder.rbegin()->get();
}
......@@ -73,12 +68,9 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t place_offset) {
auto &vars = graph->vars_[place_offset][each_var_name];
size_t version = vars.size();
vars.emplace_back(new VarHandle());
auto &var = vars.back();
var->version_ = version;
var->name_ = each_var_name;
var->place_ = place;
op_handle->AddOutput(var.get());
auto var = new VarHandle(version, place_offset, each_var_name, place);
vars.emplace_back(var);
op_handle->AddOutput(var);
}
template <typename Callback>
......
......@@ -16,6 +16,7 @@
#include <sstream>
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/platform/place.h"
......@@ -33,10 +34,10 @@ struct VarHandleBase {
// The operator who generate this variable. nullptr if the variable
// is a root node.
OpHandleBase *generated_op_;
OpHandleBase* generated_op_{nullptr};
// Operators which depend on this variable ready.
std::unordered_set<OpHandleBase *> pending_ops_;
std::unordered_set<OpHandleBase*> pending_ops_;
};
// VarHandle is actually a single version of Runtime Variable.
......@@ -47,6 +48,13 @@ struct VarHandleBase {
struct VarHandle : public VarHandleBase {
std::string DebugString() const override;
VarHandle(size_t version, size_t scope_index, std::string name,
platform::Place place)
: version_(version),
scope_idx_(scope_index),
name_(std::move(name)),
place_(std::move(place)) {}
// version field currently is not used, however, just store the version to
// debug easily.
size_t version_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册