diff --git a/paddle/fluid/framework/new_executor/event_manager.cc b/paddle/fluid/framework/new_executor/event_manager.cc index a45f65d264c3ad7ab9cb0b7e59d1ab2ed5c64ce9..cc6fd6e3ed0f94cc75936eb78f2e7ec8fbe48e16 100644 --- a/paddle/fluid/framework/new_executor/event_manager.cc +++ b/paddle/fluid/framework/new_executor/event_manager.cc @@ -41,6 +41,16 @@ void RecordEvent(const Instruction& instruction, const platform::Place& place) { } } +void RecordEvent(const Instruction& instruction) { + // If InterpreterCore in on CPUPlace, do nothing. + if (platform::is_cpu_place(instruction.DeviceContext().GetPlace())) return; + + for (auto& event : instruction.OutputEvents()) { + VLOG(3) << "Record event in out_var_id: " << event.var_id_; + event.event_->Record(&instruction.DeviceContext()); + } +} + } // namespace interpreter } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/event_manager.h b/paddle/fluid/framework/new_executor/event_manager.h index a949ae144017d4e795b505bb3551c8bbd70e37c4..283661198a48d70f4ccd2b1e7fee4d7c98632e7f 100644 --- a/paddle/fluid/framework/new_executor/event_manager.h +++ b/paddle/fluid/framework/new_executor/event_manager.h @@ -20,6 +20,8 @@ namespace framework { namespace interpreter { void RecordEvent(const Instruction& instruction, const platform::Place& place); +void RecordEvent(const Instruction& instruction); + void WaitEvent(const Instruction& instruction, const platform::Place& place); } // namespace interpreter diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index abeadb5aa14915f51b35eb71ca1a52e7b59cceac..a28d4c845d796ebcdff3484ea50f3af7e5d0a7f1 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -77,20 +77,22 @@ paddle::framework::FetchList InterpreterCore::Run( return *(fetch_var->GetMutable()); } -void InterpreterCore::Convert() { +void InterpreterCore::Convert( + std::vector* op_func_nodes) { auto& vec_meta_info = global_scope_->MutableVecMetaInfo(); auto var_nums = global_scope_->VarSize(); input_var2op_info_.resize(var_nums); + auto nodes = *op_func_nodes; - auto op_nums = vec_func_list_.size(); + auto op_nums = nodes.size(); vec_instruction_.reserve(op_nums); dependecy_count_.resize(op_nums); for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { - auto& op_func_node = vec_func_list_[op_idx]; + auto& op_func_node = nodes[op_idx]; auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); - vec_instruction_.emplace_back(op_idx, op_func_node, *dev_ctx_); + vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_); auto& instr = vec_instruction_.back(); OpInOutInfo info; @@ -104,7 +106,7 @@ void InterpreterCore::Convert() { input_var2op_info_.at(id).push_back(op_idx); // var can be gc-ed if (!info.IsBuilt()) { - info.Build(op_func_node.operator_base_); + info.Build(op_func_node.operator_base_.get()); } auto* var_desc = global_scope_->VarDesc(id); if (var_desc) { @@ -522,11 +524,12 @@ void InterpreterCore::Prepare( if (!is_build_) { paddle::framework::interpreter::build_variable_scope(block_, global_scope_); FeedInput(); + std::vector op_func_nodes; paddle::framework::interpreter::build_op_func_list( - place_, block_, &vec_func_list_, global_scope_); + place_, block_, &op_func_nodes, global_scope_); is_build_ = true; // convert vec func_list to graph - Convert(); + Convert(&op_func_nodes); } // NOTE: Because feed_tensor will be GC after // paddle::framework::build_op_func_list, so we should diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index 915ae782e2210f9c291154d2b468614528e40b81..0925d715574fc702f4bc8b961e07fd4be52b0b72 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -54,7 +54,7 @@ class InterpreterCore { const std::vector& feed_tensors); private: - void Convert(); + void Convert(std::vector* op_func_nodes); void BuildAndCacheInstructionCtx(Instruction* instr_node); @@ -84,7 +84,6 @@ class InterpreterCore { const BlockDesc& block_; // not owned VariableScope* global_scope_; // not owned - std::vector vec_func_list_; std::vector vec_instruction_; // deconstruct before OpFuncNode std::vector dependecy_count_; diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 068b554a57f1976391e194f829f3eaa615906ddd..011b1b6dece8e02380b58a829e72742e3ef28541 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -64,11 +64,12 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { std::unordered_map> -get_unused_vars(const BlockDesc& block, const std::vector& ops) { +get_unused_vars(const BlockDesc& block, + const std::vector>& ops) { std::unordered_map var_op_idx_map; for (size_t i = 0; i < ops.size(); ++i) { - auto* op = ops[i]; + const auto& op = ops[i]; OpInOutInfo info; for (auto& name_pair : op->Inputs()) { @@ -79,7 +80,7 @@ get_unused_vars(const BlockDesc& block, const std::vector& ops) { // var can be gc-ed if (!info.IsBuilt()) { - info.Build(op); + info.Build(op.get()); } if (info.IsInArgBufferNeeded(name)) { @@ -107,7 +108,8 @@ get_unused_vars(const BlockDesc& block, const std::vector& ops) { for (auto& name_op_idx_pair : var_op_idx_map) { auto& name = name_op_idx_pair.first; size_t op_idx = name_op_idx_pair.second; - result[ops[op_idx]].emplace_back(name); + + result[ops[op_idx].get()].emplace_back(name); } return result; } @@ -150,8 +152,8 @@ void build_variable_scope(const framework::BlockDesc& block, } } -std::vector create_all_ops(const framework::BlockDesc& block) { - std::vector ops; +void create_all_ops(const framework::BlockDesc& block, + std::vector>* ops) { for (auto& op : block.AllOps()) { VLOG(3) << "CreateOp from : " << op->Type(); @@ -166,9 +168,8 @@ std::vector create_all_ops(const framework::BlockDesc& block) { } auto op_base = info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); - ops.push_back(op_base); + ops->emplace_back(std::shared_ptr(op_base)); } - return ops; } std::tuple build_variable_map( @@ -234,7 +235,8 @@ void apply_device_guard(const OperatorBase* op_base, } void deal_operator_base(const platform::Place& place, - const VariableScope* var_scope, OperatorBase* op_base, + const VariableScope* var_scope, + std::shared_ptr op_base, OpFuncNode* op_func_node) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -321,8 +323,9 @@ std::tuple apply_place_transform_for_var( var_name, kernel_type_for_var.place_, new_var_name, expected_kernel_key.place_); auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type); - auto copy_op = - copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map); + auto copy_op = std::shared_ptr( + copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map)); + OpFuncNode copy_op_func_node; copy_op_func_node.input_index = copy_ins_name2id; copy_op_func_node.output_index = copy_out_name2id; @@ -330,10 +333,10 @@ std::tuple apply_place_transform_for_var( RuntimeContext copy_runtime_context({}, {}); copy_runtime_context.inputs.swap(copy_ins_value_map); copy_runtime_context.outputs.swap(copy_outs_value_map); - InterpretercoreInferShapeContext copy_infer_shape_ctx(*copy_op, + InterpretercoreInferShapeContext copy_infer_shape_ctx(*copy_op.get(), copy_runtime_context); - static_cast(copy_op)->InferShape( - ©_infer_shape_ctx); + static_cast(copy_op.get()) + ->InferShape(©_infer_shape_ctx); auto kernels_iter = all_op_kernels.find(memcpy_op_type); PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(), @@ -347,7 +350,7 @@ std::tuple apply_place_transform_for_var( auto copy_exec_ctx = ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context); auto copy_expected_kernel_key = - dynamic_cast(copy_op) + dynamic_cast(copy_op.get()) ->GetExpectedKernelType(copy_exec_ctx); auto kernel_iter = kernels.find(copy_expected_kernel_key); copy_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); @@ -362,19 +365,20 @@ std::tuple apply_place_transform_for_var( return std::make_pair(new_var_name, copy_op_func_node); } -std::vector apply_data_transform( - const OpKernelType& expected_kernel_key, const platform::Place& place, - VariableValueMap* ins_map_temp, VariableScope* var_scope, - OpFuncNode* op_func_node) { - auto& op_base = op_func_node->operator_base_; +void apply_data_transform(const OpKernelType& expected_kernel_key, + const platform::Place& place, + VariableValueMap* ins_map_temp, + VariableScope* var_scope, OpFuncNode* op_func_node, + std::vector* copy_func_nodes) { + auto op_base = op_func_node->operator_base_.get(); PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( "op_base is null, please pass a valid " "op_base in apply_data_transform.")); - auto inputs_names = op_base->Inputs(); + + VariableNameMap new_ins(op_base->Inputs()); std::unordered_set no_data_transform_index; // record the no need transform variable index. - std::vector copy_func_nodes; // return all the copy opfuncnode. for (auto& var_name_item : *ins_map_temp) { for (size_t i = 0; i < var_name_item.second.size(); ++i) { @@ -382,7 +386,7 @@ std::vector apply_data_transform( if (!(var->IsType() || var->IsType())) { continue; } - auto& var_name = inputs_names[var_name_item.first].at(i); + auto& var_name = new_ins[var_name_item.first].at(i); auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); if (!tensor_in->IsInitialized()) { continue; @@ -404,8 +408,9 @@ std::vector apply_data_transform( var_name_item.first, *op_func_node, var, var_scope); op_func_node->input_index[var_name_item.first][i] = var_scope->VarId(new_var_name); - copy_func_nodes.push_back(copy_op_func_node); + copy_func_nodes->emplace_back(copy_op_func_node); var_name_item.second[i] = var_scope->Var(new_var_name); + new_ins[var_name_item.first][i] = new_var_name; } else if (need_dtype_transform_for_var(kernel_type_for_var, expected_kernel_key)) { // TODO(@xiongkun) add dtype judgement here @@ -421,8 +426,14 @@ std::vector apply_data_transform( } } } + + // NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent + // with instruction + // hot fix, it is not good design here + op_func_node->operator_base_ = + std::shared_ptr(framework::OpRegistry::CreateOp( + op_base->Type(), new_ins, op_base->Outputs(), op_base->Attrs())); op_func_node->no_data_transform_index = std::move(no_data_transform_index); - return copy_func_nodes; } void build_op_func_list(const platform::Place& place, @@ -430,16 +441,16 @@ void build_op_func_list(const platform::Place& place, std::vector* vec_func_list, VariableScope* var_scope) { auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); - + std::vector> + ops; // its elements will be moved to vec_func_list // Step 1: create all ops for current block. - auto ops = create_all_ops(block); + create_all_ops(block, &ops); auto unused_var_map = get_unused_vars(block, ops); - size_t ops_index = 0; - for (auto& op : block.AllOps()) { + for (size_t i = 0; i < ops.size(); ++i) { + auto op = ops[i].get(); VLOG(6) << "Build OpFuncNode from : " << op->Type(); - auto op_base = ops[ops_index++]; auto inputs_names = op->Inputs(); auto outputs_names = op->Outputs(); @@ -466,20 +477,18 @@ void build_op_func_list(const platform::Place& place, op_func_node.input_index = ins_name2id; op_func_node.output_index = outs_name2id; - if (dynamic_cast(op_base) == - nullptr) { + if (dynamic_cast(op) == nullptr) { // op is not a operatorwithkernel, so direcly run OperatorBase::Run() - deal_operator_base(place, var_scope, op_base, &op_func_node); + deal_operator_base(place, var_scope, ops[i], &op_func_node); } else { // construct RuntimeContext and analysis KernelType RuntimeContext runtime_context({}, {}); runtime_context.inputs.swap(ins_map); runtime_context.outputs.swap(outs_map); - InterpretercoreInferShapeContext infer_shape_ctx(*op_base, - runtime_context); + InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context); // TODO(Aurelius84): In case of control flow ops, they are NOT inheritted // from OperatorWithKernel. - static_cast(op_base)->InferShape( + static_cast(op)->InferShape( &infer_shape_ctx); auto kernels_iter = all_op_kernels.find(op->Type()); PADDLE_ENFORCE_NE( @@ -495,13 +504,13 @@ void build_op_func_list(const platform::Place& place, auto* dev_ctx = pool.Get(place); Scope scope; auto expected_kernel_key = - dynamic_cast(op_base) + dynamic_cast(op) ->GetExpectedKernelType( - ExecutionContext(*op_base, scope, *dev_ctx, runtime_context)); + ExecutionContext(*op, scope, *dev_ctx, runtime_context)); // consider device_guard() apply_device_guard( - op_base, place, + op, place, &expected_kernel_key); // change device by the device_guard() VLOG(3) << "expected_kernel_key : " << expected_kernel_key; @@ -510,14 +519,14 @@ void build_op_func_list(const platform::Place& place, std::vector copy_op_to_insert; // NOTE(xiongkun03): assign op_base here to reduce parameter number of // apply_data_transform. - op_func_node.operator_base_ = op_base; - copy_op_to_insert = apply_data_transform( - expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node); + op_func_node.operator_base_ = ops[i]; + apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope, + &op_func_node, ©_op_to_insert); for (auto& item : copy_op_to_insert) { vec_func_list->push_back(item); } // step 4. Run op kernel - VLOG(3) << op_base->Type() + VLOG(3) << op->Type() << " : expected_kernel_key : " << expected_kernel_key; if (platform::is_gpu_place(expected_kernel_key.place_)) { @@ -533,8 +542,7 @@ void build_op_func_list(const platform::Place& place, } op_func_node.dev_ctx_ = dev_ctx; - auto exec_ctx = - ExecutionContext(*op_base, scope, *dev_ctx, runtime_context); + auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context); auto kernel_iter = kernels.find(expected_kernel_key); PADDLE_ENFORCE_NE( @@ -547,9 +555,9 @@ void build_op_func_list(const platform::Place& place, op_func_node.kernel_func_(exec_ctx); } - vec_func_list->push_back(op_func_node); + vec_func_list->emplace_back(op_func_node); // gc--------------------------------------------------------------------------- - auto iter = unused_var_map.find(op_base); + auto iter = unused_var_map.find(op); if (iter == unused_var_map.end()) { continue; } @@ -586,7 +594,7 @@ void build_op_func_list(const platform::Place& place, delete garbages; // free mem - VLOG(3) << "run " << op_base->Type() << " done."; + VLOG(3) << "run " << op->Type() << " done."; } } diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index bd8072885e2386d54c6a05b520cfd22031ac07af..7d40102cbe7647cb2e7716f2ff3027bbeb7b22a5 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -629,5 +629,109 @@ void VariableScopeListener::onCreateScope(Scope* Scope) {} void VariableScopeListener::onDeleteScope(Scope* Scope) {} void VariableScopeListener::onClear() {} +Instruction::Instruction(size_t id, OpFuncNode&& op_func_node, + const platform::DeviceContext& dev_ctx) + : id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) { + PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet( + "Required id >= 0, but received id = %d", id)); +} + +size_t Instruction::Id() const { return id_; } + +const std::map>& Instruction::Inputs() const { + return op_func_node_.input_index; +} + +const std::map>& Instruction::Outputs() const { + return op_func_node_.output_index; +} + +const std::unordered_set& Instruction::NoDataTransformVars() const { + return op_func_node_.no_data_transform_index; +} + +OpKernelComputeFunc Instruction::KernelFunc() const { + return op_func_node_.kernel_func_; +} + +OpFuncType Instruction::KernelType() const { return op_func_node_.type_; } + +OperatorBase* Instruction::OpBase() const { + auto op_base = op_func_node_.operator_base_; + PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( + "op_base shall not be nullptr.")); + return op_base.get(); +} + +NextInstruction& Instruction::NextInstructions() { return next_instruction_; } + +const NextInstruction& Instruction::NextInstructions() const { + return next_instruction_; +} + +void Instruction::AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); } + +const std::vector& Instruction::GCCheckVars() const { + return gc_check_var_list_; +} + +void Instruction::ResetContext(const VariableValueMap& in_vars, + const VariableValueMap& out_vars) { + runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); + infershape_ctx_.reset( + new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get())); + // NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an + // empty here to avoid illegal local reference. + static framework::Scope scope_; + execution_ctx_.reset( + new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get())); +} + +std::shared_ptr Instruction::InnerRuntimeContext() const { + return runtime_ctx_; +} + +std::shared_ptr +Instruction::InnerInferShapeContext() const { + return infershape_ctx_; +} + +std::shared_ptr Instruction::InnerExecutionContext() const { + return execution_ctx_; +} + +const platform::DeviceContext& Instruction::DeviceContext() const { + return dev_ctx_; +} + +const std::vector>& Instruction::InplaceInfo() + const { + return vec_inplace_in_to_out_; +} + +void Instruction::AddInplace(Variable* in, Variable* out) { + vec_inplace_in_to_out_.emplace_back(in, out); +} + +const std::vector& Instruction::InputEvents() const { + return intput_events_; +} + +const std::vector& Instruction::OutputEvents() const { + return output_events_; +} + +void Instruction::AddInputEvent(size_t var_id, + std::shared_ptr event, + platform::DeviceType waiter_type) { + intput_events_.emplace_back(var_id, event, waiter_type); +} + +void Instruction::AddOutputEvent(size_t var_id, + std::shared_ptr event, + platform::DeviceType waiter_type) { + output_events_.emplace_back(var_id, event, waiter_type); +} + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index 2479abb8926abe2becf0b0c529fb16effe242bb2..68ea48fd328032de330d28bc958c542af0656cda 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -271,7 +271,8 @@ enum class OpFuncType { class RuntimeInferShapeContext; struct OpFuncNode { - OperatorBase* operator_base_; + // TODO(zhiqiu): Better make it unique_ptr + std::shared_ptr operator_base_; std::map> input_index; std::map> output_index; std::unordered_set no_data_transform_index; @@ -283,100 +284,62 @@ struct OpFuncNode { class Instruction { public: - Instruction(size_t id, const OpFuncNode& op_func_node, - const platform::DeviceContext& dev_ctx) - : id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) { - PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet( - "Required id >= 0, but received id = %d", id)); - } + Instruction(size_t id, OpFuncNode&& op_func_node, + const platform::DeviceContext& dev_ctx); - size_t Id() const { return id_; } + size_t Id() const; - const std::map>& Inputs() const { - return op_func_node_.input_index; - } + const std::map>& Inputs() const; - const std::map>& Outputs() const { - return op_func_node_.output_index; - } + const std::map>& Outputs() const; - const std::unordered_set& NoDataTransformVars() const { - return op_func_node_.no_data_transform_index; - } + const std::unordered_set& NoDataTransformVars() const; - OpKernelComputeFunc KernelFunc() const { return op_func_node_.kernel_func_; } + OpKernelComputeFunc KernelFunc() const; - OpFuncType KernelType() const { return op_func_node_.type_; } + OpFuncType KernelType() const; - OperatorBase* OpBase() const { - auto* op_base = op_func_node_.operator_base_; - PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( - "op_base shall not be nullptr.")); - return op_base; - } + OperatorBase* OpBase() const; - NextInstruction& NextInstructions() { return next_instruction_; } + NextInstruction& NextInstructions(); - const NextInstruction& NextInstructions() const { return next_instruction_; } + const NextInstruction& NextInstructions() const; - void AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); } + void AddGCCheckVar(size_t id); - const std::vector& GCCheckVars() const { return gc_check_var_list_; } + const std::vector& GCCheckVars() const; void ResetContext(const VariableValueMap& in_vars, - const VariableValueMap& out_vars) { - runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars)); - infershape_ctx_.reset( - new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get())); - // NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an - // empty here to avoid illegal local reference. - static framework::Scope scope_; - execution_ctx_.reset( - new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get())); - } + const VariableValueMap& out_vars); - std::shared_ptr InnerRuntimeContext() const { - return runtime_ctx_; - } + std::shared_ptr InnerRuntimeContext() const; std::shared_ptr InnerInferShapeContext() - const { - return infershape_ctx_; - } + const; - std::shared_ptr InnerExecutionContext() const { - return execution_ctx_; - } + std::shared_ptr InnerExecutionContext() const; - const platform::DeviceContext& DeviceContext() const { return dev_ctx_; } + const platform::DeviceContext& DeviceContext() const; - const std::vector>& InplaceInfo() const { - return vec_inplace_in_to_out_; - } + const std::vector>& InplaceInfo() const; - void AddInplace(Variable* in, Variable* out) { - vec_inplace_in_to_out_.emplace_back(in, out); - } + void AddInplace(Variable* in, Variable* out); - const std::vector& InputEvents() const { return intput_events_; } + const std::vector& InputEvents() const; - const std::vector& OutputEvents() const { return output_events_; } + const std::vector& OutputEvents() const; void AddInputEvent(size_t var_id, std::shared_ptr event, - platform::DeviceType waiter_type) { - intput_events_.emplace_back(var_id, event, waiter_type); - } + platform::DeviceType waiter_type); void AddOutputEvent(size_t var_id, std::shared_ptr event, - platform::DeviceType waiter_type) { - output_events_.emplace_back(var_id, event, waiter_type); - } + platform::DeviceType waiter_type); private: size_t id_; - const OpFuncNode& op_func_node_; // not owned + OpFuncNode op_func_node_; const platform::DeviceContext& dev_ctx_; // not owned std::shared_ptr runtime_ctx_; @@ -403,6 +366,11 @@ static bool IsMemcpyH2D(const Instruction& instr) { static bool IsMemcpyD2H(const Instruction& instr) { return instr.OpBase()->Type() == kMemcpyD2H; } + +static bool IsCpuOp(const Instruction& instr) { + return platform::is_cpu_place(instr.DeviceContext().GetPlace()); +} + } // namespace interpreter } // namespace framework diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.cc b/paddle/fluid/framework/new_executor/stream_analyzer.cc index 23b61dd3d5ee7b7f84e12b8ad32b56604b2944e5..3ac30456a2d983c44dcb6d5067ccffece0b460aa 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/stream_analyzer.cc @@ -27,33 +27,60 @@ namespace framework { * * For example: matmul(gpu) -> out_var -> memcpy_d2h * out_var should be associated with an event. + * + * NOTE(zhiqiu): There are two special case that no event is needed: + * 1. the variable is marked as NoDataTransformVar + * 2. the variable is marked as NoNeedDataBuffer */ -std::vector StreamAnalyzer::ParseEventVarIds( +std::vector StreamAnalyzer::GetNeedEventVarIds( const Instruction& cur_instr, const Instruction& next_instr) { std::unordered_set unique_var_ids; for (auto& item : cur_instr.Outputs()) { unique_var_ids.insert(item.second.begin(), item.second.end()); } - std::vector new_event_var_ids; + auto is_no_need_buffer = [&next_instr](std::string name) { + auto* op = next_instr.OpBase(); + auto& inferer = op->Info().NoNeedBufferVarsInferer(); + if (inferer) { + auto no_need_buffer_ins = + inferer(op->Inputs(), op->Outputs(), op->Attrs()); + return no_need_buffer_ins.count(name) != 0; + } + return false; + }; + + std::vector need_event_var_ids; for (auto& item : next_instr.Inputs()) { for (auto var_id : item.second) { - if (unique_var_ids.count(var_id) > 0 && - next_instr.NoDataTransformVars().count(var_id) == 0) { - new_event_var_ids.push_back(var_id); + if (unique_var_ids.count(var_id) > 0) { + if (next_instr.NoDataTransformVars().count(var_id)) { + VLOG(4) << "Skip inserting event at variable " << item.first + << " of operator " << next_instr.OpBase()->Type() + << " since it is NoDataTransform"; + continue; + } + if (is_no_need_buffer(item.first)) { + VLOG(4) << "Skip inserting event at variable " << item.first + << " of operator " << next_instr.OpBase()->Type() + << " since it is NoNeedBufferVar"; + continue; + } + + need_event_var_ids.push_back(var_id); } } } - return new_event_var_ids; + return need_event_var_ids; } -void StreamAnalyzer::AssociateInputWithEvents( +void StreamAnalyzer::ConstructEventForVar( const std::vector& new_event_var_id, Instruction* next_instr, - platform::DeviceType waiter_type) { + platform::DeviceType waiter_type, const platform::Place& place) { for (auto var_id : new_event_var_id) { if (var_id2event_.count(var_id) == 0) { auto device_event = std::make_shared( - place_, platform::GenerateDeviceEventFlag()); + place, platform::GenerateDeviceEventFlag()); var_id2event_.emplace(var_id, std::move(device_event)); } // Add events for next_instr.inputs @@ -69,21 +96,27 @@ void StreamAnalyzer::Schedule(const std::vector& downstream_ops, std::vector event_var_ids; for (auto next_op_id : downstream_ops) { auto& next_instr = instructions->at(next_op_id); - if (IsDirectRun(cur_instr, next_instr)) { + VLOG(4) << "DirectRun: " << cur_instr.OpBase()->Type() << "->" + << next_instr.OpBase()->Type(); next_instruction.AddDirectRun(next_op_id); } else { // Always insert events between different stream - auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); - event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(), - new_event_var_ids.end()); + auto need_event_var_ids = GetNeedEventVarIds(cur_instr, next_instr); + event_var_ids.insert(event_var_ids.end(), need_event_var_ids.begin(), + need_event_var_ids.end()); auto waiter_type = GetWaiterType(next_instr); - AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type); + ConstructEventForVar(need_event_var_ids, &next_instr, waiter_type, + cur_instr.DeviceContext().GetPlace()); if (waiter_type == platform::kCPU) { // GPU -> CPU + VLOG(4) << "SyncRun: " << cur_instr.OpBase()->Type() << "->" + << next_instr.OpBase()->Type(); next_instruction.AddSyncRun(next_op_id); } else { // GPU -> GPU(different stream) + VLOG(4) << "EventRun: " << cur_instr.OpBase()->Type() << "->" + << next_instr.OpBase()->Type(); next_instruction.ADDEventRun(next_op_id); } } @@ -116,12 +149,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( * NOTE(dev): The following cases are considered as directly run: * * 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU - * 2. D2H -> CPU - * 3. CPU -> H2D + * 2. CPU -> any (it is possible: CPU op->VAR->GPU op, when var is no need + * buffer or no need data transform) + * 3. D2H -> CPU + * 4. CPU -> H2D */ bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, const Instruction& next_instr) { return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() || + interpreter::IsCpuOp(cur_instr) || interpreter::IsMemcpyD2H(cur_instr) || interpreter::IsMemcpyH2D(next_instr)); } diff --git a/paddle/fluid/framework/new_executor/stream_analyzer.h b/paddle/fluid/framework/new_executor/stream_analyzer.h index df74c9b933712ff557a91ef4ccc42227e9bbbf64..2a276c6f5097a0c8bc21779ae477208a449330f6 100644 --- a/paddle/fluid/framework/new_executor/stream_analyzer.h +++ b/paddle/fluid/framework/new_executor/stream_analyzer.h @@ -35,12 +35,13 @@ class StreamAnalyzer { platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node); private: - std::vector ParseEventVarIds(const Instruction& cur_instr, - const Instruction& next_instr); + std::vector GetNeedEventVarIds(const Instruction& cur_instr, + const Instruction& next_instr); - void AssociateInputWithEvents(const std::vector& new_event_var_id, - Instruction* next_instr, - platform::DeviceType waiter_type); + void ConstructEventForVar(const std::vector& new_event_var_id, + Instruction* next_instr, + platform::DeviceType waiter_type, + const platform::Place& place); bool IsDirectRun(Instruction& cur_instr, // NOLINT const Instruction& next_instr);