diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_base.h b/paddle/fluid/framework/new_executor/instruction/instruction_base.h index a31b65c1039d633e45a9dd6895e55f6f72150ba0..7452990a1d9076ce83140288fe70ad2b9310f2a2 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_base.h +++ b/paddle/fluid/framework/new_executor/instruction/instruction_base.h @@ -45,9 +45,9 @@ class InstructionBase { OpFuncType KernelType() const; void SetKernelType(OpFuncType type) { type_ = type; } - int GetStreamPriority() const { return scheduling_priority_; } - void SetStreamPriority(SchedulingPriority scheduling_priority) { - scheduling_priority_ = scheduling_priority; + int GetStreamPriority() const { return stream_priority_; } + void SetStreamPriority(int stream_priority) { + stream_priority_ = stream_priority; } SchedulingPriority GetSchedulingPriority() const { @@ -107,22 +107,31 @@ class InstructionBase { std::map& GetMutableInplaceBackMap() { return inplace_back_map_; } const std::map& GetInplaceBackMap() { return inplace_back_map_; } - const std::unordered_map>& Inputs() const { + const std::unordered_map<::ir::Value, std::vector>& Inputs() const { return input_index_; } - std::unordered_map>& GetMutableInputs() { + std::unordered_map<::ir::Value, std::vector>& GetMutableInputs() { return input_index_; } - void SetInputs(const std::unordered_map>& inputs); + void SetInputs( + const std::unordered_map<::ir::Value, std::vector>& inputs); - const std::unordered_map>& Outputs() const { + const std::unordered_map<::ir::Value, std::vector>& Outputs() const { return output_index_; } - std::unordered_map>& GetMutableOutputs() { + std::unordered_map<::ir::Value, std::vector>& GetMutableOutputs() { return output_index_; } void SetOutputs( - const std::unordered_map>& outputs); + const std::unordered_map<::ir::Value, std::vector>& outputs); + + const std::unordered_set<::ir::Value>& NoNeedBuffer() const { + return no_need_buffer_values_; + } + void SetNoNeedBuffer( + const std::unordered_set<::ir::Value>& no_need_buffer_values) { + no_need_buffer_values_ = no_need_buffer_values; + } virtual void Run() = 0; @@ -159,9 +168,11 @@ class InstructionBase { std::map inplace_back_map_; - std::unordered_map> input_index_; + std::unordered_map<::ir::Value, std::vector> input_index_; + + std::unordered_map<::ir::Value, std::vector> output_index_; - std::unordered_map> output_index_; + std::unordered_set<::ir::Value> no_need_buffer_values_; }; } // namespace framework diff --git a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc index 50a8161cd4332ab044a7a160ae5066e0668f86e0..39e791aca3f8ac077c522fc861a2c6e5111003f4 100644 --- a/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.cc @@ -15,12 +15,15 @@ #include "paddle/fluid/framework/new_executor/instruction/phi_kernel_instruction.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" +#include "paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/interface/infermeta.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info_parser.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h" +#include "paddle/fluid/platform/collective_helper.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/core/type_defs.h" @@ -32,6 +35,77 @@ namespace paddle { namespace framework { +platform::DeviceContext* ParseDeviceContext( + ir::Operation* op, + platform::DeviceContext* origin_dev_ctx, + const platform::Place& place, + const std::string& execution_stream, + const int stream_priority) { + auto op_attributes = op->attributes(); + auto op_name = + op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString(); + interpreter::ContextManager& ctx_manager = + interpreter::ContextManager::Instance(); + + platform::DeviceContext* dev_ctx = nullptr; + + // only gpu need update. xpu not need, because xpu memcpy op kernel is + // synchronous. + if (platform::is_gpu_place(place) || platform::is_custom_place(place)) { + VLOG(6) << "Parse DeviceContext for " << op_name + << ", execution stream = " << execution_stream; + if (execution_stream != kDefaultStream) { + dev_ctx = ctx_manager + .Get(std::string(kCustomStream) + "-" + execution_stream, + place, + stream_priority) + .get() + .get(); + interpreter::SetDeviceCommContext(op, dev_ctx); + return dev_ctx; + } + + if (op_name == interpreter::kMemcpyD2H) { + dev_ctx = ctx_manager.Get(std::string(kD2HStream), place, stream_priority) + .get() + .get(); + interpreter::SetDeviceCommContext(op, dev_ctx); + return dev_ctx; + } else if (op_name == interpreter::kMemcpyH2D) { + dev_ctx = ctx_manager.Get(std::string(kH2DStream), place, stream_priority) + .get() + .get(); + interpreter::SetDeviceCommContext(op, dev_ctx); + return dev_ctx; + } + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + // NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum + // with use_cal_stream==false by returning a device context getting from the + // global NCCLCommContext instance. Because when use_calc_stream==false, in + // OP kernel, the NCCL communication will be launched to the stream directly + // getting from the global NCCLCommContext instance rather than the + // DeviceContext passed from executor (see CAllReduceOpCUDAKernel in + // c_allreduce_op.h). Now it is just a temporary solution for ONLY + // c_allreduce_sum which is used in ResNet50 distributed training. + if (op_name == "c_allreduce_sum" && op_attributes.at("use_calc_stream") + .dyn_cast<::ir::BoolAttribute>() + .data() == false) { + int ring_id = + op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data(); + return platform::NCCLCommContext::Instance() + .Get(ring_id, place) + ->dev_context(); + } +#endif + } + + if (origin_dev_ctx != nullptr) { + interpreter::SetDeviceCommContext(op, origin_dev_ctx); + } + return origin_dev_ctx; +} + OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) { if (platform::is_cpu_place(place)) { return OpFuncType::kCpuSync; @@ -172,15 +246,27 @@ PhiKernelInstruction::PhiKernelInstruction( kernel_context_.SetDeviceContext(phi::DeviceContextPool::Instance().Get( phi::TransToPhiPlace(kernel_key.backend()))); VLOG(6) << "finish process kernel context"; - - SetDeviceContext(phi::DeviceContextPool::Instance().Get( - phi::TransToPhiPlace(kernel_key.backend()))); + SetDeviceContext( + ParseDeviceContext(op, + phi::DeviceContextPool::Instance().Get( + phi::TransToPhiPlace(kernel_key.backend())), + place, + GetExecutionStream(), + GetStreamPriority())); VLOG(6) << "finish process device context"; Scope* inner_scope = local_scope == nullptr ? scope : local_scope; InitInputsOutputsIds( op, inner_scope, value_2_var_name, var_name_2_id, variable_2_var_name); VLOG(6) << "finish process inputs outputs index"; + + auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); + std::unordered_set<::ir::Value> no_need_buffer_values; + for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { + no_need_buffer_values.insert(op->operand(no_need_buffer_ids[id])); + } + SetNoNeedBuffer(no_need_buffer_values); + VLOG(6) << "finish process no need buffer"; } std::vector GetValueIds( diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc index b4d6a62ed22a3a4c87ba61a337fe69e4cc539ede..639885b80e53449e61c8b486422b54b473a14ef5 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.cc @@ -381,10 +381,8 @@ void DependencyBuilder::AddDownstreamOp(size_t prior_op_idx, VLOG(8) << prior_op_idx << "->" << posterior_op_idx; VLOG(8) << "Add dependency from " - << instructions_->at(prior_op_idx).OpBase()->Type() << "(" - << prior_op_idx << ") to " - << instructions_->at(posterior_op_idx).OpBase()->Type() << "(" - << posterior_op_idx << ")"; + << "prior_op_idx(" << prior_op_idx << ") to " + << "posterior_op_idx(" << posterior_op_idx << ")"; } void DependencyBuilder::BuildDownstreamMap() { @@ -405,22 +403,6 @@ void DependencyBuilder::BuildDownstreamMap() { op2dependences[op_idx] = std::set(); } - auto update_var_min_rw_op = - [](const std::map>& op2dependences, - std::map>* var2min_rw_op, - size_t cur_op, - size_t rw_var) { - // rw_var is inputs or outputs of cur_op - // this function update the var2min_rw_op set . - if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) { - (*var2min_rw_op)[rw_var] = std::list(); - } - for (auto dep_op : op2dependences.at(cur_op)) { - var2min_rw_op->at(rw_var).remove(dep_op); - } - var2min_rw_op->at(rw_var).push_back(cur_op); - }; - for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { remove_duplicate.clear(); // step1: update the op2dependences structure @@ -485,7 +467,7 @@ void DependencyBuilder::BuildDownstreamMap() { for (auto var : item.second) { if (remove_duplicate.count(var) == 0) { // var in input list and in output list, so remove it. - update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); + UpdateVarMinRwOp(op2dependences, &var2min_rw_op, op_idx, var); } } } @@ -546,22 +528,45 @@ void DependencyBuilder::ShrinkDownstreamMap() { << StringizeDownstreamMap(*op_downstream_map_); } +void DependencyBuilder::UpdateVarMinRwOp( + const std::map>& op2dependences, + std::map>* var2min_rw_op, + size_t cur_op, + size_t rw_var) { + // rw_var is inputs or outputs of cur_op + // this function update the var2min_rw_op set . + if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) { + (*var2min_rw_op)[rw_var] = std::list(); + } + for (auto dep_op : op2dependences.at(cur_op)) { + var2min_rw_op->at(rw_var).remove(dep_op); + } + var2min_rw_op->at(rw_var).push_back(cur_op); +} + /// ======================== /// /// For new ir /// /// ======================== /// -const std::map>& IrDependencyBuilder::Build( - const std::vector>& - instructions) { +NewIrDependencyBuilder::NewIrDependencyBuilder() { + is_build_ = false; + op_downstream_map_ = std::make_shared>>(); + op_happens_before_ = std::make_shared>>(); +} + +const std::map>& NewIrDependencyBuilder::Build( + std::vector instructions) { if (is_build_) { - return op_downstream_map_; + return *op_downstream_map_; } - instructions_ = &instructions; - op_num_ = instructions_->size(); + std::tie(op_downstream_map_, op_happens_before_) = GetDependency(); + + instructions_ = instructions; + op_num_ = instructions_.size(); ops_before_.assign(op_num_, {}); ops_behind_.assign(op_num_, {}); - op_happens_before_.assign(op_num_, std::vector(op_num_, false)); + op_happens_before_->assign(op_num_, std::vector(op_num_, false)); BuildDownstreamMap(); VLOG(6) << "Finish BuildDownstreamMap"; @@ -576,16 +581,16 @@ const std::map>& IrDependencyBuilder::Build( // TODO(zhangbo): Add dependency for special op ? VLOG(6) << "Finish build dependency"; - VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map_); + VLOG(8) << "downstream count: " << CountDownstreamMap(*op_downstream_map_); VLOG(8) << "downstream_map: " << std::endl - << StringizeDownstreamMap(op_downstream_map_); + << StringizeDownstreamMap(*op_downstream_map_); is_build_ = true; - return op_downstream_map_; + return *op_downstream_map_; } -void IrDependencyBuilder::BuildDownstreamMap() { +void NewIrDependencyBuilder::BuildDownstreamMap() { auto var2min_rw_op = std::map>(); // # map from variable id to read // write op id. @@ -604,27 +609,11 @@ void IrDependencyBuilder::BuildDownstreamMap() { op2dependences[op_idx] = std::set(); } - auto update_var_min_rw_op = - [](const std::map>& op2dependences, - std::map>* var2min_rw_op, - size_t cur_op, - size_t rw_var) { - // rw_var is inputs or outputs of cur_op - // this function update the var2min_rw_op set . - if (var2min_rw_op->find(rw_var) == var2min_rw_op->end()) { - (*var2min_rw_op)[rw_var] = std::list(); - } - for (auto dep_op : op2dependences.at(cur_op)) { - var2min_rw_op->at(rw_var).remove(dep_op); - } - var2min_rw_op->at(rw_var).push_back(cur_op); - }; - for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { remove_duplicate.clear(); // step1: update the op2dependences structure for (auto& item : - instructions_->at(op_idx)->Inputs()) { // for all inputs(read only) + instructions_.at(op_idx)->Inputs()) { // for all inputs(read only) for (auto var : item.second) { if (var2recent_write_op.count(var)) op2dependences[op_idx].insert(var2recent_write_op[var]); @@ -632,7 +621,7 @@ void IrDependencyBuilder::BuildDownstreamMap() { } for (auto& item : - instructions_->at(op_idx)->Outputs()) { // for all write vars + instructions_.at(op_idx)->Outputs()) { // for all write vars for (auto var : item.second) { if (var2min_rw_op.count(var)) { for (auto dep_op : var2min_rw_op[var]) { @@ -644,7 +633,7 @@ void IrDependencyBuilder::BuildDownstreamMap() { // step2: update 2 var2xxxx data structure for (auto& item : - instructions_->at(op_idx)->Outputs()) { // for all write vars + instructions_.at(op_idx)->Outputs()) { // for all write vars for (auto var : item.second) { var2recent_write_op[var] = op_idx; var2min_rw_op[var] = {static_cast(op_idx)}; @@ -653,11 +642,11 @@ void IrDependencyBuilder::BuildDownstreamMap() { } for (auto& item : - instructions_->at(op_idx)->Inputs()) { // for all inputs(read only) + instructions_.at(op_idx)->Inputs()) { // for all inputs(read only) for (auto var : item.second) { if (remove_duplicate.count(var) == 0) { // var in input list and in output list, so remove it. - update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var); + UpdateVarMinRwOp(op2dependences, &var2min_rw_op, op_idx, var); } } } @@ -675,118 +664,6 @@ void IrDependencyBuilder::BuildDownstreamMap() { } } -void IrDependencyBuilder::AddDownstreamOp(size_t prior_op_idx, - size_t posterior_op_idx) { - PADDLE_ENFORCE_EQ( - OpHappensBefore(posterior_op_idx, prior_op_idx), - false, - phi::errors::Unavailable( - "Can not add dependency %d->%d because %d is run before %d", - prior_op_idx, - posterior_op_idx, - posterior_op_idx, - prior_op_idx)); - - std::set& downstream_ops = op_downstream_map_[prior_op_idx]; - // NOTE(Ruibiao): Here the downstream map shrinking is best-effort, therefore - // ShrinkDownstreamMap after BuildDownstreamMap is still helpful. For example, - // a->c will not be shrinked in the following case: AddDownstreamOp(a, b) -> - // AddDownstreamOp(a, c) -> AddDownstreamOp(b, c), it should be shrinked by - // ShrinkDownstreamMap. - for (size_t op_idx : downstream_ops) { - if (OpHappensBefore(op_idx, posterior_op_idx)) { - VLOG(7) << "Find dependencies " << prior_op_idx << "->" << op_idx << "->" - << posterior_op_idx << ", skip adding " << prior_op_idx << "->" - << posterior_op_idx; - return; - } - } - downstream_ops.insert(posterior_op_idx); - - std::vector prior_of_prior = ops_before_[prior_op_idx]; - std::vector posterior_of_posterior = ops_behind_[posterior_op_idx]; - - auto update_op_happen_before = [this](size_t prior_op_idx, - size_t posterior_op_idx) { - if (!op_happens_before_[prior_op_idx][posterior_op_idx]) { - op_happens_before_[prior_op_idx][posterior_op_idx] = true; - ops_before_[posterior_op_idx].push_back(prior_op_idx); - ops_behind_[prior_op_idx].push_back(posterior_op_idx); - } - }; - - update_op_happen_before(prior_op_idx, posterior_op_idx); - - // All ops before prior-op are also before posterior-op - for (size_t op_idx : prior_of_prior) { - update_op_happen_before(op_idx, posterior_op_idx); - } - - // All ops after posterior-op are also after prior-op - for (size_t op_idx : posterior_of_posterior) { - update_op_happen_before(prior_op_idx, op_idx); - } - - VLOG(8) << prior_op_idx << "->" << posterior_op_idx; - VLOG(8) << "Add dependency from " << instructions_->at(prior_op_idx)->Name() - << "(" << prior_op_idx << ") to " - << instructions_->at(posterior_op_idx)->Name() << "(" - << posterior_op_idx << ")"; -} - -void IrDependencyBuilder::ShrinkDownstreamMap() { - // remove unnecessary downstream ops - // for example, a->b->c - // a: b, c - // b: c - // => - // a: b - // b: c - - // shrink, find the downstream op that has no other op in the - // downstream list happens before it - for (size_t i = 0; i < op_num_; ++i) { - if (op_downstream_map_.find(i) == op_downstream_map_.end()) { - continue; - } - - std::set minumum_nexts; - for (size_t item : op_downstream_map_.at(i)) { - bool not_after_any = true; - // find the op that is not executed after any - for (size_t other_item : op_downstream_map_.at(i)) { - if (OpHappensBefore(other_item, item)) { - VLOG(8) << "happens_before: " << other_item << "->" << item - << ", so skip " << item; - not_after_any = false; - break; - } - } - if (not_after_any) { - VLOG(8) << "downstream op of " << i << ": " << item; - minumum_nexts.insert(item); - } - } - // NOTE(Ruibiao): op_happens_before will not be changed when shrink - // dowstream map - op_downstream_map_.at(i) = minumum_nexts; - } - VLOG(8) << "Finish shrink downstream map"; - VLOG(8) << "downstream count: " << CountDownstreamMap(op_downstream_map_); - VLOG(8) << "downstream_map: " << std::endl - << StringizeDownstreamMap(op_downstream_map_); -} - -void IrDependencyBuilder::AddDependencyForSequentialRun() { - size_t dependence_op_idx = ULLONG_MAX; - for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { - if (dependence_op_idx != ULLONG_MAX) { - AddDownstreamOp(dependence_op_idx, op_idx); - } - dependence_op_idx = op_idx; - } -} - } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h index e28808044940b48bbe8e3e5819163b0fd942efaf..2593b11a2e48a51c266cc0b9468a7efd3f69f862 100644 --- a/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h +++ b/paddle/fluid/framework/new_executor/interpreter/dependency_builder.h @@ -57,7 +57,7 @@ class DependencyBuilder { void ShareDependencyFrom(const DependencyBuilder& src); - private: + protected: void AddDependencyForCoalesceTensorOp(); void AddDependencyForCommunicationOp(); void AddDependencyForRandomOp(); @@ -70,8 +70,14 @@ class DependencyBuilder { void ShrinkDownstreamMap(); + void UpdateVarMinRwOp( + const std::map>& op2dependences, + std::map>* var2min_rw_op, + size_t cur_op, + size_t rw_var); + bool is_build_; - const std::vector* instructions_; // not_own + size_t op_num_; // ops_behind_ is the adjacency list about op to its posterior-ops, that is to @@ -89,64 +95,27 @@ class DependencyBuilder { // op_happens_before_ is a matrix form of ops_before_ and ops_behind_, it is // used to speed up the query. std::shared_ptr>> op_happens_before_; + + private: + const std::vector* instructions_; // not_own }; -// /// ======================== /// -// /// For new ir /// -// /// ======================== /// -class IrDependencyBuilder { +/// ======================== /// +/// For new ir /// +/// ======================== /// +class NewIrDependencyBuilder : public DependencyBuilder { public: - IrDependencyBuilder() : is_build_(false), instructions_(nullptr) {} + NewIrDependencyBuilder(); // build op dependencies and return the mapping from op to its downstream-op // set const std::map>& Build( - const std::vector>& - instructions); - - const std::map>& OpDownstreamMap() const; - - bool OpHappensBefore(size_t prior_op_idx, size_t posterior_op_idx) const { - PADDLE_ENFORCE_GE( - op_happens_before_.size(), - 0, - phi::errors::Unavailable("op_happen_before is not yet built")); - return op_happens_before_.at(prior_op_idx).at(posterior_op_idx); - } - - private: - void AddDependencyForCoalesceTensorOp(); - void AddDependencyForCommunicationOp(); - void AddDependencyForRandomOp(); - void AddDependencyForReadOp(); - void AddDependencyForSequentialRun(); - - void AddDownstreamOp(size_t prior_op_idx, size_t posterior_op_idx); + std::vector instructions); void BuildDownstreamMap(); - void ShrinkDownstreamMap(); - - bool is_build_; - const std::vector>* - instructions_; // not_own - size_t op_num_; - - // ops_behind_ is the adjacency list about op to its posterior-ops, that is to - // say, op_behind_[i] == {a, b, c} means op[a], op[b] and op[c] depend on - // op[i] directly or indirectly. ops_before_ is the revered adjacency list of - // ops_behind_. - std::vector> ops_before_; - std::vector> ops_behind_; - - // op_downstream_map_ is the mapping from op to its downstream-op set, that is - // to say, op_downstream_map_[i] == {a, b, c} means op[a], op[b] and op[c] - // depend on op[i] directly. - std::map> op_downstream_map_; - - // op_happens_before_ is a matrix form of ops_before_ and ops_behind_, it is - // used to speed up the query. - std::vector> op_happens_before_; + private: + std::vector instructions_; // not_owned }; } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 8ebaea588062a7de4a361f8433ad75baa68a31d6..66c09760ffff61b44bc0dfeb27656b2a6155b04d 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/distributed/auto_parallel/dist_attr.h" #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" #include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h" #include "paddle/fluid/framework/new_executor/interpreter/execution_config.h" #include "paddle/fluid/framework/new_executor/interpreter/static_build.h" @@ -156,6 +157,18 @@ bool IsCpuOp(const Instruction& instr) { return platform::is_cpu_place(instr.DeviceContext().GetPlace()); } +bool IsCpuOp(Instruction* instr) { + return platform::is_cpu_place(instr->DeviceContext().GetPlace()); +} + +bool IsCpuOp(const paddle::framework::InstructionBase& instr) { + return platform::is_cpu_place(instr.DeviceContext().GetPlace()); +} + +bool IsCpuOp(paddle::framework::InstructionBase* instr) { + return platform::is_cpu_place(instr->DeviceContext().GetPlace()); +} + bool IsGradOp(const std::string& op_name) { return paddle::string::ends_with(op_name, "_grad"); } @@ -173,6 +186,14 @@ bool IsMemcpyH2D(const Instruction& instr) { return instr.OpBase()->Type() == kMemcpyH2D; } +bool IsMemcpyH2D(Instruction* instr) { + return instr->OpBase()->Type() == kMemcpyH2D; +} + +bool IsMemcpyH2D(paddle::framework::InstructionBase* instr) { + return instr->Name() == "pd.memcpy_h2d"; +} + bool IsMemcpyOp(const Instruction& instr) { return IsMemcpyD2H(instr) || IsMemcpyH2D(instr); } @@ -1127,6 +1148,29 @@ void SetDeviceCommContext(framework::OperatorBase* operator_base, } } +void SetDeviceCommContext(::ir::Operation* op, + platform::DeviceContext* dev_ctx) { + auto op_attributes = op->attributes(); + if (op_attributes.count("ring_id") != 0) { + int ring_id = + op_attributes.at("ring_id").dyn_cast<::ir::Int32Attribute>().data(); + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + if (comm_context_manager.Has(ring_id)) { + auto comm_context = comm_context_manager.Get(ring_id); + if (!dev_ctx->GetCommContext()) { + dev_ctx->SetCommContext(comm_context); + } + } else { + VLOG(3) << "op: " + << op_attributes.at("op_name") + .dyn_cast<::ir::StrAttribute>() + .AsString() + << ", ring_id: " << ring_id << ", get comm_context failed!"; + } + } +} + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index eb87c8bcb4ce336ff6549d077fd9974db4a4e935..b37e46d5206e661924645b0c363e2410c5350f74 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -43,6 +43,7 @@ using AtomicVectorSizeT = std::vector>; namespace paddle { namespace framework { +class InstructionBase; namespace interpreter { class AsyncWorkQueue { public: @@ -71,12 +72,22 @@ bool IsCommunicationOp(const Instruction& instr); bool IsCpuOp(const Instruction& instr); +bool IsCpuOp(Instruction* instr); + +bool IsCpuOp(const paddle::framework::InstructionBase& instr); + +bool IsCpuOp(const paddle::framework::InstructionBase* instr); + bool IsGradOp(const std::string& op_name); bool IsMemcpyD2H(const Instruction& instr); bool IsMemcpyH2D(const Instruction& instr); +bool IsMemcpyH2D(Instruction* instr); + +bool IsMemcpyH2D(paddle::framework::InstructionBase* instr); + bool IsMemcpyOp(const Instruction& instr); bool IsSupportedHeterPlace(const phi::Place& place); @@ -110,6 +121,9 @@ void LogDeviceMemoryStats(const platform::Place& place); void SetDeviceCommContext(framework::OperatorBase* operator_base, platform::DeviceContext* dev_ctx); + +void SetDeviceCommContext(::ir::Operation* op, + platform::DeviceContext* dev_ctx); } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index 43cad874add7479332607d79c9bd14d07f60ff56..a83f02a861731855078fff807ae8a8a9ca6dcc24 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -17,6 +17,7 @@ #include #include +#include "paddle/fluid/framework/new_executor/instruction/instruction_base.h" #include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device_context.h" @@ -28,42 +29,6 @@ namespace interpreter { using DeviceContext = platform::DeviceContext; using DeviceEvent = platform::DeviceEvent; -class ContextManager { - public: - using DeviceContextMap = - std::map>>; - - static ContextManager& Instance() { - static ContextManager* ctx_manager = new ContextManager; - return *ctx_manager; - } - - std::shared_future> Get( - const std::string& type, - const platform::Place& place, - int stream_priority) { - std::lock_guard lk(ctx_mtx_); - VLOG(6) << "Get dev_ctx for " << type << " - " << place; - - DeviceContextMap& ctxs = ctx_pool_[type]; - if (ctxs.find(place) == ctxs.end()) { - platform::EmplaceDeviceContexts( - &ctxs, - {place}, - /*disable_setting_default_stream_for_allocator=*/true, - stream_priority); - } - return ctxs[place]; - } - - private: - ContextManager() {} - DISABLE_COPY_AND_ASSIGN(ContextManager); - - std::mutex ctx_mtx_; - std::unordered_map ctx_pool_; -}; - inline std::string RunTypeToString(DownstreamRunType run_type) { if (run_type == DownstreamRunType::kDirectRun) { return "DirectRun"; @@ -79,6 +44,11 @@ void StreamAnalyzer::ConstructEvents(std::vector* instructions) { cross_step_merged_instructions.emplace_back(instr); } + std::vector cross_step_merged_instructions_ptr; + for (Instruction& instr : cross_step_merged_instructions) { + cross_step_merged_instructions_ptr.emplace_back(&instr); + } + DependencyBuilder dependency_builder; dependency_builder.Build(cross_step_merged_instructions); @@ -91,10 +61,10 @@ void StreamAnalyzer::ConstructEvents(std::vector* instructions) { /*number_of_run_type = */ 2)); // instr_id -> run_type -> // next_instr_id AnalyseAllRunType( - cross_step_merged_instructions, downstream_map, &run_type_info); + cross_step_merged_instructions_ptr, downstream_map, &run_type_info); AnalyseAllEventInfo( - cross_step_merged_instructions, run_type_info, event_info_.get()); + cross_step_merged_instructions_ptr, run_type_info, event_info_.get()); ShrinkEventInfo(dependency_builder, event_info_.get()); is_event_info_build_ = true; } @@ -210,38 +180,41 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext( return op_func_node.dev_ctx_; } -bool StreamAnalyzer::HasDataDependency(const Instruction& cur_instr, - const Instruction& next_instr) const { - auto no_need_buffer_ins = - [](const Instruction& instr) -> const std::unordered_set { - auto* op = instr.OpBase(); - auto& inferer = op->Info().NoNeedBufferVarsInferer(); - if (inferer) { - return inferer(op->Inputs(), op->Outputs(), op->Attrs()); - } - return std::unordered_set(); - }; +const std::unordered_set no_need_buffer_ins(Instruction* instr) { + auto* op = instr->OpBase(); + auto& inferer = op->Info().NoNeedBufferVarsInferer(); + if (inferer) { + return inferer(op->Inputs(), op->Outputs(), op->Attrs()); + } + return std::unordered_set(); +} + +const std::unordered_set no_need_buffer_ins( + const paddle::framework::InstructionBase* instr) { + return instr->NoNeedBuffer(); +} +template +bool has_data_dependency(T1* cur_instr, T1* next_instr) { // cur_instr->var->next_instr std::unordered_set cur_var_ids; - for (auto& item : cur_instr.Outputs()) { + for (auto& item : cur_instr->Outputs()) { cur_var_ids.insert(item.second.begin(), item.second.end()); } - const std::unordered_set next_instr_no_need_buffer_ins = + const std::unordered_set next_instr_no_need_buffer_ins = no_need_buffer_ins(next_instr); - for (auto& item : next_instr.Inputs()) { + for (auto& item : next_instr->Inputs()) { if (next_instr_no_need_buffer_ins.find(item.first) != next_instr_no_need_buffer_ins.end()) { continue; } for (auto next_var_id : item.second) { if (cur_var_ids.find(next_var_id) != cur_var_ids.end()) { - VLOG(6) << "Found data dependency from " << cur_instr.OpBase()->Type() - << "(" << cur_instr.Id() << ") to " - << next_instr.OpBase()->Type() << "(" << next_instr.Id() - << ") at variable " << item.first << "(" << next_var_id << ")"; + VLOG(6) << "Found data dependency from " + << "cur_instr(" << cur_instr->Id() << ") to " + << "next_instr(" << next_instr->Id() << ")"; return true; } } @@ -249,22 +222,21 @@ bool StreamAnalyzer::HasDataDependency(const Instruction& cur_instr, // cur_instr->var && next_instr->var // var->cur_instr && next_instr->var - const std::unordered_set cur_instr_no_need_buffer_ins = + const std::unordered_set cur_instr_no_need_buffer_ins = no_need_buffer_ins(cur_instr); - for (auto& item : cur_instr.Inputs()) { + for (auto& item : cur_instr->Inputs()) { if (cur_instr_no_need_buffer_ins.find(item.first) == cur_instr_no_need_buffer_ins.end()) { cur_var_ids.insert(item.second.begin(), item.second.end()); } } - for (auto& item : next_instr.Outputs()) { + for (auto& item : next_instr->Outputs()) { for (auto next_var_id : item.second) { if (cur_var_ids.find(next_var_id) != cur_var_ids.end()) { - VLOG(6) << "Found data dependency from " << cur_instr.OpBase()->Type() - << "(" << cur_instr.Id() << ") to " - << next_instr.OpBase()->Type() << "(" << next_instr.Id() - << ") at variable " << item.first << "(" << next_var_id << ")"; + VLOG(6) << "Found data dependency from " + << "cur_instr(" << cur_instr->Id() << ") to " + << "next_instr(" << next_instr->Id() << ")"; return true; } } @@ -273,64 +245,80 @@ bool StreamAnalyzer::HasDataDependency(const Instruction& cur_instr, return false; } -void StreamAnalyzer::AnalyseAllEventInfo( - const std::vector& instructions, - const std::vector>>& run_type_info, - std::map>>* - event_info) const { - for (size_t cur_instr_id = 0; cur_instr_id < instructions.size(); - ++cur_instr_id) { - const std::vector& next_instr_ids = - run_type_info[cur_instr_id][DownstreamRunType::kEventRun]; - std::set waiter_instr_ids; - std::set visited_next_instr_id; +template +DownstreamRunType analyse_run_type_for_two_instructions(T* cur_instr, + T* next_instr, + const Place& place) { + // xpu&ipu memcpy kerenl is synchronous. + if (platform::is_ipu_place(place) || platform::is_xpu_place(place)) { + return DownstreamRunType::kDirectRun; + } - for (size_t next_instr_id : next_instr_ids) { - AnalyseEventInfoForTwoInstructions(instructions, - run_type_info, - cur_instr_id, - next_instr_id, - &waiter_instr_ids, - &visited_next_instr_id); + // npu d2h kernel is asynchronous. + if (platform::is_custom_place(place)) { + if (platform::is_cpu_place(cur_instr->DeviceContext().GetPlace()) || + interpreter::IsMemcpyH2D(next_instr)) { + return DownstreamRunType::kDirectRun; } + } - for (size_t waiter_instr_id : waiter_instr_ids) { - (*event_info)[&(instructions[cur_instr_id].DeviceContext())] - [waiter_instr_id] - .insert(cur_instr_id); - } + if (cur_instr->KernelType() == OpFuncType::kGpuAsync && + (&cur_instr->DeviceContext() != &next_instr->DeviceContext())) { + return DownstreamRunType::kEventRun; } + + return DownstreamRunType::kDirectRun; } -void StreamAnalyzer::AnalyseAllRunType( - const std::vector& instructions, +template +void analyse_all_run_type( + const std::vector& instructions, const std::map>& downstream_map, - std::vector>>* run_type_info) const { + const Place& place, + std::vector>>* run_type_info) { for (auto& item : downstream_map) { size_t cur_instr_id = item.first; - const Instruction& cur_instr = instructions[item.first]; + T* cur_instr = instructions[item.first]; for (size_t next_instr_id : item.second) { - const Instruction& next_instr = instructions[next_instr_id]; - DownstreamRunType run_type = - AnalyseRunTypeForTwoInstructions(cur_instr, next_instr); + T* next_instr = instructions[next_instr_id]; + DownstreamRunType run_type = analyse_run_type_for_two_instructions( + cur_instr, next_instr, place); (*run_type_info)[cur_instr_id][run_type].push_back(next_instr_id); - VLOG(6) << RunTypeToString(run_type) << ": " << cur_instr.OpBase()->Type() - << "(" << cur_instr_id << ") -> " << next_instr.OpBase()->Type() - << "(" << next_instr_id << ")"; + VLOG(6) << RunTypeToString(run_type) << ": " + << "cur_instr_id(" << cur_instr_id << ") -> " + << "next_instr_id(" << next_instr_id << ")"; } } } +void StreamAnalyzer::AnalyseAllRunType( + const std::vector& instructions, + const std::map>& downstream_map, + std::vector>>* run_type_info) const { + analyse_all_run_type( + instructions, downstream_map, place_, run_type_info); +} + // The caller should guarantee cur_instr and next_instr is kEventRun -void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( - const std::vector& instructions, +template +void analyse_event_info_for_two_instructions( + const std::vector& instructions, + const std::vector>>& run_type_info, + const size_t cur_instr_id, + const size_t next_instr_id, + std::set* waiter_instr_ids, + std::set* visited_next_instr_id); + +template <> +void analyse_event_info_for_two_instructions( + const std::vector& instructions, const std::vector>>& run_type_info, const size_t cur_instr_id, const size_t next_instr_id, std::set* waiter_instr_ids, - std::set* visited_next_instr_id) const { + std::set* visited_next_instr_id) { if (visited_next_instr_id->find(next_instr_id) != visited_next_instr_id->end()) { return; @@ -355,10 +343,11 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( // There is actually a data dependency between op1 and op2 that var0 and // fused_var share the same tensor. However, as the dependency is implicit, we // can only add event for it with the help of depend_op. - if (HasDataDependency(instructions[cur_instr_id], - instructions[next_instr_id]) || + + if (has_data_dependency( + instructions[cur_instr_id], instructions[next_instr_id]) || !run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || - instructions[next_instr_id].OpBase()->Type() == "depend") { + instructions[next_instr_id]->OpBase()->Type() == "depend") { waiter_instr_ids->insert(next_instr_id); return; } @@ -372,19 +361,119 @@ void StreamAnalyzer::AnalyseEventInfoForTwoInstructions( // between cur_instr and next_instr. for (size_t instr_id : run_type_info[next_instr_id][DownstreamRunType::kDirectRun]) { - AnalyseEventInfoForTwoInstructions(instructions, - run_type_info, - cur_instr_id, - instr_id, - waiter_instr_ids, - visited_next_instr_id); + analyse_event_info_for_two_instructions(instructions, + run_type_info, + cur_instr_id, + instr_id, + waiter_instr_ids, + visited_next_instr_id); } } -void StreamAnalyzer::ShrinkEventInfo( - const DependencyBuilder& dependency_builder, +template <> +void analyse_event_info_for_two_instructions< + paddle::framework::InstructionBase>( + const std::vector& instructions, + const std::vector>>& run_type_info, + const size_t cur_instr_id, + const size_t next_instr_id, + std::set* waiter_instr_ids, + std::set* visited_next_instr_id) { + if (visited_next_instr_id->find(next_instr_id) != + visited_next_instr_id->end()) { + return; + } + visited_next_instr_id->insert(next_instr_id); + + // NOTE(Ruibiao): Though depend_op as next_instr is no_need_buffer, we should + // also wait event for it. Because depend_op is used to build dependencies for + // fused vars in some scenarios. In those cases, we do not know which vars may + // lead a implicit data dependency. For example, + // ### + // ### fused_var = fuse_op(var0, ...) + // ### var1 = op1(fused_var) + // ### var0 = depend_op(var0, fused_var) + // ### var2 = op2(var0) + // ### + // If op1 are cross-stream with depend_op and op2, then we have: + // ### + // ### event_run : op1 -> depend_op + // ### direct_run : depend_op -> op2 + // ### + // There is actually a data dependency between op1 and op2 that var0 and + // fused_var share the same tensor. However, as the dependency is implicit, we + // can only add event for it with the help of depend_op. + + if (has_data_dependency( + instructions[cur_instr_id], instructions[next_instr_id]) || + !run_type_info[next_instr_id][DownstreamRunType::kEventRun].empty() || + instructions[next_instr_id]->Name() == "pd.depend") { + waiter_instr_ids->insert(next_instr_id); + return; + } + + // NOTE(Ruibiao): If no data dependency from cur_instr to next_instr, and + // simultaneously next_instr has no event_run downstream instr, we try to + // recursively add events between cur_instr and next_instr's + // direct-run-instrs. This can delay the event wait and achieve better + // scheduling performance in some scenarios. However, when next_instr has too + // many direct-run-instrs, it may perform worse than add event directly + // between cur_instr and next_instr. + for (size_t instr_id : + run_type_info[next_instr_id][DownstreamRunType::kDirectRun]) { + analyse_event_info_for_two_instructions( + instructions, + run_type_info, + cur_instr_id, + instr_id, + waiter_instr_ids, + visited_next_instr_id); + } +} + +template +void analyse_all_event_info( + const std::vector& instructions, + const std::vector>>& run_type_info, + std::map>>* + event_info) { + for (size_t cur_instr_id = 0; cur_instr_id < instructions.size(); + ++cur_instr_id) { + const std::vector& next_instr_ids = + run_type_info[cur_instr_id][DownstreamRunType::kEventRun]; + std::set waiter_instr_ids; + std::set visited_next_instr_id; + + for (size_t next_instr_id : next_instr_ids) { + analyse_event_info_for_two_instructions(instructions, + run_type_info, + cur_instr_id, + next_instr_id, + &waiter_instr_ids, + &visited_next_instr_id); + } + + for (size_t waiter_instr_id : waiter_instr_ids) { + (*event_info)[&(instructions[cur_instr_id]->DeviceContext())] + [waiter_instr_id] + .insert(cur_instr_id); + } + } +} + +void StreamAnalyzer::AnalyseAllEventInfo( + const std::vector& instructions, + const std::vector>>& run_type_info, std::map>>* event_info) const { + analyse_all_event_info(instructions, run_type_info, event_info); +} + +template +void shrink_event_info( + const T& dependency_builder, + std::map>>* + event_info) { for (auto& item : *event_info) { // shrink redundant recorders, waiter instrs should only wait for the last // recorder instrs in each stream @@ -444,6 +533,13 @@ void StreamAnalyzer::ShrinkEventInfo( } } +void StreamAnalyzer::ShrinkEventInfo( + const DependencyBuilder& dependency_builder, + std::map>>* + event_info) const { + shrink_event_info(dependency_builder, event_info); +} + platform::DeviceType StreamAnalyzer::GetWaiterType( const Instruction& instr) const { if (instr.KernelType() == OpFuncType::kCpuSync) { @@ -458,39 +554,147 @@ platform::DeviceType StreamAnalyzer::GetWaiterType( } } -DownstreamRunType StreamAnalyzer::AnalyseRunTypeForTwoInstructions( - const Instruction& cur_instr, const Instruction& next_instr) const { - // xpu&ipu memcpy kerenl is synchronous. - if (platform::is_ipu_place(place_) || platform::is_xpu_place(place_)) { - return DownstreamRunType::kDirectRun; +std::shared_ptr< + std::map>>> +StreamAnalyzer::GetEventInfo() const { + return event_info_; +} + +void StreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) { + event_info_ = src.GetEventInfo(); + is_event_info_build_ = true; +} + +/// ======================== /// +/// For new ir /// +/// ======================== /// +void NewIrStreamAnalyzer::ConstructEvents( + const std::vector>& + instructions) { + if (!is_event_info_build_) { + std::vector + cross_step_merged_instructions_ptr; + for (auto& instr : instructions) { + cross_step_merged_instructions_ptr.emplace_back(instr.get()); + } + for (auto& instr : instructions) { + cross_step_merged_instructions_ptr.emplace_back(instr.get()); + } + + NewIrDependencyBuilder dependency_builder; + dependency_builder.Build(cross_step_merged_instructions_ptr); + const std::map>& downstream_map = + dependency_builder.OpDownstreamMap(); + + const size_t instr_num = cross_step_merged_instructions_ptr.size(); + std::vector>> run_type_info( + instr_num, + std::vector>( + /*number_of_run_type = */ 2)); // instr_id -> run_type -> + // next_instr_id + AnalyseAllRunType( + cross_step_merged_instructions_ptr, downstream_map, &run_type_info); + + AnalyseAllEventInfo( + cross_step_merged_instructions_ptr, run_type_info, event_info_.get()); + + ShrinkEventInfo(dependency_builder, event_info_.get()); + + is_event_info_build_ = true; } - // npu d2h kernel is asynchronous. - if (platform::is_custom_place(place_)) { - if (interpreter::IsCpuOp(cur_instr) || - interpreter::IsMemcpyH2D(next_instr)) { - return DownstreamRunType::kDirectRun; + // Construct events + std::map> instr2event; + for (auto& context_item : *event_info_) { + for (auto& waiter_item : context_item.second) { + size_t waiter_instr_id = waiter_item.first; + std::set& recorder_instr_ids = waiter_item.second; + + if (waiter_instr_id >= instructions.size()) { + waiter_instr_id -= instructions.size(); + } + + for (size_t recorder_instr_id : recorder_instr_ids) { + // Redundant record + if (recorder_instr_id >= instructions.size()) { + continue; + } + + paddle::framework::InstructionBase* recorder_instr = + instructions.at(recorder_instr_id).get(); + paddle::framework::InstructionBase* waiter_instr = + instructions.at(waiter_instr_id).get(); + platform::DeviceType waiter_type = GetWaiterType(waiter_instr); + + if (instr2event.find(recorder_instr_id) == instr2event.end()) { + std::shared_ptr device_event = + std::make_shared( + recorder_instr->DeviceContext().GetPlace(), + platform::GenerateDeviceEventFlag()); + recorder_instr->AddEventToRecord(device_event, + platform::kCUDA /*unused*/); + instr2event.emplace(recorder_instr_id, device_event); + } + + waiter_instr->AddEventToWait( + recorder_instr_id, instr2event.at(recorder_instr_id), waiter_type); + VLOG(6) << "Add event: " << recorder_instr->Name() << "(" + << recorder_instr_id << ") -> " << waiter_instr->Name() << "(" + << waiter_instr_id << "), waiter type = " << waiter_type; + } } } +} - if (cur_instr.KernelType() == OpFuncType::kGpuAsync && - (&cur_instr.DeviceContext() != &next_instr.DeviceContext())) { - return DownstreamRunType::kEventRun; +void NewIrStreamAnalyzer::AnalyseAllRunType( + const std::vector& instructions, + const std::map>& downstream_map, + std::vector>>* run_type_info) const { + analyse_all_run_type( + instructions, downstream_map, place_, run_type_info); +} + +void NewIrStreamAnalyzer::AnalyseAllEventInfo( + const std::vector& instructions, + const std::vector>>& run_type_info, + std::map>>* + event_info) const { + analyse_all_event_info( + instructions, run_type_info, event_info); +} + +void NewIrStreamAnalyzer::ShrinkEventInfo( + const NewIrDependencyBuilder& dependency_builder, + std::map>>* + event_info_map) const { + shrink_event_info(dependency_builder, event_info_map); +} + +platform::DeviceType NewIrStreamAnalyzer::GetWaiterType( + const paddle::framework::InstructionBase* instr) const { + if (instr->KernelType() == OpFuncType::kCpuSync) { + return platform::kCPU; + } else { + if (platform::is_xpu_place(place_)) { + return platform::kXPU; + } else if (platform::is_custom_place(place_)) { + return platform::kCUSTOM_DEVICE; + } + return platform::kCUDA; } +} - return DownstreamRunType::kDirectRun; +void NewIrStreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) { + event_info_ = src.GetEventInfo(); + is_event_info_build_ = true; } std::shared_ptr< std::map>>> -StreamAnalyzer::GetEventInfo() const { +NewIrStreamAnalyzer::GetEventInfo() const { return event_info_; } -void StreamAnalyzer::ShareEventInfoFrom(const StreamAnalyzer& src) { - event_info_ = src.GetEventInfo(); - is_event_info_build_ = true; -} } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h index cef36f4a82f71b8cc99a71f6edbcb13a5b650f45..8c7d2d5b6ddbca816bec124417cdfdf865f23a5f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.h @@ -28,6 +28,43 @@ namespace interpreter { enum DownstreamRunType { kDirectRun, kEventRun }; +class ContextManager { + public: + using DeviceContextMap = + std::map>>; + + static ContextManager& Instance() { + static ContextManager* ctx_manager = new ContextManager; + return *ctx_manager; + } + + std::shared_future> Get( + const std::string& type, + const platform::Place& place, + int stream_priority) { + std::lock_guard lk(ctx_mtx_); + VLOG(6) << "Get dev_ctx for " << type << " - " << place; + + DeviceContextMap& ctxs = ctx_pool_[type]; + if (ctxs.find(place) == ctxs.end()) { + platform::EmplaceDeviceContexts( + &ctxs, + {place}, + /*disable_setting_default_stream_for_allocator=*/true, + stream_priority); + } + return ctxs[place]; + } + + private: + ContextManager() {} + DISABLE_COPY_AND_ASSIGN(ContextManager); + + std::mutex ctx_mtx_; + std::unordered_map ctx_pool_; +}; + class StreamAnalyzer { public: using DeviceContext = platform::DeviceContext; @@ -54,35 +91,75 @@ class StreamAnalyzer { GetEventInfo() const; private: - bool HasDataDependency(const Instruction& cur_instr, - const Instruction& next_instr) const; + bool HasDataDependency(Instruction* cur_instr, Instruction* next_instr) const; void AnalyseAllEventInfo( - const std::vector& instructions, + const std::vector& instructions, const std::vector>>& run_type_info, std::map>>* event_info) const; void AnalyseAllRunType( - const std::vector& instructions, + const std::vector& instructions, const std::map>& downstream_map, std::vector>>* run_type_info) const; - void AnalyseEventInfoForTwoInstructions( - const std::vector& instructions, - const std::vector>>& run_type_info, - const size_t cur_instr_id, - const size_t next_instr_id, - std::set* waiter_instr_ids, - std::set* visited_next_instr_id) const; - void ShrinkEventInfo( const DependencyBuilder& dependency_builder, std::map>>* event_info_map) const; - DownstreamRunType AnalyseRunTypeForTwoInstructions( - const Instruction& cur_instr, const Instruction& next_instr) const; + const Place place_; + bool is_event_info_build_{false}; + std::shared_ptr< + std::map>>> + event_info_; +}; + +/// ======================== /// +/// For new ir /// +/// ======================== /// +class NewIrStreamAnalyzer { + public: + using DeviceContext = platform::DeviceContext; + using Place = platform::Place; + + explicit NewIrStreamAnalyzer(const Place& place) : place_(place) { + event_info_ = std::make_shared< + std::map>>>(); + } + + ~NewIrStreamAnalyzer() {} + + void ConstructEvents( + const std::vector>& + instructions); + + platform::DeviceType GetWaiterType( + const paddle::framework::InstructionBase* instr) const; + + void ShareEventInfoFrom(const StreamAnalyzer& src); + + std::shared_ptr< + std::map>>> + GetEventInfo() const; + + private: + void AnalyseAllRunType( + const std::vector& instructions, + const std::map>& downstream_map, + std::vector>>* run_type_info) const; + + void AnalyseAllEventInfo( + const std::vector& instructions, + const std::vector>>& run_type_info, + std::map>>* + event_info) const; + + void ShrinkEventInfo( + const NewIrDependencyBuilder& dependency_builder, + std::map>>* + event_info_map) const; const Place place_; bool is_event_info_build_{false}; diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc index 3cdc815a562ae4cc1cd69713a4d1e13e3550acd3..b9060cb16e0d88e04baa2ee92217c06e73d91646 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.cc @@ -51,7 +51,8 @@ NewIRInterpreter::NewIRInterpreter(const platform::Place& place, execution_config_(execution_config), var_scope_(scope), scope_(scope), - ir_program_(std::move(ir_prog)) { + ir_program_(std::move(ir_prog)), + ir_stream_analyzer_(place) { VLOG(4) << "NewIRInterpreter(): " << this << " on " << place_; static_build_ = FLAGS_new_executor_static_build && !FLAGS_new_executor_use_cuda_graph && @@ -97,7 +98,6 @@ NewIRInterpreter::~NewIRInterpreter() { gc_.reset(nullptr); async_work_queue_.reset(); VLOG(4) << "~NewIRInterpreter(): " << this << " on " << place_; - #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working @@ -197,9 +197,11 @@ FetchList NewIRInterpreter::Run(const std::vector& feed_names, VLOG(4) << DebugValueInfo(); // NOTE(zhangbo): Iterative version, gradually replacing BuildOpFuncList() - // and Convert() - // BuildInstruction(); - // BuildInstructionDependences(); + // and Convert() by: + // [1] BuildInstruction(); + // [2] BuildInstructionDependences(); + // [3] ir_stream_analyzer_.ConstructEvents(&vec_instruction_base_); + // [4] GC(); std::vector op_func_nodes; interpreter::BuildOpFuncList(place_, @@ -260,8 +262,35 @@ FetchList NewIRInterpreter::BetaRun(const std::vector& feed_names, &var_name_2_id_, &variable_list_); VLOG(4) << DebugValueInfo(); + BuildInstruction(); + BuildInstructionDependences(); + + ir_stream_analyzer_.ConstructEvents(vec_instruction_base_); + // add event for the input var of jit program, since there are async copied + // from gpu_pinned place to gpu place on compute stream. + for (size_t i = 0; i < dependecy_count_.size(); ++i) { + if (dependecy_count_[i] == 0) { + InstructionBase* inst = vec_instruction_base_[i].get(); + if (inst->Name() == "pd.memcpy_d2h" && platform::is_gpu_place(place_)) { + for (auto& item : inst->Inputs()) { + for (auto var_id : item.second) { + auto name = GetNameById(var_id); + if (JitInputVars().count(name)) { + auto device_event = std::make_shared( + place_, platform::GenerateDeviceEventFlag()); + VLOG(4) << "Add input event for input: " << name << " of " + << inst->Name(); + inst->AddEventToWait( + i, device_event, ir_stream_analyzer_.GetWaiterType(inst)); + } + } + } + } + } + } + for (size_t instr_id = 0; instr_id < vec_instruction_base_.size(); ++instr_id) { vec_instruction_base_[instr_id]->Run(); @@ -345,6 +374,21 @@ void NewIRInterpreter::reset_scope(Scope* new_scope) { const Scope* NewIRInterpreter::local_scope() const { return local_scope_; } +std::string NewIRInterpreter::GetNameById(int id) const { + // NOTE(zhiqiu): do not use vec_meta_info_[id].vardesc_->Name() since + // vec_meta_info_[id] may be nullptr, + // typically when the target variable is not existed in the original program + // desc, but created by interpretercore. + // For example, created and used by d2h_copy or h2d_copy operator. + auto it = std::find_if(var_name_2_id_.begin(), + var_name_2_id_.end(), + [id](const auto& pair) { return pair.second == id; }); + if (it != var_name_2_id_.end()) { + return it->first; + } + return ""; +} + void NewIRInterpreter::ShareWorkQueueFrom(InterpreterBaseImpl* src) { async_work_queue_ = reinterpret_cast(src)->GetWorkQueue(); VLOG(8) << "Share AsyncWorkQueue from InterpreterCore(" << src @@ -1581,13 +1625,13 @@ void NewIRInterpreter::AnalyseExecuteOrderForTrace() { /// ======================== /// void NewIRInterpreter::BuildInstruction() { - VLOG(0) << "Build Instructions for new ir ... "; + VLOG(6) << "Build Instructions for new ir ... "; vec_instruction_base_.clear(); size_t op_idx = 0; for (auto it = ir_program_->block()->begin(); it != ir_program_->block()->end(); ++it) { - VLOG(0) << "Build Instruction for op: " << op_idx; + VLOG(6) << "Build Instruction for op: " << op_idx; if ((*it)->dialect()->name() == "pd_kernel") { auto op_name = (*it) ->attributes() @@ -1635,7 +1679,12 @@ void NewIRInterpreter::BuildInstructionDependences() { // instr, and set the dependecy_count_ size_t instr_num = vec_instruction_base_.size(); dependecy_count_ = std::vector(instr_num, 0); - auto downstream_map = ir_dependency_builder_.Build(vec_instruction_base_); + + std::vector instructions_ptr; + for (auto& instr : vec_instruction_base_) { + instructions_ptr.push_back(instr.get()); + } + auto downstream_map = ir_dependency_builder_.Build(instructions_ptr); for (size_t instr_id = 0; instr_id < instr_num; ++instr_id) { InstructionBase* cur_instr = vec_instruction_base_[instr_id].get(); diff --git a/paddle/fluid/framework/new_executor/new_ir_interpreter.h b/paddle/fluid/framework/new_executor/new_ir_interpreter.h index 14c8d1778c288eeec39a0aee253be74f4159781c..8011811c44f2fc1f4adbf1fb8b8139dcff58fcfb 100644 --- a/paddle/fluid/framework/new_executor/new_ir_interpreter.h +++ b/paddle/fluid/framework/new_executor/new_ir_interpreter.h @@ -84,6 +84,8 @@ class NewIRInterpreter : public InterpreterBaseImpl { hookfuncs_ = hookfuncs; } + std::string GetNameById(int id) const; + private: // build graph void Convert(std::vector* op_func_nodes); @@ -216,7 +218,9 @@ class NewIRInterpreter : public InterpreterBaseImpl { std::vector variable_list_; - interpreter::IrDependencyBuilder ir_dependency_builder_; + interpreter::NewIrDependencyBuilder ir_dependency_builder_; + + interpreter::NewIrStreamAnalyzer ir_stream_analyzer_; }; } // namespace framework diff --git a/paddle/fluid/ir/interface/op_yaml_info_parser.cc b/paddle/fluid/ir/interface/op_yaml_info_parser.cc index 2cb84e5a610aaabf2179aff77d721b5a515540d8..20266aac6aba4f8f75a8e00a537d57f629d943bf 100644 --- a/paddle/fluid/ir/interface/op_yaml_info_parser.cc +++ b/paddle/fluid/ir/interface/op_yaml_info_parser.cc @@ -92,6 +92,10 @@ const std::map& OpYamlInfoParser::OutputName2Id() const { return output_name2id_; } +const std::vector& OpYamlInfoParser::NoNeedBufferIds() const { + return no_need_buffer_ids_; +} + bool OpYamlInfoParser::HasInplace(const std::string& out_name) const { auto& inplace_info = std::get<3>(op_info_tuple_).inplace; for (size_t i = 0; i < inplace_info.size(); i++) { @@ -117,14 +121,16 @@ const std::string& OpYamlInfoParser::InplaceName( void OpYamlInfoParser::parse() { auto input_info = std::get<0>(op_info_tuple_); - int input_start_index = 0; for (size_t i = 0; i < input_info.size(); ++i) { - input_name2id_[input_info[i].name] = input_start_index++; + input_name2id_[input_info[i].name] = i; input_name_list_.push_back(input_info[i].name); input_info_[input_info[i].name] = input_info[i]; if (!input_info[i].is_mutable_attribute) { input_tensor_number_++; } + if (input_info[i].no_need_buffer) { + no_need_buffer_ids_.push_back(i); + } } auto attribute_info = std::get<1>(op_info_tuple_); @@ -133,10 +139,9 @@ void OpYamlInfoParser::parse() { attr_info_[attribute_info[i].name] = attribute_info[i]; } - int output_start_index = 0; auto output_info = std::get<2>(op_info_tuple_); for (size_t i = 0; i < output_info.size(); ++i) { - output_name2id_[output_info[i].name] = output_start_index++; + output_name2id_[output_info[i].name] = i; output_name_list_.push_back(output_info[i].name); output_info_[output_info[i].name] = output_info[i]; } diff --git a/paddle/fluid/ir/interface/op_yaml_info_parser.h b/paddle/fluid/ir/interface/op_yaml_info_parser.h index 6b600a6d70e8121e34b680d903a828e7bf66a41a..356decadcf677fffe0c5383967822c619d46286a 100644 --- a/paddle/fluid/ir/interface/op_yaml_info_parser.h +++ b/paddle/fluid/ir/interface/op_yaml_info_parser.h @@ -37,6 +37,8 @@ class OpYamlInfoParser { const std::map& InputName2Id() const; const std::map& OutputName2Id() const; + const std::vector& NoNeedBufferIds() const; + const std::vector& InputNames() const { return input_name_list_; } @@ -65,6 +67,9 @@ class OpYamlInfoParser { std::map input_info_; int input_tensor_number_{0}; + // no_need_buffer_ids + std::vector no_need_buffer_ids_; + // attribute info std::vector attribute_name_list_; std::map attr_info_;