未验证 提交 7cd4dd7c 编写于 作者: G gongweibao 提交者: GitHub

Hide varhandle members. (#15382)

上级 236201c2
...@@ -82,13 +82,13 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl( ...@@ -82,13 +82,13 @@ std::unique_ptr<ir::Graph> AllReduceDepsPass::ApplyImpl(
PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error", PADDLE_ENFORCE(i0 != nullptr && i1 != nullptr, "%s convert to %s error",
op1->DebugString(), op2->DebugString()); op1->DebugString(), op2->DebugString());
auto l_it = vars.find(i0->name_); auto l_it = vars.find(i0->name());
auto r_it = vars.find(i1->name_); auto r_it = vars.find(i1->name());
if (l_it->second < r_it->second) return true; if (l_it->second < r_it->second) return true;
if (l_it->second == r_it->second) { if (l_it->second == r_it->second) {
return i0->name_ < i1->name_; return i0->name() < i1->name();
} }
return false; return false;
......
...@@ -70,9 +70,9 @@ void AllReduceOpHandle::RunImpl() { ...@@ -70,9 +70,9 @@ void AllReduceOpHandle::RunImpl() {
auto *s = local_scopes_[i]; auto *s = local_scopes_[i];
auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &lod_tensor = auto &lod_tensor =
local_scope.FindVar(in_var_handles[i]->name_)->Get<LoDTensor>(); local_scope.FindVar(in_var_handles[i]->name())->Get<LoDTensor>();
lod_tensors.emplace_back(&lod_tensor); lod_tensors.emplace_back(&lod_tensor);
PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal."); "The name of input and output should be equal.");
} }
...@@ -134,7 +134,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -134,7 +134,7 @@ void AllReduceOpHandle::RunImpl() {
auto &trg = *this->local_scopes_[0] auto &trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName) ->FindVar(kLocalExecScopeName)
->Get<Scope *>() ->Get<Scope *>()
->FindVar(out_var_handles[0]->name_) ->FindVar(out_var_handles[0]->name())
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
// Reduce All Tensor to trg in CPU // Reduce All Tensor to trg in CPU
...@@ -145,7 +145,7 @@ void AllReduceOpHandle::RunImpl() { ...@@ -145,7 +145,7 @@ void AllReduceOpHandle::RunImpl() {
auto &scope = auto &scope =
*local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>(); *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto &p = places_[i]; auto &p = places_[i];
auto *var = scope.FindVar(out_var_handles[i]->name_); auto *var = scope.FindVar(out_var_handles[i]->name());
auto *dev_ctx = dev_ctxes_.at(p); auto *dev_ctx = dev_ctxes_.at(p);
RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { RunAndRecordEvent(p, [&trg, var, dev_ctx, p] {
......
...@@ -56,11 +56,11 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -56,11 +56,11 @@ void BroadcastOpHandle::BroadcastOneVar(
const std::vector<VarHandle *> &out_var_handles, const std::vector<VarHandle *> &out_var_handles,
const std::vector<const Scope *> &var_scopes) { const std::vector<const Scope *> &var_scopes) {
auto *in_var = auto *in_var =
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
if (UNLIKELY(!in_tensor.IsInitialized())) { if (UNLIKELY(!in_tensor.IsInitialized())) {
VLOG(3) << "in var " << in_var_handle.name_ << "not inited, return!"; VLOG(3) << "in var " << in_var_handle.name() << "not inited, return!";
return; return;
} }
...@@ -71,9 +71,9 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -71,9 +71,9 @@ void BroadcastOpHandle::BroadcastOneVar(
if (out_var_handle->IsTheSameVar(in_var_handle)) { if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue; continue;
} }
auto &out_p = out_var_handle->place_; auto &out_p = out_var_handle->place();
auto *out_var = var_scopes.at(out_var_handle->scope_idx_) auto *out_var = var_scopes.at(out_var_handle->scope_idx())
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name());
RunAndRecordEvent(out_p, [in_tensor, out_var] { RunAndRecordEvent(out_p, [in_tensor, out_var] {
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
...@@ -91,11 +91,11 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -91,11 +91,11 @@ void BroadcastOpHandle::BroadcastOneVar(
size_t numel = static_cast<size_t>(in_tensor.numel()); size_t numel = static_cast<size_t>(in_tensor.numel());
for (auto out_var_handle : out_var_handles) { for (auto out_var_handle : out_var_handles) {
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) Variable *out_var = var_scopes.at(out_var_handle->scope_idx())
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name());
int dst_id = int dst_id =
boost::get<platform::CUDAPlace>(out_var_handle->place_).device; boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
auto &nccl_ctx = nccl_ctxs_->at(dst_id); auto &nccl_ctx = nccl_ctxs_->at(dst_id);
...@@ -106,7 +106,7 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -106,7 +106,7 @@ void BroadcastOpHandle::BroadcastOneVar(
} else { } else {
send_recv_buffer = VariableVisitor::GetMutableTensor(out_var) send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
.Resize(in_tensor.dims()) .Resize(in_tensor.dims())
.mutable_data(out_var_handle->place_); .mutable_data(out_var_handle->place());
} }
broadcast_calls.emplace_back( broadcast_calls.emplace_back(
...@@ -126,11 +126,11 @@ void BroadcastOpHandle::BroadcastOneVar( ...@@ -126,11 +126,11 @@ void BroadcastOpHandle::BroadcastOneVar(
} }
if (!out_handle->IsTheSameVar(in_var_handle)) { if (!out_handle->IsTheSameVar(in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle.scope_idx_) auto out_var = var_scopes.at(in_var_handle.scope_idx())
->FindVar(out_var_handles[0]->name_); ->FindVar(out_var_handles[0]->name());
paddle::framework::TensorCopy( paddle::framework::TensorCopy(
in_tensor, in_var_handle.place_, in_tensor, in_var_handle.place(),
*(dev_ctxes_.at(in_var_handle.place_)), *(dev_ctxes_.at(in_var_handle.place())),
&VariableVisitor::GetMutableTensor(out_var)); &VariableVisitor::GetMutableTensor(out_var));
} }
}); });
...@@ -148,7 +148,7 @@ void BroadcastOpHandle::InitOutputValue( ...@@ -148,7 +148,7 @@ void BroadcastOpHandle::InitOutputValue(
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>()); var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
} }
auto *in_var = auto *in_var =
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_); var_scopes.at(in_var_handle.scope_idx())->FindVar(in_var_handle.name());
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var); Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
...@@ -158,9 +158,9 @@ void BroadcastOpHandle::InitOutputValue( ...@@ -158,9 +158,9 @@ void BroadcastOpHandle::InitOutputValue(
if (out_var_handle->IsTheSameVar(in_var_handle)) { if (out_var_handle->IsTheSameVar(in_var_handle)) {
continue; continue;
} }
auto t_out_p = out_var_handle->place_; auto t_out_p = out_var_handle->place();
auto *out_var = var_scopes.at(out_var_handle->scope_idx_) auto *out_var = var_scopes.at(out_var_handle->scope_idx())
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name());
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
if (is_gpu_place(in_tensor.place())) { if (is_gpu_place(in_tensor.place())) {
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p), PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
......
...@@ -100,13 +100,13 @@ void DataBalanceOpHandle::RunImpl() { ...@@ -100,13 +100,13 @@ void DataBalanceOpHandle::RunImpl() {
std::vector<std::vector<LoDTensor *>> lod_tensors(data_num); std::vector<std::vector<LoDTensor *>> lod_tensors(data_num);
std::vector<int> device_sizes; std::vector<int> device_sizes;
for (int i = 0; i < static_cast<int>(in_var_handles.size()); ++i) { for (int i = 0; i < static_cast<int>(in_var_handles.size()); ++i) {
PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, PADDLE_ENFORCE_EQ(in_var_handles[i]->name(), out_var_handles[i]->name(),
"The name of input and output should be equal."); "The name of input and output should be equal.");
int place_idx = i / data_num; int place_idx = i / data_num;
int data_idx = i % data_num; int data_idx = i % data_num;
auto *local_scope = auto *local_scope =
local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get<Scope *>(); local_scopes_[place_idx]->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name_); auto *tensor_var = local_scope->FindVar(in_var_handles[i]->name());
PADDLE_ENFORCE(tensor_var->IsType<LoDTensor>()); PADDLE_ENFORCE(tensor_var->IsType<LoDTensor>());
auto *tensor = tensor_var->GetMutable<LoDTensor>(); auto *tensor = tensor_var->GetMutable<LoDTensor>();
lod_tensors[data_idx].push_back(tensor); lod_tensors[data_idx].push_back(tensor);
......
...@@ -52,12 +52,12 @@ void FetchOpHandle::RunImpl() { ...@@ -52,12 +52,12 @@ void FetchOpHandle::RunImpl() {
for (size_t i = 0; i < inputs_.size(); ++i) { for (size_t i = 0; i < inputs_.size(); ++i) {
auto *var_handle = static_cast<VarHandle *>(inputs_[i]); auto *var_handle = static_cast<VarHandle *>(inputs_[i]);
auto &scope = scopes.at(var_handle->scope_idx_); auto &scope = scopes.at(var_handle->scope_idx());
auto *var = scope->FindVar(kLocalExecScopeName) auto *var = scope->FindVar(kLocalExecScopeName)
->Get<Scope *>() ->Get<Scope *>()
->FindVar(var_handle->name_); ->FindVar(var_handle->name());
PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope", PADDLE_ENFORCE_NOT_NULL(var, "Cannot find variable %s in execution scope",
var_handle->name_); var_handle->name());
auto &t = var->Get<framework::LoDTensor>(); auto &t = var->Get<framework::LoDTensor>();
if (platform::is_gpu_place(t.place())) { if (platform::is_gpu_place(t.place())) {
......
...@@ -29,14 +29,14 @@ void FuseVarsOpHandle::RunImpl() { ...@@ -29,14 +29,14 @@ void FuseVarsOpHandle::RunImpl() {
auto scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto out_var_handle = out_var_handles[0]; auto out_var_handle = out_var_handles[0];
auto out_var = scope->Var(out_var_handle->name_); auto out_var = scope->Var(out_var_handle->name());
auto out_tensor = out_var->GetMutable<LoDTensor>(); auto out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->Resize({total_numel_}).mutable_data(this->place_, type_); out_tensor->Resize({total_numel_}).mutable_data(this->place_, type_);
int64_t s = 0; int64_t s = 0;
for (size_t i = 1; i < out_var_handles.size(); ++i) { for (size_t i = 1; i < out_var_handles.size(); ++i) {
auto out_name = out_var_handles[i]->name_; auto out_name = out_var_handles[i]->name();
auto out_t = scope->Var(out_name)->GetMutable<LoDTensor>(); auto out_t = scope->Var(out_name)->GetMutable<LoDTensor>();
auto numel = this->inputs_numel_.at(out_name); auto numel = this->inputs_numel_.at(out_name);
out_t->ShareDataWith(out_tensor->Slice(s, s + numel)); out_t->ShareDataWith(out_tensor->Slice(s, s + numel));
......
...@@ -49,7 +49,7 @@ void GatherOpHandle::RunImpl() { ...@@ -49,7 +49,7 @@ void GatherOpHandle::RunImpl() {
auto in_0_handle = in_var_handles[0]; auto in_0_handle = in_var_handles[0];
auto pre_in_var = auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name());
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(), PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
...@@ -65,7 +65,7 @@ void GatherOpHandle::RunImpl() { ...@@ -65,7 +65,7 @@ void GatherOpHandle::RunImpl() {
// Gather the inputs // Gather the inputs
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto *in_var = auto *in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx())->FindVar(in_handle->name());
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var); VariableVisitor::EnforceShapeAndDTypeEQ(*in_var, *pre_in_var);
...@@ -77,7 +77,7 @@ void GatherOpHandle::RunImpl() { ...@@ -77,7 +77,7 @@ void GatherOpHandle::RunImpl() {
} }
// NOTE: The Places of all input tensor must be all on CPU or all on GPU. // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
platform::Place t_out_p = out_var_handle->place_; platform::Place t_out_p = out_var_handle->place();
if (platform::is_gpu_place(pre_in_value.place())) { if (platform::is_gpu_place(pre_in_value.place())) {
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p), PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
"Places of input and output must be all on GPU."); "Places of input and output must be all on GPU.");
...@@ -85,8 +85,8 @@ void GatherOpHandle::RunImpl() { ...@@ -85,8 +85,8 @@ void GatherOpHandle::RunImpl() {
t_out_p = platform::CPUPlace(); t_out_p = platform::CPUPlace();
} }
auto out_var = auto out_var = var_scopes.at(out_var_handle->scope_idx())
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name());
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
auto out_value = out_var->GetMutable<framework::SelectedRows>(); auto out_value = out_var->GetMutable<framework::SelectedRows>();
out_value->set_height(pre_in_value.height()); out_value->set_height(pre_in_value.height());
...@@ -99,8 +99,8 @@ void GatherOpHandle::RunImpl() { ...@@ -99,8 +99,8 @@ void GatherOpHandle::RunImpl() {
Tensor *out_tensor = out_value->mutable_value(); Tensor *out_tensor = out_value->mutable_value();
// copy // copy
auto dev_ctx = dev_ctxes_.at(out_var_handle->place_); auto dev_ctx = dev_ctxes_.at(out_var_handle->place());
RunAndRecordEvent(out_var_handle->place_, [in_tensors, out_tensor, &dev_ctx, RunAndRecordEvent(out_var_handle->place(), [in_tensors, out_tensor, &dev_ctx,
t_out_p] { t_out_p] {
int s = 0, e = 0; int s = 0, e = 0;
for (size_t j = 0; j < in_tensors.size(); ++j) { for (size_t j = 0; j < in_tensors.size(); ++j) {
......
...@@ -33,7 +33,7 @@ static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) { ...@@ -33,7 +33,7 @@ static ComputationOpHandle* FindNextComputationOpHandle(VarHandle* var_in) {
queue.pop(); queue.pop();
for (auto* op : var->PendingOps()) { for (auto* op : var->PendingOps()) {
auto* compute_op = dynamic_cast<ComputationOpHandle*>(op); auto* compute_op = dynamic_cast<ComputationOpHandle*>(op);
if (compute_op != nullptr && compute_op->GetPlace() == var_in->place_) { if (compute_op != nullptr && compute_op->GetPlace() == var_in->place()) {
return compute_op; return compute_op;
} }
for (auto* out_var : op->Outputs()) { for (auto* out_var : op->Outputs()) {
...@@ -64,7 +64,7 @@ std::unique_ptr<ir::Graph> MemoryEarlyDeletePass::ApplyImpl( ...@@ -64,7 +64,7 @@ std::unique_ptr<ir::Graph> MemoryEarlyDeletePass::ApplyImpl(
for (auto& var : vars) { for (auto& var : vars) {
auto* var_handle = dynamic_cast<VarHandle*>(var); auto* var_handle = dynamic_cast<VarHandle*>(var);
auto var_name = var->Node()->Name(); auto var_name = var->Node()->Name();
auto& var_place = var_handle->place_; auto& var_place = var_handle->place();
if (unlived_vars.count(var_name) == 0) continue; if (unlived_vars.count(var_name) == 0) continue;
if (!unlived_vars[var_name].empty()) { if (!unlived_vars[var_name].empty()) {
if (compute_op != nullptr && if (compute_op != nullptr &&
......
...@@ -52,11 +52,11 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph, ...@@ -52,11 +52,11 @@ void GraphvizSSAGraphPrinter::Print(const ir::Graph &graph,
vars[var_ptr] = cur_var_id; vars[var_ptr] = cur_var_id;
if (var_handle_ptr) { if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_ sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name()
<< "\\n" << "\\n"
<< var_handle_ptr->place_ << "\\n" << var_handle_ptr->place() << "\\n"
<< "scope: " << var_handle_ptr->scope_idx_ << "\\n" << "scope: " << var_handle_ptr->scope_idx() << "\\n"
<< "v" << var_handle_ptr->version_ << "\"]" << std::endl; << "v" << var_handle_ptr->version() << "\"]" << std::endl;
} else if (dummy_ptr) { } else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl; sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
} }
......
...@@ -60,8 +60,8 @@ void ReduceOpHandle::GatherSelectedRows( ...@@ -60,8 +60,8 @@ void ReduceOpHandle::GatherSelectedRows(
*CollectiveContext::GetInstance(); *CollectiveContext::GetInstance();
// 1. gather local selected rows, merge them // 1. gather local selected rows, merge them
std::string gathered_var_name = out_var_handle->name_ + "_gathered_tmp"; std::string gathered_var_name = out_var_handle->name() + "_gathered_tmp";
auto scope = local_scopes_.at(out_var_handle->scope_idx_); auto scope = local_scopes_.at(out_var_handle->scope_idx());
auto gathered_var_mid = scope->Var(gathered_var_name); auto gathered_var_mid = scope->Var(gathered_var_name);
auto gathered_select_rows = auto gathered_select_rows =
gathered_var_mid->GetMutable<framework::SelectedRows>(); gathered_var_mid->GetMutable<framework::SelectedRows>();
...@@ -73,7 +73,7 @@ void ReduceOpHandle::GatherSelectedRows( ...@@ -73,7 +73,7 @@ void ReduceOpHandle::GatherSelectedRows(
// merge them // merge them
auto merged_dev_ctx = dynamic_cast<DevCtx *>(dev_ctxes.at(out_place)); auto merged_dev_ctx = dynamic_cast<DevCtx *>(dev_ctxes.at(out_place));
std::string merged_var_name = std::string merged_var_name =
GetRemoteVarName(out_var_handle->name_, collective_context.trainer_id_); GetRemoteVarName(out_var_handle->name(), collective_context.trainer_id_);
auto merged_select_rows = auto merged_select_rows =
scope->Var(merged_var_name)->GetMutable<SelectedRows>(); scope->Var(merged_var_name)->GetMutable<SelectedRows>();
operators::math::scatter::MergeAdd<DevCtx, DataType> merge_func; operators::math::scatter::MergeAdd<DevCtx, DataType> merge_func;
...@@ -101,7 +101,7 @@ void ReduceOpHandle::GatherSelectedRows( ...@@ -101,7 +101,7 @@ void ReduceOpHandle::GatherSelectedRows(
operators::distributed::RemoteVar var; operators::distributed::RemoteVar var;
var.trainer_id_ = i; var.trainer_id_ = i;
var.var_name_ = GetRemoteVarName(out_var_handle->name_, i); var.var_name_ = GetRemoteVarName(out_var_handle->name(), i);
var.ep_ = collective_context.endpoints_[i]; var.ep_ = collective_context.endpoints_[i];
vars.push_back(var); vars.push_back(var);
...@@ -166,7 +166,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -166,7 +166,7 @@ void ReduceOpHandle::RunImpl() {
} }
auto pre_in_var = auto pre_in_var =
var_scopes.at(in_0_handle->scope_idx_)->FindVar(in_0_handle->name_); var_scopes.at(in_0_handle->scope_idx())->FindVar(in_0_handle->name());
PADDLE_ENFORCE_NOT_NULL(pre_in_var); PADDLE_ENFORCE_NOT_NULL(pre_in_var);
// Wait input done, this Wait is asynchronous operation // Wait input done, this Wait is asynchronous operation
...@@ -175,15 +175,15 @@ void ReduceOpHandle::RunImpl() { ...@@ -175,15 +175,15 @@ void ReduceOpHandle::RunImpl() {
// NOTE: The Places of all input tensor must be all on CPU or all on GPU. // NOTE: The Places of all input tensor must be all on CPU or all on GPU.
std::vector<platform::Place> in_places; // used to get dev_ctx std::vector<platform::Place> in_places; // used to get dev_ctx
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
in_places.emplace_back(in_handle->place_); in_places.emplace_back(in_handle->place());
auto in_var = auto in_var =
var_scopes.at(in_handle->scope_idx_)->FindVar(in_handle->name_); var_scopes.at(in_handle->scope_idx())->FindVar(in_handle->name());
PADDLE_ENFORCE_NOT_NULL(in_var); PADDLE_ENFORCE_NOT_NULL(in_var);
VariableVisitor::EnforceShapeAndDTypeEQ(*pre_in_var, *in_var); VariableVisitor::EnforceShapeAndDTypeEQ(*pre_in_var, *in_var);
} }
auto out_var = auto out_var = var_scopes.at(out_var_handle->scope_idx())
var_scopes.at(out_var_handle->scope_idx_)->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name());
PADDLE_ENFORCE_NOT_NULL(out_var); PADDLE_ENFORCE_NOT_NULL(out_var);
// NOTE: The tensors' Place of input and output must be all on GPU or all on // NOTE: The tensors' Place of input and output must be all on GPU or all on
...@@ -191,9 +191,9 @@ void ReduceOpHandle::RunImpl() { ...@@ -191,9 +191,9 @@ void ReduceOpHandle::RunImpl() {
auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place(); auto in_p = VariableVisitor::GetMutableTensor(pre_in_var).place();
platform::Place t_out_p; platform::Place t_out_p;
if (platform::is_gpu_place(in_p)) { if (platform::is_gpu_place(in_p)) {
PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place_), PADDLE_ENFORCE(platform::is_gpu_place(out_var_handle->place()),
"Places of input and output must be all on GPU."); "Places of input and output must be all on GPU.");
t_out_p = out_var_handle->place_; t_out_p = out_var_handle->place();
} else { } else {
t_out_p = platform::CPUPlace(); t_out_p = platform::CPUPlace();
} }
...@@ -253,7 +253,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -253,7 +253,7 @@ void ReduceOpHandle::RunImpl() {
auto &reduce_sum_trg = *this->local_scopes_[0] auto &reduce_sum_trg = *this->local_scopes_[0]
->FindVar(kLocalExecScopeName) ->FindVar(kLocalExecScopeName)
->Get<Scope *>() ->Get<Scope *>()
->FindVar(out_var_handle->name_) ->FindVar(out_var_handle->name())
->GetMutable<framework::LoDTensor>(); ->GetMutable<framework::LoDTensor>();
ReduceLoDTensor func(lod_tensors, &reduce_sum_trg); ReduceLoDTensor func(lod_tensors, &reduce_sum_trg);
VisitDataType(lod_tensors[0]->type(), func); VisitDataType(lod_tensors[0]->type(), func);
...@@ -269,9 +269,9 @@ void ReduceOpHandle::RunImpl() { ...@@ -269,9 +269,9 @@ void ReduceOpHandle::RunImpl() {
auto pre_in = pre_in_var->Get<framework::LoDTensor>(); auto pre_in = pre_in_var->Get<framework::LoDTensor>();
VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var); VariableVisitor::ShareDimsAndLoD(*pre_in_var, out_var);
VariableVisitor::GetMutableTensor(out_var).mutable_data( VariableVisitor::GetMutableTensor(out_var).mutable_data(
out_var_handle->place_, pre_in.type()); out_var_handle->place(), pre_in.type());
auto out_p = out_var_handle->place_; auto out_p = out_var_handle->place();
int root_id = boost::get<platform::CUDAPlace>(out_p).device; int root_id = boost::get<platform::CUDAPlace>(out_p).device;
std::vector<std::function<void()>> all_reduce_calls; std::vector<std::function<void()>> all_reduce_calls;
for (size_t i = 0; i < var_scopes.size(); ++i) { for (size_t i = 0; i < var_scopes.size(); ++i) {
...@@ -286,7 +286,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -286,7 +286,7 @@ void ReduceOpHandle::RunImpl() {
if (root_id == dev_id) { if (root_id == dev_id) {
recvbuffer = recvbuffer =
out_var->GetMutable<framework::LoDTensor>()->mutable_data( out_var->GetMutable<framework::LoDTensor>()->mutable_data(
out_var_handle->place_); out_var_handle->place());
} }
int type = platform::ToNCCLDataType(lod_tensor.type()); int type = platform::ToNCCLDataType(lod_tensor.type());
...@@ -320,8 +320,8 @@ std::vector<const T *> ReduceOpHandle::GetInputValues( ...@@ -320,8 +320,8 @@ std::vector<const T *> ReduceOpHandle::GetInputValues(
const std::vector<const Scope *> &var_scopes) const { const std::vector<const Scope *> &var_scopes) const {
std::vector<const T *> in_selected_rows; std::vector<const T *> in_selected_rows;
for (auto *in_handle : in_var_handles) { for (auto *in_handle : in_var_handles) {
auto &in_sr = var_scopes.at(in_handle->scope_idx_) auto &in_sr = var_scopes.at(in_handle->scope_idx())
->FindVar(in_handle->name_) ->FindVar(in_handle->name())
->Get<T>(); ->Get<T>();
in_selected_rows.emplace_back(&in_sr); in_selected_rows.emplace_back(&in_sr);
} }
......
...@@ -30,7 +30,7 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc, ...@@ -30,7 +30,7 @@ RPCOpHandle::RPCOpHandle(ir::Node *node, const framework::OpDesc &op_desc,
void RPCOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place();
if (ir::IsControlDepVar(*in->Node())) { if (ir::IsControlDepVar(*in->Node())) {
continue; continue;
} }
......
...@@ -68,7 +68,7 @@ struct ScaleLossGradFunctor { ...@@ -68,7 +68,7 @@ struct ScaleLossGradFunctor {
void ScaleLossGradOpHandle::RunImpl() { void ScaleLossGradOpHandle::RunImpl() {
// Doesn't wait any event // Doesn't wait any event
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_; std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name();
auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(); auto &local_scope = *scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>(); auto *tensor = local_scope.FindVar(var_name)->GetMutable<LoDTensor>();
......
...@@ -111,15 +111,22 @@ struct VarHandle : public VarHandleBase { ...@@ -111,15 +111,22 @@ struct VarHandle : public VarHandleBase {
// version field currently is not used, however, just store the version to // version field currently is not used, however, just store the version to
// debug easily. // debug easily.
private:
size_t version_; size_t version_;
size_t scope_idx_; size_t scope_idx_;
std::string name_; std::string name_;
platform::Place place_; platform::Place place_;
public:
bool IsTheSameVar(const VarHandle& o) const { bool IsTheSameVar(const VarHandle& o) const {
return o.generated_op_ == generated_op_ && o.name_ == name_ && return o.generated_op_ == generated_op_ && o.name_ == name_ &&
o.scope_idx_ == scope_idx_; o.scope_idx_ == scope_idx_;
} }
size_t version() const { return version_; }
size_t scope_idx() const { return scope_idx_; }
const std::string& name() const { return name_; }
const platform::Place& place() const { return place_; }
}; };
// Dummy Variable. It is used to represent dependencies between operators // Dummy Variable. It is used to represent dependencies between operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册