未验证 提交 584b4b24 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fix stream analysis (#37161)

* fix revord_event

* refine class Instruction

* refine Instruction and InterpreterCore

* make instruction and operator_base consistent

* support NoNeedBufferVar in stream_analyzer

* fix place of event

* add vlog before continue
上级 9c591703
......@@ -41,6 +41,16 @@ void RecordEvent(const Instruction& instruction, const platform::Place& place) {
}
}
void RecordEvent(const Instruction& instruction) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(instruction.DeviceContext().GetPlace())) return;
for (auto& event : instruction.OutputEvents()) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(&instruction.DeviceContext());
}
}
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,8 @@ namespace framework {
namespace interpreter {
void RecordEvent(const Instruction& instruction, const platform::Place& place);
void RecordEvent(const Instruction& instruction);
void WaitEvent(const Instruction& instruction, const platform::Place& place);
} // namespace interpreter
......
......@@ -77,20 +77,22 @@ paddle::framework::FetchList InterpreterCore::Run(
return *(fetch_var->GetMutable<framework::FetchList>());
}
void InterpreterCore::Convert() {
void InterpreterCore::Convert(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto& vec_meta_info = global_scope_->MutableVecMetaInfo();
auto var_nums = global_scope_->VarSize();
input_var2op_info_.resize(var_nums);
auto nodes = *op_func_nodes;
auto op_nums = vec_func_list_.size();
auto op_nums = nodes.size();
vec_instruction_.reserve(op_nums);
dependecy_count_.resize(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = vec_func_list_[op_idx];
auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, op_func_node, *dev_ctx_);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
auto& instr = vec_instruction_.back();
OpInOutInfo info;
......@@ -104,7 +106,7 @@ void InterpreterCore::Convert() {
input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed
if (!info.IsBuilt()) {
info.Build(op_func_node.operator_base_);
info.Build(op_func_node.operator_base_.get());
}
auto* var_desc = global_scope_->VarDesc(id);
if (var_desc) {
......@@ -522,11 +524,12 @@ void InterpreterCore::Prepare(
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &vec_func_list_, global_scope_);
place_, block_, &op_func_nodes, global_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert();
Convert(&op_func_nodes);
}
// NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should
......
......@@ -54,7 +54,7 @@ class InterpreterCore {
const std::vector<framework::LoDTensor>& feed_tensors);
private:
void Convert();
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildAndCacheInstructionCtx(Instruction* instr_node);
......@@ -84,7 +84,6 @@ class InterpreterCore {
const BlockDesc& block_; // not owned
VariableScope* global_scope_; // not owned
std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::vector<size_t> dependecy_count_;
......
......@@ -64,11 +64,12 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
std::unordered_map<const paddle::framework::OperatorBase*,
std::vector<std::string>>
get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) {
get_unused_vars(const BlockDesc& block,
const std::vector<std::shared_ptr<OperatorBase>>& ops) {
std::unordered_map<std::string, size_t> var_op_idx_map;
for (size_t i = 0; i < ops.size(); ++i) {
auto* op = ops[i];
const auto& op = ops[i];
OpInOutInfo info;
for (auto& name_pair : op->Inputs()) {
......@@ -79,7 +80,7 @@ get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) {
// var can be gc-ed
if (!info.IsBuilt()) {
info.Build(op);
info.Build(op.get());
}
if (info.IsInArgBufferNeeded(name)) {
......@@ -107,7 +108,8 @@ get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) {
for (auto& name_op_idx_pair : var_op_idx_map) {
auto& name = name_op_idx_pair.first;
size_t op_idx = name_op_idx_pair.second;
result[ops[op_idx]].emplace_back(name);
result[ops[op_idx].get()].emplace_back(name);
}
return result;
}
......@@ -150,8 +152,8 @@ void build_variable_scope(const framework::BlockDesc& block,
}
}
std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) {
std::vector<OperatorBase*> ops;
void create_all_ops(const framework::BlockDesc& block,
std::vector<std::shared_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) {
VLOG(3) << "CreateOp from : " << op->Type();
......@@ -166,9 +168,8 @@ std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) {
}
auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
ops.push_back(op_base);
ops->emplace_back(std::shared_ptr<OperatorBase>(op_base));
}
return ops;
}
std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
......@@ -234,7 +235,8 @@ void apply_device_guard(const OperatorBase* op_base,
}
void deal_operator_base(const platform::Place& place,
const VariableScope* var_scope, OperatorBase* op_base,
const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
......@@ -321,8 +323,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
var_name, kernel_type_for_var.place_, new_var_name,
expected_kernel_key.place_);
auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type);
auto copy_op =
copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map);
auto copy_op = std::shared_ptr<OperatorBase>(
copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map));
OpFuncNode copy_op_func_node;
copy_op_func_node.input_index = copy_ins_name2id;
copy_op_func_node.output_index = copy_out_name2id;
......@@ -330,10 +333,10 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
RuntimeContext copy_runtime_context({}, {});
copy_runtime_context.inputs.swap(copy_ins_value_map);
copy_runtime_context.outputs.swap(copy_outs_value_map);
InterpretercoreInferShapeContext copy_infer_shape_ctx(*copy_op,
InterpretercoreInferShapeContext copy_infer_shape_ctx(*copy_op.get(),
copy_runtime_context);
static_cast<const framework::OperatorWithKernel*>(copy_op)->InferShape(
&copy_infer_shape_ctx);
static_cast<const framework::OperatorWithKernel*>(copy_op.get())
->InferShape(&copy_infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(memcpy_op_type);
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
......@@ -347,7 +350,7 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
auto copy_exec_ctx =
ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context);
auto copy_expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(copy_op)
dynamic_cast<const framework::OperatorWithKernel*>(copy_op.get())
->GetExpectedKernelType(copy_exec_ctx);
auto kernel_iter = kernels.find(copy_expected_kernel_key);
copy_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
......@@ -362,19 +365,20 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
return std::make_pair(new_var_name, copy_op_func_node);
}
std::vector<OpFuncNode> apply_data_transform(
const OpKernelType& expected_kernel_key, const platform::Place& place,
VariableValueMap* ins_map_temp, VariableScope* var_scope,
OpFuncNode* op_func_node) {
auto& op_base = op_func_node->operator_base_;
void apply_data_transform(const OpKernelType& expected_kernel_key,
const platform::Place& place,
VariableValueMap* ins_map_temp,
VariableScope* var_scope, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* copy_func_nodes) {
auto op_base = op_func_node->operator_base_.get();
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid "
"op_base in apply_data_transform."));
auto inputs_names = op_base->Inputs();
VariableNameMap new_ins(op_base->Inputs());
std::unordered_set<int>
no_data_transform_index; // record the no need transform variable index.
std::vector<OpFuncNode> copy_func_nodes; // return all the copy opfuncnode.
for (auto& var_name_item : *ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
......@@ -382,7 +386,7 @@ std::vector<OpFuncNode> apply_data_transform(
if (!(var->IsType<LoDTensor>() || var->IsType<SelectedRows>())) {
continue;
}
auto& var_name = inputs_names[var_name_item.first].at(i);
auto& var_name = new_ins[var_name_item.first].at(i);
auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
if (!tensor_in->IsInitialized()) {
continue;
......@@ -404,8 +408,9 @@ std::vector<OpFuncNode> apply_data_transform(
var_name_item.first, *op_func_node, var, var_scope);
op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
copy_func_nodes.push_back(copy_op_func_node);
copy_func_nodes->emplace_back(copy_op_func_node);
var_name_item.second[i] = var_scope->Var(new_var_name);
new_ins[var_name_item.first][i] = new_var_name;
} else if (need_dtype_transform_for_var(kernel_type_for_var,
expected_kernel_key)) {
// TODO(@xiongkun) add dtype judgement here
......@@ -421,8 +426,14 @@ std::vector<OpFuncNode> apply_data_transform(
}
}
}
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
// with instruction
// hot fix, it is not good design here
op_func_node->operator_base_ =
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
op_base->Type(), new_ins, op_base->Outputs(), op_base->Attrs()));
op_func_node->no_data_transform_index = std::move(no_data_transform_index);
return copy_func_nodes;
}
void build_op_func_list(const platform::Place& place,
......@@ -430,16 +441,16 @@ void build_op_func_list(const platform::Place& place,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<std::shared_ptr<OperatorBase>>
ops; // its elements will be moved to vec_func_list
// Step 1: create all ops for current block.
auto ops = create_all_ops(block);
create_all_ops(block, &ops);
auto unused_var_map = get_unused_vars(block, ops);
size_t ops_index = 0;
for (auto& op : block.AllOps()) {
for (size_t i = 0; i < ops.size(); ++i) {
auto op = ops[i].get();
VLOG(6) << "Build OpFuncNode from : " << op->Type();
auto op_base = ops[ops_index++];
auto inputs_names = op->Inputs();
auto outputs_names = op->Outputs();
......@@ -466,20 +477,18 @@ void build_op_func_list(const platform::Place& place,
op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id;
if (dynamic_cast<const framework::OperatorWithKernel*>(op_base) ==
nullptr) {
if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, op_base, &op_func_node);
deal_operator_base(place, var_scope, ops[i], &op_func_node);
} else {
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
InterpretercoreInferShapeContext infer_shape_ctx(*op_base,
runtime_context);
InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT inheritted
// from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
static_cast<const framework::OperatorWithKernel*>(op)->InferShape(
&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE(
......@@ -495,13 +504,13 @@ void build_op_func_list(const platform::Place& place,
auto* dev_ctx = pool.Get(place);
Scope scope;
auto expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(op_base)
dynamic_cast<const framework::OperatorWithKernel*>(op)
->GetExpectedKernelType(
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context));
ExecutionContext(*op, scope, *dev_ctx, runtime_context));
// consider device_guard()
apply_device_guard(
op_base, place,
op, place,
&expected_kernel_key); // change device by the device_guard()
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
......@@ -510,14 +519,14 @@ void build_op_func_list(const platform::Place& place,
std::vector<OpFuncNode> copy_op_to_insert;
// NOTE(xiongkun03): assign op_base here to reduce parameter number of
// apply_data_transform.
op_func_node.operator_base_ = op_base;
copy_op_to_insert = apply_data_transform(
expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node);
op_func_node.operator_base_ = ops[i];
apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, &copy_op_to_insert);
for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item);
}
// step 4. Run op kernel
VLOG(3) << op_base->Type()
VLOG(3) << op->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
if (platform::is_gpu_place(expected_kernel_key.place_)) {
......@@ -533,8 +542,7 @@ void build_op_func_list(const platform::Place& place,
}
op_func_node.dev_ctx_ = dev_ctx;
auto exec_ctx =
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context);
auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context);
auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE(
......@@ -547,9 +555,9 @@ void build_op_func_list(const platform::Place& place,
op_func_node.kernel_func_(exec_ctx);
}
vec_func_list->push_back(op_func_node);
vec_func_list->emplace_back(op_func_node);
// gc---------------------------------------------------------------------------
auto iter = unused_var_map.find(op_base);
auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) {
continue;
}
......@@ -586,7 +594,7 @@ void build_op_func_list(const platform::Place& place,
delete garbages; // free mem
VLOG(3) << "run " << op_base->Type() << " done.";
VLOG(3) << "run " << op->Type() << " done.";
}
}
......
......@@ -629,5 +629,109 @@ void VariableScopeListener::onCreateScope(Scope* Scope) {}
void VariableScopeListener::onDeleteScope(Scope* Scope) {}
void VariableScopeListener::onClear() {}
Instruction::Instruction(size_t id, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx)
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet(
"Required id >= 0, but received id = %d", id));
}
size_t Instruction::Id() const { return id_; }
const std::map<std::string, std::vector<int>>& Instruction::Inputs() const {
return op_func_node_.input_index;
}
const std::map<std::string, std::vector<int>>& Instruction::Outputs() const {
return op_func_node_.output_index;
}
const std::unordered_set<int>& Instruction::NoDataTransformVars() const {
return op_func_node_.no_data_transform_index;
}
OpKernelComputeFunc Instruction::KernelFunc() const {
return op_func_node_.kernel_func_;
}
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
OperatorBase* Instruction::OpBase() const {
auto op_base = op_func_node_.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base shall not be nullptr."));
return op_base.get();
}
NextInstruction& Instruction::NextInstructions() { return next_instruction_; }
const NextInstruction& Instruction::NextInstructions() const {
return next_instruction_;
}
void Instruction::AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); }
const std::vector<size_t>& Instruction::GCCheckVars() const {
return gc_check_var_list_;
}
void Instruction::ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference.
static framework::Scope scope_;
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
}
std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
return runtime_ctx_;
}
std::shared_ptr<InterpretercoreInferShapeContext>
Instruction::InnerInferShapeContext() const {
return infershape_ctx_;
}
std::shared_ptr<ExecutionContext> Instruction::InnerExecutionContext() const {
return execution_ctx_;
}
const platform::DeviceContext& Instruction::DeviceContext() const {
return dev_ctx_;
}
const std::vector<std::pair<Variable*, Variable*>>& Instruction::InplaceInfo()
const {
return vec_inplace_in_to_out_;
}
void Instruction::AddInplace(Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}
const std::vector<EventInter>& Instruction::InputEvents() const {
return intput_events_;
}
const std::vector<EventInter>& Instruction::OutputEvents() const {
return output_events_;
}
void Instruction::AddInputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
intput_events_.emplace_back(var_id, event, waiter_type);
}
void Instruction::AddOutputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
output_events_.emplace_back(var_id, event, waiter_type);
}
} // namespace framework
} // namespace paddle
......@@ -271,7 +271,8 @@ enum class OpFuncType {
class RuntimeInferShapeContext;
struct OpFuncNode {
OperatorBase* operator_base_;
// TODO(zhiqiu): Better make it unique_ptr
std::shared_ptr<OperatorBase> operator_base_;
std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index;
......@@ -283,100 +284,62 @@ struct OpFuncNode {
class Instruction {
public:
Instruction(size_t id, const OpFuncNode& op_func_node,
const platform::DeviceContext& dev_ctx)
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet(
"Required id >= 0, but received id = %d", id));
}
Instruction(size_t id, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx);
size_t Id() const { return id_; }
size_t Id() const;
const std::map<std::string, std::vector<int>>& Inputs() const {
return op_func_node_.input_index;
}
const std::map<std::string, std::vector<int>>& Inputs() const;
const std::map<std::string, std::vector<int>>& Outputs() const {
return op_func_node_.output_index;
}
const std::map<std::string, std::vector<int>>& Outputs() const;
const std::unordered_set<int>& NoDataTransformVars() const {
return op_func_node_.no_data_transform_index;
}
const std::unordered_set<int>& NoDataTransformVars() const;
OpKernelComputeFunc KernelFunc() const { return op_func_node_.kernel_func_; }
OpKernelComputeFunc KernelFunc() const;
OpFuncType KernelType() const { return op_func_node_.type_; }
OpFuncType KernelType() const;
OperatorBase* OpBase() const {
auto* op_base = op_func_node_.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base shall not be nullptr."));
return op_base;
}
OperatorBase* OpBase() const;
NextInstruction& NextInstructions() { return next_instruction_; }
NextInstruction& NextInstructions();
const NextInstruction& NextInstructions() const { return next_instruction_; }
const NextInstruction& NextInstructions() const;
void AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); }
void AddGCCheckVar(size_t id);
const std::vector<size_t>& GCCheckVars() const { return gc_check_var_list_; }
const std::vector<size_t>& GCCheckVars() const;
void ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference.
static framework::Scope scope_;
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
}
const VariableValueMap& out_vars);
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const {
return runtime_ctx_;
}
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
const {
return infershape_ctx_;
}
const;
std::shared_ptr<ExecutionContext> InnerExecutionContext() const {
return execution_ctx_;
}
std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
const platform::DeviceContext& DeviceContext() const { return dev_ctx_; }
const platform::DeviceContext& DeviceContext() const;
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const {
return vec_inplace_in_to_out_;
}
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
void AddInplace(Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}
void AddInplace(Variable* in, Variable* out);
const std::vector<EventInter>& InputEvents() const { return intput_events_; }
const std::vector<EventInter>& InputEvents() const;
const std::vector<EventInter>& OutputEvents() const { return output_events_; }
const std::vector<EventInter>& OutputEvents() const;
void AddInputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
intput_events_.emplace_back(var_id, event, waiter_type);
}
platform::DeviceType waiter_type);
void AddOutputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
output_events_.emplace_back(var_id, event, waiter_type);
}
platform::DeviceType waiter_type);
private:
size_t id_;
const OpFuncNode& op_func_node_; // not owned
OpFuncNode op_func_node_;
const platform::DeviceContext& dev_ctx_; // not owned
std::shared_ptr<RuntimeContext> runtime_ctx_;
......@@ -403,6 +366,11 @@ static bool IsMemcpyH2D(const Instruction& instr) {
static bool IsMemcpyD2H(const Instruction& instr) {
return instr.OpBase()->Type() == kMemcpyD2H;
}
static bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}
} // namespace interpreter
} // namespace framework
......
......@@ -27,33 +27,60 @@ namespace framework {
*
* For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event.
*
* NOTE(zhiqiu): There are two special case that no event is needed:
* 1. the variable is marked as NoDataTransformVar
* 2. the variable is marked as NoNeedDataBuffer
*/
std::vector<size_t> StreamAnalyzer::ParseEventVarIds(
std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
const Instruction& cur_instr, const Instruction& next_instr) {
std::unordered_set<size_t> unique_var_ids;
for (auto& item : cur_instr.Outputs()) {
unique_var_ids.insert(item.second.begin(), item.second.end());
}
std::vector<size_t> new_event_var_ids;
auto is_no_need_buffer = [&next_instr](std::string name) {
auto* op = next_instr.OpBase();
auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (inferer) {
auto no_need_buffer_ins =
inferer(op->Inputs(), op->Outputs(), op->Attrs());
return no_need_buffer_ins.count(name) != 0;
}
return false;
};
std::vector<size_t> need_event_var_ids;
for (auto& item : next_instr.Inputs()) {
for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0 &&
next_instr.NoDataTransformVars().count(var_id) == 0) {
new_event_var_ids.push_back(var_id);
if (unique_var_ids.count(var_id) > 0) {
if (next_instr.NoDataTransformVars().count(var_id)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoDataTransform";
continue;
}
if (is_no_need_buffer(item.first)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoNeedBufferVar";
continue;
}
need_event_var_ids.push_back(var_id);
}
}
}
return new_event_var_ids;
return need_event_var_ids;
}
void StreamAnalyzer::AssociateInputWithEvents(
void StreamAnalyzer::ConstructEventForVar(
const std::vector<size_t>& new_event_var_id, Instruction* next_instr,
platform::DeviceType waiter_type) {
platform::DeviceType waiter_type, const platform::Place& place) {
for (auto var_id : new_event_var_id) {
if (var_id2event_.count(var_id) == 0) {
auto device_event = std::make_shared<platform::DeviceEvent>(
place_, platform::GenerateDeviceEventFlag());
place, platform::GenerateDeviceEventFlag());
var_id2event_.emplace(var_id, std::move(device_event));
}
// Add events for next_instr.inputs
......@@ -69,21 +96,27 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id);
if (IsDirectRun(cur_instr, next_instr)) {
VLOG(4) << "DirectRun: " << cur_instr.OpBase()->Type() << "->"
<< next_instr.OpBase()->Type();
next_instruction.AddDirectRun(next_op_id);
} else {
// Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr);
event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(),
new_event_var_ids.end());
auto need_event_var_ids = GetNeedEventVarIds(cur_instr, next_instr);
event_var_ids.insert(event_var_ids.end(), need_event_var_ids.begin(),
need_event_var_ids.end());
auto waiter_type = GetWaiterType(next_instr);
AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type);
ConstructEventForVar(need_event_var_ids, &next_instr, waiter_type,
cur_instr.DeviceContext().GetPlace());
if (waiter_type == platform::kCPU) { // GPU -> CPU
VLOG(4) << "SyncRun: " << cur_instr.OpBase()->Type() << "->"
<< next_instr.OpBase()->Type();
next_instruction.AddSyncRun(next_op_id);
} else { // GPU -> GPU(different stream)
VLOG(4) << "EventRun: " << cur_instr.OpBase()->Type() << "->"
<< next_instr.OpBase()->Type();
next_instruction.ADDEventRun(next_op_id);
}
}
......@@ -116,12 +149,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
* NOTE(dev): The following cases are considered as directly run:
*
* 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU
* 2. D2H -> CPU
* 3. CPU -> H2D
* 2. CPU -> any (it is possible: CPU op->VAR->GPU op, when var is no need
* buffer or no need data transform)
* 3. D2H -> CPU
* 4. CPU -> H2D
*/
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) {
return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpreter::IsCpuOp(cur_instr) ||
interpreter::IsMemcpyD2H(cur_instr) ||
interpreter::IsMemcpyH2D(next_instr));
}
......
......@@ -35,12 +35,13 @@ class StreamAnalyzer {
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node);
private:
std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr,
const Instruction& next_instr);
std::vector<size_t> GetNeedEventVarIds(const Instruction& cur_instr,
const Instruction& next_instr);
void AssociateInputWithEvents(const std::vector<size_t>& new_event_var_id,
Instruction* next_instr,
platform::DeviceType waiter_type);
void ConstructEventForVar(const std::vector<size_t>& new_event_var_id,
Instruction* next_instr,
platform::DeviceType waiter_type,
const platform::Place& place);
bool IsDirectRun(Instruction& cur_instr, // NOLINT
const Instruction& next_instr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册