未验证 提交 345de9a5 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Add Construct event for new ir interpretercore (#55555)

* add interface

* add code

* add code

* add code

* add code

* fix bug

* fix bug

* add var prefix

* add code

* add code

* add code

* fix compile bug

* fix bug

* refine code

* refine code

* refine code

* refine code

* fix bug
上级 984a4cc1
...@@ -45,9 +45,9 @@ class InstructionBase { ...@@ -45,9 +45,9 @@ class InstructionBase {
OpFuncType KernelType() const; OpFuncType KernelType() const;
void SetKernelType(OpFuncType type) { type_ = type; } void SetKernelType(OpFuncType type) { type_ = type; }
int GetStreamPriority() const { return scheduling_priority_; } int GetStreamPriority() const { return stream_priority_; }
void SetStreamPriority(SchedulingPriority scheduling_priority) { void SetStreamPriority(int stream_priority) {
scheduling_priority_ = scheduling_priority; stream_priority_ = stream_priority;
} }
SchedulingPriority GetSchedulingPriority() const { SchedulingPriority GetSchedulingPriority() const {
...@@ -107,22 +107,31 @@ class InstructionBase { ...@@ -107,22 +107,31 @@ class InstructionBase {
std::map<int, int>& GetMutableInplaceBackMap() { return inplace_back_map_; } std::map<int, int>& GetMutableInplaceBackMap() { return inplace_back_map_; }
const std::map<int, int>& GetInplaceBackMap() { return inplace_back_map_; } const std::map<int, int>& GetInplaceBackMap() { return inplace_back_map_; }
const std::unordered_map<ir::Value, std::vector<int>>& Inputs() const { const std::unordered_map<::ir::Value, std::vector<int>>& Inputs() const {
return input_index_; return input_index_;
} }
std::unordered_map<ir::Value, std::vector<int>>& GetMutableInputs() { std::unordered_map<::ir::Value, std::vector<int>>& GetMutableInputs() {
return input_index_; return input_index_;
} }
void SetInputs(const std::unordered_map<ir::Value, std::vector<int>>& inputs); void SetInputs(
const std::unordered_map<::ir::Value, std::vector<int>>& inputs);
const std::unordered_map<ir::Value, std::vector<int>>& Outputs() const { const std::unordered_map<::ir::Value, std::vector<int>>& Outputs() const {
return output_index_; return output_index_;
} }
std::unordered_map<ir::Value, std::vector<int>>& GetMutableOutputs() { std::unordered_map<::ir::Value, std::vector<int>>& GetMutableOutputs() {
return output_index_; return output_index_;
} }
void SetOutputs( void SetOutputs(
const std::unordered_map<ir::Value, std::vector<int>>& outputs); const std::unordered_map<::ir::Value, std::vector<int>>& outputs);
const std::unordered_set<::ir::Value>& NoNeedBuffer() const {
return no_need_buffer_values_;
}
void SetNoNeedBuffer(
const std::unordered_set<::ir::Value>& no_need_buffer_values) {
no_need_buffer_values_ = no_need_buffer_values;
}
virtual void Run() = 0; virtual void Run() = 0;
...@@ -159,9 +168,11 @@ class InstructionBase { ...@@ -159,9 +168,11 @@ class InstructionBase {
std::map<int, int> inplace_back_map_; std::map<int, int> inplace_back_map_;
std::unordered_map<ir::Value, std::vector<int>> input_index_; std::unordered_map<::ir::Value, std::vector<int>> input_index_;
std::unordered_map<::ir::Value, std::vector<int>> output_index_;
std::unordered_map<ir::Value, std::vector<int>> output_index_; std::unordered_set<::ir::Value> no_need_buffer_values_;
}; };
} // namespace framework } // namespace framework
......
...@@ -15,12 +15,15 @@ ...@@ -15,12 +15,15 @@
#include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/interface/infermeta.h" #include "paddle/fluid/ir/interface/infermeta.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h" #include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/core/type_defs.h" #include "paddle/phi/core/type_defs.h"
...@@ -32,6 +35,77 @@ ...@@ -32,6 +35,77 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
platform::DeviceContext* ParseDeviceContext(
ir::Operation* op,
platform::DeviceContext* origin_dev_ctx,
const platform::Place& place,
const std::string& execution_stream,
const int stream_priority) {
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
interpreter::ContextManager& ctx_manager =
interpreter::ContextManager::Instance();
platform::DeviceContext* dev_ctx = nullptr;
// only gpu need update. xpu not need, because xpu memcpy op kernel is
// synchronous.
if (platform::is_gpu_place(place) || platform::is_custom_place(place)) {
VLOG(6) << "Parse DeviceContext for " << op_name
<< ", execution stream = " << execution_stream;
if (execution_stream != kDefaultStream) {
dev_ctx = ctx_manager
.Get(std::string(kCustomStream) + "-" + execution_stream,
place,
stream_priority)
.get()
.get();
interpreter::SetDeviceCommContext(op, dev_ctx);
return dev_ctx;
}
if (op_name == interpreter::kMemcpyD2H) {
dev_ctx = ctx_manager.Get(std::string(kD2HStream), place, stream_priority)
.get()
.get();
interpreter::SetDeviceCommContext(op, dev_ctx);
return dev_ctx;
} else if (op_name == interpreter::kMemcpyH2D) {
dev_ctx = ctx_manager.Get(std::string(kH2DStream), place, stream_priority)
.get()
.get();
interpreter::SetDeviceCommContext(op, dev_ctx);
return dev_ctx;
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum
// with use_cal_stream==false by returning a device context getting from the
// global NCCLCommContext instance. Because when use_calc_stream==false, in
// OP kernel, the NCCL communication will be launched to the stream directly
// getting from the global NCCLCommContext instance rather than the
// DeviceContext passed from executor (see CAllReduceOpCUDAKernel in
// c_allreduce_op.h). Now it is just a temporary solution for ONLY
// c_allreduce_sum which is used in ResNet50 distributed training.
if (op_name == "c_allreduce_sum" && op_attributes.at("use_calc_stream")
.dyn_cast<::ir::BoolAttribute>()
.data() == false) {
int ring_id =
op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data();
return platform::NCCLCommContext::Instance()
.Get(ring_id, place)
->dev_context();
}
#endif
}
if (origin_dev_ctx != nullptr) {
interpreter::SetDeviceCommContext(op, origin_dev_ctx);
}
return origin_dev_ctx;
}
OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) { OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
return OpFuncType::kCpuSync; return OpFuncType::kCpuSync;
...@@ -172,15 +246,27 @@ PhiKernelInstruction::PhiKernelInstruction( ...@@ -172,15 +246,27 @@ PhiKernelInstruction::PhiKernelInstruction(
kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get( kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend()))); phi::TransToPhiPlace(kernel_key.backend())));
VLOG(6) << "finish process kernel context"; VLOG(6) << "finish process kernel context";
SetDeviceContext(
SetDeviceContext(phi::DeviceContextPool::Instance().Get( ParseDeviceContext(op,
phi::TransToPhiPlace(kernel_key.backend()))); phi::DeviceContextPool::Instance().Get(
phi::TransToPhiPlace(kernel_key.backend())),
place,
GetExecutionStream(),
GetStreamPriority()));
VLOG(6) << "finish process device context"; VLOG(6) << "finish process device context";
Scope* inner_scope = local_scope == nullptr ? scope : local_scope; Scope* inner_scope = local_scope == nullptr ? scope : local_scope;
InitInputsOutputsIds( InitInputsOutputsIds(
op, inner_scope, value_2_var_name, var_name_2_id, variable_2_var_name); op, inner_scope, value_2_var_name, var_name_2_id, variable_2_var_name);
VLOG(6) << "finish process inputs outputs index"; VLOG(6) << "finish process inputs outputs index";
auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds();
std::unordered_set<::ir::Value> no_need_buffer_values;
for (size_t id = 0; id < no_need_buffer_ids.size(); id++) {
no_need_buffer_values.insert(op->operand(no_need_buffer_ids[id]));
}
SetNoNeedBuffer(no_need_buffer_values);
VLOG(6) << "finish process no need buffer";
} }
std::vector<int> GetValueIds( std::vector<int> GetValueIds(
......
...@@ -381,10 +381,8 @@ void DependencyBuilder::AddDownstreamOp(size_t prior_op_idx, ...@@ -381,10 +381,8 @@ void DependencyBuilder::AddDownstreamOp(size_t prior_op_idx,
VLOG(8) << prior_op_idx << "->" << posterior_op_idx; VLOG(8) << prior_op_idx << "->" << posterior_op_idx;
VLOG(8) << "Add dependency from " VLOG(8) << "Add dependency from "
<< instructions_->at(prior_op_idx).OpBase()->Type() << "(" << "prior_op_idx(" << prior_op_idx << ") to "
<< prior_op_idx << ") to " << "posterior_op_idx(" << posterior_op_idx << ")";
<< instructions_->at(posterior_op_idx).OpBase()->Type() << "("
<< posterior_op_idx << ")";
} }
void DependencyBuilder::BuildDownstreamMap() { void DependencyBuilder::BuildDownstreamMap() {
...@@ -405,22 +403,6 @@ void DependencyBuilder::BuildDownstreamMap() { ...@@ -405,22 +403,6 @@ void DependencyBuilder::BuildDownstreamMap() {
op2dependences[op_idx] = std::set<size_t>(); op2dependences[op_idx] = std::set<size_t>();
} }
auto update_var_min_rw_op =
[](const std::map<size_t, std::set<size_t>>& op2dependences,
std::map<size_t, std::list<size_t>>* var2min_rw_op,
size_t cur_op,
size_t rw_var) {
// rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set .
if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
(*var2min_rw_op)[rw_var] = std::list<size_t>();
}
for (auto dep_op : op2dependences.at(cur_op)) {
var2min_rw_op->at(rw_var).remove(dep_op);
}
var2min_rw_op->at(rw_var).push_back(cur_op);
};
for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) {
remove_duplicate.clear(); remove_duplicate.clear();
// step1: update the op2dependences structure // step1: update the op2dependences structure
...@@ -485,7 +467,7 @@ void DependencyBuilder::BuildDownstreamMap() { ...@@ -485,7 +467,7 @@ void DependencyBuilder::BuildDownstreamMap() {
for (auto var : item.second) { for (auto var : item.second) {
if (remove_duplicate.count(var) == if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it. 0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); UpdateVarMinRwOp(op2dependences, &var2min_rw_op, op_idx, var);
} }
} }
} }
...@@ -546,22 +528,45 @@ void DependencyBuilder::ShrinkDownstreamMap() { ...@@ -546,22 +528,45 @@ void DependencyBuilder::ShrinkDownstreamMap() {
<< StringizeDownstreamMap(*op_downstream_map_); << StringizeDownstreamMap(*op_downstream_map_);
} }
void DependencyBuilder::UpdateVarMinRwOp(
const std::map<size_t, std::set<size_t>>& op2dependences,
std::map<size_t, std::list<size_t>>* var2min_rw_op,
size_t cur_op,
size_t rw_var) {
// rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set .
if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
(*var2min_rw_op)[rw_var] = std::list<size_t>();
}
for (auto dep_op : op2dependences.at(cur_op)) {
var2min_rw_op->at(rw_var).remove(dep_op);
}
var2min_rw_op->at(rw_var).push_back(cur_op);
}
/// ======================== /// /// ======================== ///
/// For new ir /// /// For new ir ///
/// ======================== /// /// ======================== ///
const std::map<size_t, std::set<size_t>>& IrDependencyBuilder::Build( NewIrDependencyBuilder::NewIrDependencyBuilder() {
const std::vector<std::unique_ptr<paddle::framework::InstructionBase>>& is_build_ = false;
instructions) { op_downstream_map_ = std::make_shared<std::map<size_t, std::set<size_t>>>();
op_happens_before_ = std::make_shared<std::vector<std::vector<bool>>>();
}
const std::map<size_t, std::set<size_t>>& NewIrDependencyBuilder::Build(
std::vector<paddle::framework::InstructionBase*> instructions) {
if (is_build_) { if (is_build_) {
return op_downstream_map_; return *op_downstream_map_;
} }
instructions_ = &instructions; std::tie(op_downstream_map_, op_happens_before_) = GetDependency();
op_num_ = instructions_->size();
instructions_ = instructions;
op_num_ = instructions_.size();
ops_before_.assign(op_num_, {}); ops_before_.assign(op_num_, {});
ops_behind_.assign(op_num_, {}); ops_behind_.assign(op_num_, {});
op_happens_before_.assign(op_num_, std::vector<bool>(op_num_, false)); op_happens_before_->assign(op_num_, std::vector<bool>(op_num_, false));
BuildDownstreamMap(); BuildDownstreamMap();
VLOG(6) << "Finish BuildDownstreamMap"; VLOG(6) << "Finish BuildDownstreamMap";
...@@ -576,16 +581,16 @@ const std::map<size_t, std::set<size_t>>& IrDependencyBuilder::Build( ...@@ -576,16 +581,16 @@ const std::map<size_t, std::set<size_t>>& IrDependencyBuilder::Build(
// TODO(zhangbo): Add dependency for special op ? // TODO(zhangbo): Add dependency for special op ?
VLOG(6) << "Finish build dependency"; VLOG(6) << "Finish build dependency";
VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map_); VLOG(8) << "downstream count: " << CountDownstreamMap(*op_downstream_map_);
VLOG(8) << "downstream_map: " << std::endl VLOG(8) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(op_downstream_map_); << StringizeDownstreamMap(*op_downstream_map_);
is_build_ = true; is_build_ = true;
return op_downstream_map_; return *op_downstream_map_;
} }
void IrDependencyBuilder::BuildDownstreamMap() { void NewIrDependencyBuilder::BuildDownstreamMap() {
auto var2min_rw_op = auto var2min_rw_op =
std::map<size_t, std::list<size_t>>(); // # map from variable id to read std::map<size_t, std::list<size_t>>(); // # map from variable id to read
// write op id. // write op id.
...@@ -604,27 +609,11 @@ void IrDependencyBuilder::BuildDownstreamMap() { ...@@ -604,27 +609,11 @@ void IrDependencyBuilder::BuildDownstreamMap() {
op2dependences[op_idx] = std::set<size_t>(); op2dependences[op_idx] = std::set<size_t>();
} }
auto update_var_min_rw_op =
[](const std::map<size_t, std::set<size_t>>& op2dependences,
std::map<size_t, std::list<size_t>>* var2min_rw_op,
size_t cur_op,
size_t rw_var) {
// rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set .
if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) {
(*var2min_rw_op)[rw_var] = std::list<size_t>();
}
for (auto dep_op : op2dependences.at(cur_op)) {
var2min_rw_op->at(rw_var).remove(dep_op);
}
var2min_rw_op->at(rw_var).push_back(cur_op);
};
for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) {
remove_duplicate.clear(); remove_duplicate.clear();
// step1: update the op2dependences structure // step1: update the op2dependences structure
for (auto& item : for (auto& item :
instructions_->at(op_idx)->Inputs()) { // for all inputs(read only) instructions_.at(op_idx)->Inputs()) { // for all inputs(read only)
for (auto var : item.second) { for (auto var : item.second) {
if (var2recent_write_op.count(var)) if (var2recent_write_op.count(var))
op2dependences[op_idx].insert(var2recent_write_op[var]); op2dependences[op_idx].insert(var2recent_write_op[var]);
...@@ -632,7 +621,7 @@ void IrDependencyBuilder::BuildDownstreamMap() { ...@@ -632,7 +621,7 @@ void IrDependencyBuilder::BuildDownstreamMap() {
} }
for (auto& item : for (auto& item :
instructions_->at(op_idx)->Outputs()) { // for all write vars instructions_.at(op_idx)->Outputs()) { // for all write vars
for (auto var : item.second) { for (auto var : item.second) {
if (var2min_rw_op.count(var)) { if (var2min_rw_op.count(var)) {
for (auto dep_op : var2min_rw_op[var]) { for (auto dep_op : var2min_rw_op[var]) {
...@@ -644,7 +633,7 @@ void IrDependencyBuilder::BuildDownstreamMap() { ...@@ -644,7 +633,7 @@ void IrDependencyBuilder::BuildDownstreamMap() {
// step2: update 2 var2xxxx data structure // step2: update 2 var2xxxx data structure
for (auto& item : for (auto& item :
instructions_->at(op_idx)->Outputs()) { // for all write vars instructions_.at(op_idx)->Outputs()) { // for all write vars
for (auto var : item.second) { for (auto var : item.second) {
var2recent_write_op[var] = op_idx; var2recent_write_op[var] = op_idx;
var2min_rw_op[var] = {static_cast<size_t>(op_idx)}; var2min_rw_op[var] = {static_cast<size_t>(op_idx)};
...@@ -653,11 +642,11 @@ void IrDependencyBuilder::BuildDownstreamMap() { ...@@ -653,11 +642,11 @@ void IrDependencyBuilder::BuildDownstreamMap() {
} }
for (auto& item : for (auto& item :
instructions_->at(op_idx)->Inputs()) { // for all inputs(read only) instructions_.at(op_idx)->Inputs()) { // for all inputs(read only)
for (auto var : item.second) { for (auto var : item.second) {
if (remove_duplicate.count(var) == if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it. 0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); UpdateVarMinRwOp(op2dependences, &var2min_rw_op, op_idx, var);
} }
} }
} }
...@@ -675,118 +664,6 @@ void IrDependencyBuilder::BuildDownstreamMap() { ...@@ -675,118 +664,6 @@ void IrDependencyBuilder::BuildDownstreamMap() {
} }
} }
void IrDependencyBuilder::AddDownstreamOp(size_t prior_op_idx,
size_t posterior_op_idx) {
PADDLE_ENFORCE_EQ(
OpHappensBefore(posterior_op_idx, prior_op_idx),
false,
phi::errors::Unavailable(
"Can not add dependency %d->%d because %d is run before %d",
prior_op_idx,
posterior_op_idx,
posterior_op_idx,
prior_op_idx));
std::set<size_t>& downstream_ops = op_downstream_map_[prior_op_idx];
// NOTE(Ruibiao): Here the downstream map shrinking is best-effort, therefore
// ShrinkDownstreamMap after BuildDownstreamMap is still helpful. For example,
// a->c will not be shrinked in the following case: AddDownstreamOp(a, b) ->
// AddDownstreamOp(a, c) -> AddDownstreamOp(b, c), it should be shrinked by
// ShrinkDownstreamMap.
for (size_t op_idx : downstream_ops) {
if (OpHappensBefore(op_idx, posterior_op_idx)) {
VLOG(7) << "Find dependencies " << prior_op_idx << "->" << op_idx << "->"
<< posterior_op_idx << ", skip adding " << prior_op_idx << "->"
<< posterior_op_idx;
return;
}
}
downstream_ops.insert(posterior_op_idx);
std::vector<size_t> prior_of_prior = ops_before_[prior_op_idx];
std::vector<size_t> posterior_of_posterior = ops_behind_[posterior_op_idx];
auto update_op_happen_before = [this](size_t prior_op_idx,
size_t posterior_op_idx) {
if (!op_happens_before_[prior_op_idx][posterior_op_idx]) {
op_happens_before_[prior_op_idx][posterior_op_idx] = true;
ops_before_[posterior_op_idx].push_back(prior_op_idx);
ops_behind_[prior_op_idx].push_back(posterior_op_idx);
}
};
update_op_happen_before(prior_op_idx, posterior_op_idx);
// All ops before prior-op are also before posterior-op
for (size_t op_idx : prior_of_prior) {
update_op_happen_before(op_idx, posterior_op_idx);
}
// All ops after posterior-op are also after prior-op
for (size_t op_idx : posterior_of_posterior) {
update_op_happen_before(prior_op_idx, op_idx);
}
VLOG(8) << prior_op_idx << "->" << posterior_op_idx;
VLOG(8) << "Add dependency from " << instructions_->at(prior_op_idx)->Name()
<< "(" << prior_op_idx << ") to "
<< instructions_->at(posterior_op_idx)->Name() << "("
<< posterior_op_idx << ")";
}
void IrDependencyBuilder::ShrinkDownstreamMap() {
// remove unnecessary downstream ops
// for example, a->b->c
// a: b, c
// b: c
// =>
// a: b
// b: c
// shrink, find the downstream op that has no other op in the
// downstream list happens before it
for (size_t i = 0; i < op_num_; ++i) {
if (op_downstream_map_.find(i) == op_downstream_map_.end()) {
continue;
}
std::set<size_t> minumum_nexts;
for (size_t item : op_downstream_map_.at(i)) {
bool not_after_any = true;
// find the op that is not executed after any
for (size_t other_item : op_downstream_map_.at(i)) {
if (OpHappensBefore(other_item, item)) {
VLOG(8) << "happens_before: " << other_item << "->" << item
<< ", so skip " << item;
not_after_any = false;
break;
}
}
if (not_after_any) {
VLOG(8) << "downstream op of " << i << ": " << item;
minumum_nexts.insert(item);
}
}
// NOTE(Ruibiao): op_happens_before will not be changed when shrink
// dowstream map
op_downstream_map_.at(i) = minumum_nexts;
}
VLOG(8) << "Finish shrink downstream map";
VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map_);
VLOG(8) << "downstream_map: " << std::endl
<< StringizeDownstreamMap(op_downstream_map_);
}
void IrDependencyBuilder::AddDependencyForSequentialRun() {
size_t dependence_op_idx = ULLONG_MAX;
for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) {
if (dependence_op_idx != ULLONG_MAX) {
AddDownstreamOp(dependence_op_idx, op_idx);
}
dependence_op_idx = op_idx;
}
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -57,7 +57,7 @@ class DependencyBuilder { ...@@ -57,7 +57,7 @@ class DependencyBuilder {
void ShareDependencyFrom(const DependencyBuilder& src); void ShareDependencyFrom(const DependencyBuilder& src);
private: protected:
void AddDependencyForCoalesceTensorOp(); void AddDependencyForCoalesceTensorOp();
void AddDependencyForCommunicationOp(); void AddDependencyForCommunicationOp();
void AddDependencyForRandomOp(); void AddDependencyForRandomOp();
...@@ -70,8 +70,14 @@ class DependencyBuilder { ...@@ -70,8 +70,14 @@ class DependencyBuilder {
void ShrinkDownstreamMap(); void ShrinkDownstreamMap();
void UpdateVarMinRwOp(
const std::map<size_t, std::set<size_t>>& op2dependences,
std::map<size_t, std::list<size_t>>* var2min_rw_op,
size_t cur_op,
size_t rw_var);
bool is_build_; bool is_build_;
const std::vector<Instruction>* instructions_; // not_own
size_t op_num_; size_t op_num_;
// ops_behind_ is the adjacency list about op to its posterior-ops, that is to // ops_behind_ is the adjacency list about op to its posterior-ops, that is to
...@@ -89,64 +95,27 @@ class DependencyBuilder { ...@@ -89,64 +95,27 @@ class DependencyBuilder {
// op_happens_before_ is a matrix form of ops_before_ and ops_behind_, it is // op_happens_before_ is a matrix form of ops_before_ and ops_behind_, it is
// used to speed up the query. // used to speed up the query.
std::shared_ptr<std::vector<std::vector<bool>>> op_happens_before_; std::shared_ptr<std::vector<std::vector<bool>>> op_happens_before_;
private:
const std::vector<Instruction>* instructions_; // not_own
}; };
// /// ======================== /// /// ======================== ///
// /// For new ir /// /// For new ir ///
// /// ======================== /// /// ======================== ///
class IrDependencyBuilder { class NewIrDependencyBuilder : public DependencyBuilder {
public: public:
IrDependencyBuilder() : is_build_(false), instructions_(nullptr) {} NewIrDependencyBuilder();
// build op dependencies and return the mapping from op to its downstream-op // build op dependencies and return the mapping from op to its downstream-op
// set // set
const std::map<size_t, std::set<size_t>>& Build( const std::map<size_t, std::set<size_t>>& Build(
const std::vector<std::unique_ptr<paddle::framework::InstructionBase>>& std::vector<paddle::framework::InstructionBase*> instructions);
instructions);
const std::map<size_t, std::set<size_t>>& OpDownstreamMap() const;
bool OpHappensBefore(size_t prior_op_idx, size_t posterior_op_idx) const {
PADDLE_ENFORCE_GE(
op_happens_before_.size(),
0,
phi::errors::Unavailable("op_happen_before is not yet built"));
return op_happens_before_.at(prior_op_idx).at(posterior_op_idx);
}
private:
void AddDependencyForCoalesceTensorOp();
void AddDependencyForCommunicationOp();
void AddDependencyForRandomOp();
void AddDependencyForReadOp();
void AddDependencyForSequentialRun();
void AddDownstreamOp(size_t prior_op_idx, size_t posterior_op_idx);
void BuildDownstreamMap(); void BuildDownstreamMap();
void ShrinkDownstreamMap(); private:
std::vector<paddle::framework::InstructionBase*> instructions_; // not_owned
bool is_build_;
const std::vector<std::unique_ptr<paddle::framework::InstructionBase>>*
instructions_; // not_own
size_t op_num_;
// ops_behind_ is the adjacency list about op to its posterior-ops, that is to
// say, op_behind_[i] == {a, b, c} means op[a], op[b] and op[c] depend on
// op[i] directly or indirectly. ops_before_ is the revered adjacency list of
// ops_behind_.
std::vector<std::vector<size_t>> ops_before_;
std::vector<std::vector<size_t>> ops_behind_;
// op_downstream_map_ is the mapping from op to its downstream-op set, that is
// to say, op_downstream_map_[i] == {a, b, c} means op[a], op[b] and op[c]
// depend on op[i] directly.
std::map<size_t, std::set<size_t>> op_downstream_map_;
// op_happens_before_ is a matrix form of ops_before_ and ops_behind_, it is
// used to speed up the query.
std::vector<std::vector<bool>> op_happens_before_;
}; };
} // namespace interpreter } // namespace interpreter
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/distributed/auto_parallel/dist_attr.h" #include "paddle/fluid/distributed/auto_parallel/dist_attr.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h"
#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" #include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h" #include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
...@@ -156,6 +157,18 @@ bool IsCpuOp(const Instruction& instr) { ...@@ -156,6 +157,18 @@ bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace()); return platform::is_cpu_place(instr.DeviceContext().GetPlace());
} }
bool IsCpuOp(Instruction* instr) {
return platform::is_cpu_place(instr->DeviceContext().GetPlace());
}
bool IsCpuOp(const paddle::framework::InstructionBase& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}
bool IsCpuOp(paddle::framework::InstructionBase* instr) {
return platform::is_cpu_place(instr->DeviceContext().GetPlace());
}
bool IsGradOp(const std::string& op_name) { bool IsGradOp(const std::string& op_name) {
return paddle::string::ends_with(op_name, "_grad"); return paddle::string::ends_with(op_name, "_grad");
} }
...@@ -173,6 +186,14 @@ bool IsMemcpyH2D(const Instruction& instr) { ...@@ -173,6 +186,14 @@ bool IsMemcpyH2D(const Instruction& instr) {
return instr.OpBase()->Type() == kMemcpyH2D; return instr.OpBase()->Type() == kMemcpyH2D;
} }
bool IsMemcpyH2D(Instruction* instr) {
return instr->OpBase()->Type() == kMemcpyH2D;
}
bool IsMemcpyH2D(paddle::framework::InstructionBase* instr) {
return instr->Name() == "pd.memcpy_h2d";
}
bool IsMemcpyOp(const Instruction& instr) { bool IsMemcpyOp(const Instruction& instr) {
return IsMemcpyD2H(instr) || IsMemcpyH2D(instr); return IsMemcpyD2H(instr) || IsMemcpyH2D(instr);
} }
...@@ -1127,6 +1148,29 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base, ...@@ -1127,6 +1148,29 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base,
} }
} }
void SetDeviceCommContext(::ir::Operation* op,
platform::DeviceContext* dev_ctx) {
auto op_attributes = op->attributes();
if (op_attributes.count("ring_id") != 0) {
int ring_id =
op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data();
const auto& comm_context_manager =
phi::distributed::CommContextManager::GetInstance();
if (comm_context_manager.Has(ring_id)) {
auto comm_context = comm_context_manager.Get(ring_id);
if (!dev_ctx->GetCommContext()) {
dev_ctx->SetCommContext(comm_context);
}
} else {
VLOG(3) << "op: "
<< op_attributes.at("op_name")
.dyn_cast<::ir::StrAttribute>()
.AsString()
<< ", ring_id: " << ring_id << ", get comm_context failed!";
}
}
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -43,6 +43,7 @@ using AtomicVectorSizeT = std::vector<std::atomic<size_t>>; ...@@ -43,6 +43,7 @@ using AtomicVectorSizeT = std::vector<std::atomic<size_t>>;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InstructionBase;
namespace interpreter { namespace interpreter {
class AsyncWorkQueue { class AsyncWorkQueue {
public: public:
...@@ -71,12 +72,22 @@ bool IsCommunicationOp(const Instruction& instr); ...@@ -71,12 +72,22 @@ bool IsCommunicationOp(const Instruction& instr);
bool IsCpuOp(const Instruction& instr); bool IsCpuOp(const Instruction& instr);
bool IsCpuOp(Instruction* instr);
bool IsCpuOp(const paddle::framework::InstructionBase& instr);
bool IsCpuOp(const paddle::framework::InstructionBase* instr);
bool IsGradOp(const std::string& op_name); bool IsGradOp(const std::string& op_name);
bool IsMemcpyD2H(const Instruction& instr); bool IsMemcpyD2H(const Instruction& instr);
bool IsMemcpyH2D(const Instruction& instr); bool IsMemcpyH2D(const Instruction& instr);
bool IsMemcpyH2D(Instruction* instr);
bool IsMemcpyH2D(paddle::framework::InstructionBase* instr);
bool IsMemcpyOp(const Instruction& instr); bool IsMemcpyOp(const Instruction& instr);
bool IsSupportedHeterPlace(const phi::Place& place); bool IsSupportedHeterPlace(const phi::Place& place);
...@@ -110,6 +121,9 @@ void LogDeviceMemoryStats(const platform::Place& place); ...@@ -110,6 +121,9 @@ void LogDeviceMemoryStats(const platform::Place& place);
void SetDeviceCommContext(framework::OperatorBase* operator_base, void SetDeviceCommContext(framework::OperatorBase* operator_base,
platform::DeviceContext* dev_ctx); platform::DeviceContext* dev_ctx);
void SetDeviceCommContext(::ir::Operation* op,
platform::DeviceContext* dev_ctx);
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -28,6 +28,43 @@ namespace interpreter { ...@@ -28,6 +28,43 @@ namespace interpreter {
enum DownstreamRunType { kDirectRun, kEventRun }; enum DownstreamRunType { kDirectRun, kEventRun };
class ContextManager {
public:
using DeviceContextMap =
std::map<Place,
std::shared_future<std::unique_ptr<platform::DeviceContext>>>;
static ContextManager& Instance() {
static ContextManager* ctx_manager = new ContextManager;
return *ctx_manager;
}
std::shared_future<std::unique_ptr<platform::DeviceContext>> Get(
const std::string& type,
const platform::Place& place,
int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get dev_ctx for " << type << " - " << place;
DeviceContextMap& ctxs = ctx_pool_[type];
if (ctxs.find(place) == ctxs.end()) {
platform::EmplaceDeviceContexts(
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true,
stream_priority);
}
return ctxs[place];
}
private:
ContextManager() {}
DISABLE_COPY_AND_ASSIGN(ContextManager);
std::mutex ctx_mtx_;
std::unordered_map<std::string, DeviceContextMap> ctx_pool_;
};
class StreamAnalyzer { class StreamAnalyzer {
public: public:
using DeviceContext = platform::DeviceContext; using DeviceContext = platform::DeviceContext;
...@@ -54,35 +91,75 @@ class StreamAnalyzer { ...@@ -54,35 +91,75 @@ class StreamAnalyzer {
GetEventInfo() const; GetEventInfo() const;
private: private:
bool HasDataDependency(const Instruction& cur_instr, bool HasDataDependency(Instruction* cur_instr, Instruction* next_instr) const;
const Instruction& next_instr) const;
void AnalyseAllEventInfo( void AnalyseAllEventInfo(
const std::vector<Instruction>& instructions, const std::vector<Instruction*>& instructions,
const std::vector<std::vector<std::vector<size_t>>>& run_type_info, const std::vector<std::vector<std::vector<size_t>>>& run_type_info,
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>* std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>*
event_info) const; event_info) const;
void AnalyseAllRunType( void AnalyseAllRunType(
const std::vector<Instruction>& instructions, const std::vector<Instruction*>& instructions,
const std::map<size_t, std::set<size_t>>& downstream_map, const std::map<size_t, std::set<size_t>>& downstream_map,
std::vector<std::vector<std::vector<size_t>>>* run_type_info) const; std::vector<std::vector<std::vector<size_t>>>* run_type_info) const;
void AnalyseEventInfoForTwoInstructions(
const std::vector<Instruction>& instructions,
const std::vector<std::vector<std::vector<size_t>>>& run_type_info,
const size_t cur_instr_id,
const size_t next_instr_id,
std::set<size_t>* waiter_instr_ids,
std::set<size_t>* visited_next_instr_id) const;
void ShrinkEventInfo( void ShrinkEventInfo(
const DependencyBuilder& dependency_builder, const DependencyBuilder& dependency_builder,
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>* std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>*
event_info_map) const; event_info_map) const;
DownstreamRunType AnalyseRunTypeForTwoInstructions( const Place place_;
const Instruction& cur_instr, const Instruction& next_instr) const; bool is_event_info_build_{false};
std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
event_info_;
};
/// ======================== ///
/// For new ir ///
/// ======================== ///
class NewIrStreamAnalyzer {
public:
using DeviceContext = platform::DeviceContext;
using Place = platform::Place;
explicit NewIrStreamAnalyzer(const Place& place) : place_(place) {
event_info_ = std::make_shared<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>();
}
~NewIrStreamAnalyzer() {}
void ConstructEvents(
const std::vector<std::unique_ptr<paddle::framework::InstructionBase>>&
instructions);
platform::DeviceType GetWaiterType(
const paddle::framework::InstructionBase* instr) const;
void ShareEventInfoFrom(const StreamAnalyzer& src);
std::shared_ptr<
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>>
GetEventInfo() const;
private:
void AnalyseAllRunType(
const std::vector<paddle::framework::InstructionBase*>& instructions,
const std::map<size_t, std::set<size_t>>& downstream_map,
std::vector<std::vector<std::vector<size_t>>>* run_type_info) const;
void AnalyseAllEventInfo(
const std::vector<paddle::framework::InstructionBase*>& instructions,
const std::vector<std::vector<std::vector<size_t>>>& run_type_info,
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>*
event_info) const;
void ShrinkEventInfo(
const NewIrDependencyBuilder& dependency_builder,
std::map<const DeviceContext*, std::map<size_t, std::set<size_t>>>*
event_info_map) const;
const Place place_; const Place place_;
bool is_event_info_build_{false}; bool is_event_info_build_{false};
......
...@@ -51,7 +51,8 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place, ...@@ -51,7 +51,8 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place,
execution_config_(execution_config), execution_config_(execution_config),
var_scope_(scope), var_scope_(scope),
scope_(scope), scope_(scope),
ir_program_(std::move(ir_prog)) { ir_program_(std::move(ir_prog)),
ir_stream_analyzer_(place) {
VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_;
static_build_ = FLAGS_new_executor_static_build && static_build_ = FLAGS_new_executor_static_build &&
!FLAGS_new_executor_use_cuda_graph && !FLAGS_new_executor_use_cuda_graph &&
...@@ -97,7 +98,6 @@ NewIRInterpreter::~NewIRInterpreter() { ...@@ -97,7 +98,6 @@ NewIRInterpreter::~NewIRInterpreter() {
gc_.reset(nullptr); gc_.reset(nullptr);
async_work_queue_.reset(); async_work_queue_.reset();
VLOG(4) << "~NewIRInterpreter(): " << this << " on " << place_; VLOG(4) << "~NewIRInterpreter(): " << this << " on " << place_;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
...@@ -197,9 +197,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names, ...@@ -197,9 +197,11 @@ FetchList NewIRInterpreter::Run(const std::vector<std::string>& feed_names,
VLOG(4) << DebugValueInfo(); VLOG(4) << DebugValueInfo();
// NOTE(zhangbo): Iterative version, gradually replacing BuildOpFuncList() // NOTE(zhangbo): Iterative version, gradually replacing BuildOpFuncList()
// and Convert() // and Convert() by:
// BuildInstruction(); // [1] BuildInstruction();
// BuildInstructionDependences(); // [2] BuildInstructionDependences();
// [3] ir_stream_analyzer_.ConstructEvents(&vec_instruction_base_);
// [4] GC();
std::vector<paddle::framework::OpFuncNode> op_func_nodes; std::vector<paddle::framework::OpFuncNode> op_func_nodes;
interpreter::BuildOpFuncList(place_, interpreter::BuildOpFuncList(place_,
...@@ -260,8 +262,35 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names, ...@@ -260,8 +262,35 @@ FetchList NewIRInterpreter::BetaRun(const std::vector<std::string>& feed_names,
&var_name_2_id_, &var_name_2_id_,
&variable_list_); &variable_list_);
VLOG(4) << DebugValueInfo(); VLOG(4) << DebugValueInfo();
BuildInstruction(); BuildInstruction();
BuildInstructionDependences(); BuildInstructionDependences();
ir_stream_analyzer_.ConstructEvents(vec_instruction_base_);
// add event for the input var of jit program, since there are async copied
// from gpu_pinned place to gpu place on compute stream.
for (size_t i = 0; i < dependecy_count_.size(); ++i) {
if (dependecy_count_[i] == 0) {
InstructionBase* inst = vec_instruction_base_[i].get();
if (inst->Name() == "pd.memcpy_d2h" && platform::is_gpu_place(place_)) {
for (auto& item : inst->Inputs()) {
for (auto var_id : item.second) {
auto name = GetNameById(var_id);
if (JitInputVars().count(name)) {
auto device_event = std::make_shared<platform::DeviceEvent>(
place_, platform::GenerateDeviceEventFlag());
VLOG(4) << "Add input event for input: " << name << " of "
<< inst->Name();
inst->AddEventToWait(
i, device_event, ir_stream_analyzer_.GetWaiterType(inst));
}
}
}
}
}
}
for (size_t instr_id = 0; instr_id < vec_instruction_base_.size(); for (size_t instr_id = 0; instr_id < vec_instruction_base_.size();
++instr_id) { ++instr_id) {
vec_instruction_base_[instr_id]->Run(); vec_instruction_base_[instr_id]->Run();
...@@ -345,6 +374,21 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { ...@@ -345,6 +374,21 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) {
const Scope* NewIRInterpreter::local_scope() const { return local_scope_; } const Scope* NewIRInterpreter::local_scope() const { return local_scope_; }
std::string NewIRInterpreter::GetNameById(int id) const {
// NOTE(zhiqiu): do not use vec_meta_info_[id].vardesc_->Name() since
// vec_meta_info_[id] may be nullptr,
// typically when the target variable is not existed in the original program
// desc, but created by interpretercore.
// For example, created and used by d2h_copy or h2d_copy operator.
auto it = std::find_if(var_name_2_id_.begin(),
var_name_2_id_.end(),
[id](const auto& pair) { return pair.second == id; });
if (it != var_name_2_id_.end()) {
return it->first;
}
return "";
}
void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) {
async_work_queue_ = reinterpret_cast<NewIRInterpreter*>(src)->GetWorkQueue(); async_work_queue_ = reinterpret_cast<NewIRInterpreter*>(src)->GetWorkQueue();
VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src
...@@ -1581,13 +1625,13 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace() { ...@@ -1581,13 +1625,13 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace() {
/// ======================== /// /// ======================== ///
void NewIRInterpreter::BuildInstruction() { void NewIRInterpreter::BuildInstruction() {
VLOG(0) << "Build Instructions for new ir ... "; VLOG(6) << "Build Instructions for new ir ... ";
vec_instruction_base_.clear(); vec_instruction_base_.clear();
size_t op_idx = 0; size_t op_idx = 0;
for (auto it = ir_program_->block()->begin(); for (auto it = ir_program_->block()->begin();
it != ir_program_->block()->end(); it != ir_program_->block()->end();
++it) { ++it) {
VLOG(0) << "Build Instruction for op: " << op_idx; VLOG(6) << "Build Instruction for op: " << op_idx;
if ((*it)->dialect()->name() == "pd_kernel") { if ((*it)->dialect()->name() == "pd_kernel") {
auto op_name = (*it) auto op_name = (*it)
->attributes() ->attributes()
...@@ -1635,7 +1679,12 @@ void NewIRInterpreter::BuildInstructionDependences() { ...@@ -1635,7 +1679,12 @@ void NewIRInterpreter::BuildInstructionDependences() {
// instr, and set the dependecy_count_ // instr, and set the dependecy_count_
size_t instr_num = vec_instruction_base_.size(); size_t instr_num = vec_instruction_base_.size();
dependecy_count_ = std::vector<size_t>(instr_num, 0); dependecy_count_ = std::vector<size_t>(instr_num, 0);
auto downstream_map = ir_dependency_builder_.Build(vec_instruction_base_);
std::vector<paddle::framework::InstructionBase*> instructions_ptr;
for (auto& instr : vec_instruction_base_) {
instructions_ptr.push_back(instr.get());
}
auto downstream_map = ir_dependency_builder_.Build(instructions_ptr);
for (size_t instr_id = 0; instr_id < instr_num; ++instr_id) { for (size_t instr_id = 0; instr_id < instr_num; ++instr_id) {
InstructionBase* cur_instr = vec_instruction_base_[instr_id].get(); InstructionBase* cur_instr = vec_instruction_base_[instr_id].get();
......
...@@ -84,6 +84,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -84,6 +84,8 @@ class NewIRInterpreter : public InterpreterBaseImpl {
hookfuncs_ = hookfuncs; hookfuncs_ = hookfuncs;
} }
std::string GetNameById(int id) const;
private: private:
// build graph // build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes); void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
...@@ -216,7 +218,9 @@ class NewIRInterpreter : public InterpreterBaseImpl { ...@@ -216,7 +218,9 @@ class NewIRInterpreter : public InterpreterBaseImpl {
std::vector<Variable*> variable_list_; std::vector<Variable*> variable_list_;
interpreter::IrDependencyBuilder ir_dependency_builder_; interpreter::NewIrDependencyBuilder ir_dependency_builder_;
interpreter::NewIrStreamAnalyzer ir_stream_analyzer_;
}; };
} // namespace framework } // namespace framework
......
...@@ -92,6 +92,10 @@ const std::map<std::string, int>& OpYamlInfoParser::OutputName2Id() const { ...@@ -92,6 +92,10 @@ const std::map<std::string, int>& OpYamlInfoParser::OutputName2Id() const {
return output_name2id_; return output_name2id_;
} }
const std::vector<int>& OpYamlInfoParser::NoNeedBufferIds() const {
return no_need_buffer_ids_;
}
bool OpYamlInfoParser::HasInplace(const std::string& out_name) const { bool OpYamlInfoParser::HasInplace(const std::string& out_name) const {
auto& inplace_info = std::get<3>(op_info_tuple_).inplace; auto& inplace_info = std::get<3>(op_info_tuple_).inplace;
for (size_t i = 0; i < inplace_info.size(); i++) { for (size_t i = 0; i < inplace_info.size(); i++) {
...@@ -117,14 +121,16 @@ const std::string& OpYamlInfoParser::InplaceName( ...@@ -117,14 +121,16 @@ const std::string& OpYamlInfoParser::InplaceName(
void OpYamlInfoParser::parse() { void OpYamlInfoParser::parse() {
auto input_info = std::get<0>(op_info_tuple_); auto input_info = std::get<0>(op_info_tuple_);
int input_start_index = 0;
for (size_t i = 0; i < input_info.size(); ++i) { for (size_t i = 0; i < input_info.size(); ++i) {
input_name2id_[input_info[i].name] = input_start_index++; input_name2id_[input_info[i].name] = i;
input_name_list_.push_back(input_info[i].name); input_name_list_.push_back(input_info[i].name);
input_info_[input_info[i].name] = input_info[i]; input_info_[input_info[i].name] = input_info[i];
if (!input_info[i].is_mutable_attribute) { if (!input_info[i].is_mutable_attribute) {
input_tensor_number_++; input_tensor_number_++;
} }
if (input_info[i].no_need_buffer) {
no_need_buffer_ids_.push_back(i);
}
} }
auto attribute_info = std::get<1>(op_info_tuple_); auto attribute_info = std::get<1>(op_info_tuple_);
...@@ -133,10 +139,9 @@ void OpYamlInfoParser::parse() { ...@@ -133,10 +139,9 @@ void OpYamlInfoParser::parse() {
attr_info_[attribute_info[i].name] = attribute_info[i]; attr_info_[attribute_info[i].name] = attribute_info[i];
} }
int output_start_index = 0;
auto output_info = std::get<2>(op_info_tuple_); auto output_info = std::get<2>(op_info_tuple_);
for (size_t i = 0; i < output_info.size(); ++i) { for (size_t i = 0; i < output_info.size(); ++i) {
output_name2id_[output_info[i].name] = output_start_index++; output_name2id_[output_info[i].name] = i;
output_name_list_.push_back(output_info[i].name); output_name_list_.push_back(output_info[i].name);
output_info_[output_info[i].name] = output_info[i]; output_info_[output_info[i].name] = output_info[i];
} }
......
...@@ -37,6 +37,8 @@ class OpYamlInfoParser { ...@@ -37,6 +37,8 @@ class OpYamlInfoParser {
const std::map<std::string, int>& InputName2Id() const; const std::map<std::string, int>& InputName2Id() const;
const std::map<std::string, int>& OutputName2Id() const; const std::map<std::string, int>& OutputName2Id() const;
const std::vector<int>& NoNeedBufferIds() const;
const std::vector<std::string>& InputNames() const { const std::vector<std::string>& InputNames() const {
return input_name_list_; return input_name_list_;
} }
...@@ -65,6 +67,9 @@ class OpYamlInfoParser { ...@@ -65,6 +67,9 @@ class OpYamlInfoParser {
std::map<std::string, OpInputInfo> input_info_; std::map<std::string, OpInputInfo> input_info_;
int input_tensor_number_{0}; int input_tensor_number_{0};
// no_need_buffer_ids
std::vector<int> no_need_buffer_ids_;
// attribute info // attribute info
std::vector<std::string> attribute_name_list_; std::vector<std::string> attribute_name_list_;
std::map<std::string, OpAttributeInfo> attr_info_; std::map<std::string, OpAttributeInfo> attr_info_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册