未验证 提交 584b4b24 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] fix stream analysis (#37161)

* fix revord_event

* refine class Instruction

* refine Instruction and InterpreterCore

* make instruction and operator_base consistent

* support NoNeedBufferVar in stream_analyzer

* fix place of event

* add vlog before continue
上级 9c591703
...@@ -41,6 +41,16 @@ void RecordEvent(const Instruction& instruction, const platform::Place& place) { ...@@ -41,6 +41,16 @@ void RecordEvent(const Instruction& instruction, const platform::Place& place) {
} }
} }
void RecordEvent(const Instruction& instruction) {
// If InterpreterCore in on CPUPlace, do nothing.
if (platform::is_cpu_place(instruction.DeviceContext().GetPlace())) return;
for (auto& event : instruction.OutputEvents()) {
VLOG(3) << "Record event in out_var_id: " << event.var_id_;
event.event_->Record(&instruction.DeviceContext());
}
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,6 +20,8 @@ namespace framework { ...@@ -20,6 +20,8 @@ namespace framework {
namespace interpreter { namespace interpreter {
void RecordEvent(const Instruction& instruction, const platform::Place& place); void RecordEvent(const Instruction& instruction, const platform::Place& place);
void RecordEvent(const Instruction& instruction);
void WaitEvent(const Instruction& instruction, const platform::Place& place); void WaitEvent(const Instruction& instruction, const platform::Place& place);
} // namespace interpreter } // namespace interpreter
......
...@@ -77,20 +77,22 @@ paddle::framework::FetchList InterpreterCore::Run( ...@@ -77,20 +77,22 @@ paddle::framework::FetchList InterpreterCore::Run(
return *(fetch_var->GetMutable<framework::FetchList>()); return *(fetch_var->GetMutable<framework::FetchList>());
} }
void InterpreterCore::Convert() { void InterpreterCore::Convert(
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) {
auto& vec_meta_info = global_scope_->MutableVecMetaInfo(); auto& vec_meta_info = global_scope_->MutableVecMetaInfo();
auto var_nums = global_scope_->VarSize(); auto var_nums = global_scope_->VarSize();
input_var2op_info_.resize(var_nums); input_var2op_info_.resize(var_nums);
auto nodes = *op_func_nodes;
auto op_nums = vec_func_list_.size(); auto op_nums = nodes.size();
vec_instruction_.reserve(op_nums); vec_instruction_.reserve(op_nums);
dependecy_count_.resize(op_nums); dependecy_count_.resize(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = vec_func_list_[op_idx]; auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, op_func_node, *dev_ctx_); vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
auto& instr = vec_instruction_.back(); auto& instr = vec_instruction_.back();
OpInOutInfo info; OpInOutInfo info;
...@@ -104,7 +106,7 @@ void InterpreterCore::Convert() { ...@@ -104,7 +106,7 @@ void InterpreterCore::Convert() {
input_var2op_info_.at(id).push_back(op_idx); input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed // var can be gc-ed
if (!info.IsBuilt()) { if (!info.IsBuilt()) {
info.Build(op_func_node.operator_base_); info.Build(op_func_node.operator_base_.get());
} }
auto* var_desc = global_scope_->VarDesc(id); auto* var_desc = global_scope_->VarDesc(id);
if (var_desc) { if (var_desc) {
...@@ -522,11 +524,12 @@ void InterpreterCore::Prepare( ...@@ -522,11 +524,12 @@ void InterpreterCore::Prepare(
if (!is_build_) { if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_); paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
FeedInput(); FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list( paddle::framework::interpreter::build_op_func_list(
place_, block_, &vec_func_list_, global_scope_); place_, block_, &op_func_nodes, global_scope_);
is_build_ = true; is_build_ = true;
// convert vec func_list to graph // convert vec func_list to graph
Convert(); Convert(&op_func_nodes);
} }
// NOTE: Because feed_tensor will be GC after // NOTE: Because feed_tensor will be GC after
// paddle::framework::build_op_func_list, so we should // paddle::framework::build_op_func_list, so we should
......
...@@ -54,7 +54,7 @@ class InterpreterCore { ...@@ -54,7 +54,7 @@ class InterpreterCore {
const std::vector<framework::LoDTensor>& feed_tensors); const std::vector<framework::LoDTensor>& feed_tensors);
private: private:
void Convert(); void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
void BuildAndCacheInstructionCtx(Instruction* instr_node); void BuildAndCacheInstructionCtx(Instruction* instr_node);
...@@ -84,7 +84,6 @@ class InterpreterCore { ...@@ -84,7 +84,6 @@ class InterpreterCore {
const BlockDesc& block_; // not owned const BlockDesc& block_; // not owned
VariableScope* global_scope_; // not owned VariableScope* global_scope_; // not owned
std::vector<paddle::framework::OpFuncNode> vec_func_list_;
std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
std::vector<size_t> dependecy_count_; std::vector<size_t> dependecy_count_;
......
...@@ -64,11 +64,12 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) { ...@@ -64,11 +64,12 @@ bool var_can_be_deleted(const std::string& name, const BlockDesc& block) {
std::unordered_map<const paddle::framework::OperatorBase*, std::unordered_map<const paddle::framework::OperatorBase*,
std::vector<std::string>> std::vector<std::string>>
get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) { get_unused_vars(const BlockDesc& block,
const std::vector<std::shared_ptr<OperatorBase>>& ops) {
std::unordered_map<std::string, size_t> var_op_idx_map; std::unordered_map<std::string, size_t> var_op_idx_map;
for (size_t i = 0; i < ops.size(); ++i) { for (size_t i = 0; i < ops.size(); ++i) {
auto* op = ops[i]; const auto& op = ops[i];
OpInOutInfo info; OpInOutInfo info;
for (auto& name_pair : op->Inputs()) { for (auto& name_pair : op->Inputs()) {
...@@ -79,7 +80,7 @@ get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) { ...@@ -79,7 +80,7 @@ get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) {
// var can be gc-ed // var can be gc-ed
if (!info.IsBuilt()) { if (!info.IsBuilt()) {
info.Build(op); info.Build(op.get());
} }
if (info.IsInArgBufferNeeded(name)) { if (info.IsInArgBufferNeeded(name)) {
...@@ -107,7 +108,8 @@ get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) { ...@@ -107,7 +108,8 @@ get_unused_vars(const BlockDesc& block, const std::vector<OperatorBase*>& ops) {
for (auto& name_op_idx_pair : var_op_idx_map) { for (auto& name_op_idx_pair : var_op_idx_map) {
auto& name = name_op_idx_pair.first; auto& name = name_op_idx_pair.first;
size_t op_idx = name_op_idx_pair.second; size_t op_idx = name_op_idx_pair.second;
result[ops[op_idx]].emplace_back(name);
result[ops[op_idx].get()].emplace_back(name);
} }
return result; return result;
} }
...@@ -150,8 +152,8 @@ void build_variable_scope(const framework::BlockDesc& block, ...@@ -150,8 +152,8 @@ void build_variable_scope(const framework::BlockDesc& block,
} }
} }
std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) { void create_all_ops(const framework::BlockDesc& block,
std::vector<OperatorBase*> ops; std::vector<std::shared_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) { for (auto& op : block.AllOps()) {
VLOG(3) << "CreateOp from : " << op->Type(); VLOG(3) << "CreateOp from : " << op->Type();
...@@ -166,9 +168,8 @@ std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) { ...@@ -166,9 +168,8 @@ std::vector<OperatorBase*> create_all_ops(const framework::BlockDesc& block) {
} }
auto op_base = auto op_base =
info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map); info.Creator()(op->Type(), inputs_names, outputs_names, op_attr_map);
ops.push_back(op_base); ops->emplace_back(std::shared_ptr<OperatorBase>(op_base));
} }
return ops;
} }
std::tuple<VariableValueMap, VariableIdMap> build_variable_map( std::tuple<VariableValueMap, VariableIdMap> build_variable_map(
...@@ -234,7 +235,8 @@ void apply_device_guard(const OperatorBase* op_base, ...@@ -234,7 +235,8 @@ void apply_device_guard(const OperatorBase* op_base,
} }
void deal_operator_base(const platform::Place& place, void deal_operator_base(const platform::Place& place,
const VariableScope* var_scope, OperatorBase* op_base, const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node) { OpFuncNode* op_func_node) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
...@@ -321,8 +323,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -321,8 +323,9 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
var_name, kernel_type_for_var.place_, new_var_name, var_name, kernel_type_for_var.place_, new_var_name,
expected_kernel_key.place_); expected_kernel_key.place_);
auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type); auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type);
auto copy_op = auto copy_op = std::shared_ptr<OperatorBase>(
copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map); copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map));
OpFuncNode copy_op_func_node; OpFuncNode copy_op_func_node;
copy_op_func_node.input_index = copy_ins_name2id; copy_op_func_node.input_index = copy_ins_name2id;
copy_op_func_node.output_index = copy_out_name2id; copy_op_func_node.output_index = copy_out_name2id;
...@@ -330,10 +333,10 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -330,10 +333,10 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
RuntimeContext copy_runtime_context({}, {}); RuntimeContext copy_runtime_context({}, {});
copy_runtime_context.inputs.swap(copy_ins_value_map); copy_runtime_context.inputs.swap(copy_ins_value_map);
copy_runtime_context.outputs.swap(copy_outs_value_map); copy_runtime_context.outputs.swap(copy_outs_value_map);
InterpretercoreInferShapeContext copy_infer_shape_ctx(*copy_op, InterpretercoreInferShapeContext copy_infer_shape_ctx(*copy_op.get(),
copy_runtime_context); copy_runtime_context);
static_cast<const framework::OperatorWithKernel*>(copy_op)->InferShape( static_cast<const framework::OperatorWithKernel*>(copy_op.get())
&copy_infer_shape_ctx); ->InferShape(&copy_infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(memcpy_op_type); auto kernels_iter = all_op_kernels.find(memcpy_op_type);
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(), PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
...@@ -347,7 +350,7 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -347,7 +350,7 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
auto copy_exec_ctx = auto copy_exec_ctx =
ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context); ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context);
auto copy_expected_kernel_key = auto copy_expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(copy_op) dynamic_cast<const framework::OperatorWithKernel*>(copy_op.get())
->GetExpectedKernelType(copy_exec_ctx); ->GetExpectedKernelType(copy_exec_ctx);
auto kernel_iter = kernels.find(copy_expected_kernel_key); auto kernel_iter = kernels.find(copy_expected_kernel_key);
copy_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second); copy_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
...@@ -362,19 +365,20 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var( ...@@ -362,19 +365,20 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
return std::make_pair(new_var_name, copy_op_func_node); return std::make_pair(new_var_name, copy_op_func_node);
} }
std::vector<OpFuncNode> apply_data_transform( void apply_data_transform(const OpKernelType& expected_kernel_key,
const OpKernelType& expected_kernel_key, const platform::Place& place, const platform::Place& place,
VariableValueMap* ins_map_temp, VariableScope* var_scope, VariableValueMap* ins_map_temp,
OpFuncNode* op_func_node) { VariableScope* var_scope, OpFuncNode* op_func_node,
auto& op_base = op_func_node->operator_base_; std::vector<OpFuncNode>* copy_func_nodes) {
auto op_base = op_func_node->operator_base_.get();
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet( PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid " "op_base is null, please pass a valid "
"op_base in apply_data_transform.")); "op_base in apply_data_transform."));
auto inputs_names = op_base->Inputs();
VariableNameMap new_ins(op_base->Inputs());
std::unordered_set<int> std::unordered_set<int>
no_data_transform_index; // record the no need transform variable index. no_data_transform_index; // record the no need transform variable index.
std::vector<OpFuncNode> copy_func_nodes; // return all the copy opfuncnode.
for (auto& var_name_item : *ins_map_temp) { for (auto& var_name_item : *ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) { for (size_t i = 0; i < var_name_item.second.size(); ++i) {
...@@ -382,7 +386,7 @@ std::vector<OpFuncNode> apply_data_transform( ...@@ -382,7 +386,7 @@ std::vector<OpFuncNode> apply_data_transform(
if (!(var->IsType<LoDTensor>() || var->IsType<SelectedRows>())) { if (!(var->IsType<LoDTensor>() || var->IsType<SelectedRows>())) {
continue; continue;
} }
auto& var_name = inputs_names[var_name_item.first].at(i); auto& var_name = new_ins[var_name_item.first].at(i);
auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var); auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
if (!tensor_in->IsInitialized()) { if (!tensor_in->IsInitialized()) {
continue; continue;
...@@ -404,8 +408,9 @@ std::vector<OpFuncNode> apply_data_transform( ...@@ -404,8 +408,9 @@ std::vector<OpFuncNode> apply_data_transform(
var_name_item.first, *op_func_node, var, var_scope); var_name_item.first, *op_func_node, var, var_scope);
op_func_node->input_index[var_name_item.first][i] = op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name); var_scope->VarId(new_var_name);
copy_func_nodes.push_back(copy_op_func_node); copy_func_nodes->emplace_back(copy_op_func_node);
var_name_item.second[i] = var_scope->Var(new_var_name); var_name_item.second[i] = var_scope->Var(new_var_name);
new_ins[var_name_item.first][i] = new_var_name;
} else if (need_dtype_transform_for_var(kernel_type_for_var, } else if (need_dtype_transform_for_var(kernel_type_for_var,
expected_kernel_key)) { expected_kernel_key)) {
// TODO(@xiongkun) add dtype judgement here // TODO(@xiongkun) add dtype judgement here
...@@ -421,8 +426,14 @@ std::vector<OpFuncNode> apply_data_transform( ...@@ -421,8 +426,14 @@ std::vector<OpFuncNode> apply_data_transform(
} }
} }
} }
// NOTE(zhiqiu): UPDATE the corresponding OeratorBase to make it consistent
// with instruction
// hot fix, it is not good design here
op_func_node->operator_base_ =
std::shared_ptr<OperatorBase>(framework::OpRegistry::CreateOp(
op_base->Type(), new_ins, op_base->Outputs(), op_base->Attrs()));
op_func_node->no_data_transform_index = std::move(no_data_transform_index); op_func_node->no_data_transform_index = std::move(no_data_transform_index);
return copy_func_nodes;
} }
void build_op_func_list(const platform::Place& place, void build_op_func_list(const platform::Place& place,
...@@ -430,16 +441,16 @@ void build_op_func_list(const platform::Place& place, ...@@ -430,16 +441,16 @@ void build_op_func_list(const platform::Place& place,
std::vector<OpFuncNode>* vec_func_list, std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) { VariableScope* var_scope) {
auto& all_op_kernels = OperatorWithKernel::AllOpKernels(); auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<std::shared_ptr<OperatorBase>>
ops; // its elements will be moved to vec_func_list
// Step 1: create all ops for current block. // Step 1: create all ops for current block.
auto ops = create_all_ops(block); create_all_ops(block, &ops);
auto unused_var_map = get_unused_vars(block, ops); auto unused_var_map = get_unused_vars(block, ops);
size_t ops_index = 0; for (size_t i = 0; i < ops.size(); ++i) {
for (auto& op : block.AllOps()) { auto op = ops[i].get();
VLOG(6) << "Build OpFuncNode from : " << op->Type(); VLOG(6) << "Build OpFuncNode from : " << op->Type();
auto op_base = ops[ops_index++];
auto inputs_names = op->Inputs(); auto inputs_names = op->Inputs();
auto outputs_names = op->Outputs(); auto outputs_names = op->Outputs();
...@@ -466,20 +477,18 @@ void build_op_func_list(const platform::Place& place, ...@@ -466,20 +477,18 @@ void build_op_func_list(const platform::Place& place,
op_func_node.input_index = ins_name2id; op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id; op_func_node.output_index = outs_name2id;
if (dynamic_cast<const framework::OperatorWithKernel*>(op_base) == if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) {
nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run() // op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, op_base, &op_func_node); deal_operator_base(place, var_scope, ops[i], &op_func_node);
} else { } else {
// construct RuntimeContext and analysis KernelType // construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {}); RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map); runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map); runtime_context.outputs.swap(outs_map);
InterpretercoreInferShapeContext infer_shape_ctx(*op_base, InterpretercoreInferShapeContext infer_shape_ctx(*op, runtime_context);
runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT inheritted // TODO(Aurelius84): In case of control flow ops, they are NOT inheritted
// from OperatorWithKernel. // from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape( static_cast<const framework::OperatorWithKernel*>(op)->InferShape(
&infer_shape_ctx); &infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op->Type()); auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -495,13 +504,13 @@ void build_op_func_list(const platform::Place& place, ...@@ -495,13 +504,13 @@ void build_op_func_list(const platform::Place& place,
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
Scope scope; Scope scope;
auto expected_kernel_key = auto expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(op_base) dynamic_cast<const framework::OperatorWithKernel*>(op)
->GetExpectedKernelType( ->GetExpectedKernelType(
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context)); ExecutionContext(*op, scope, *dev_ctx, runtime_context));
// consider device_guard() // consider device_guard()
apply_device_guard( apply_device_guard(
op_base, place, op, place,
&expected_kernel_key); // change device by the device_guard() &expected_kernel_key); // change device by the device_guard()
VLOG(3) << "expected_kernel_key : " << expected_kernel_key; VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
...@@ -510,14 +519,14 @@ void build_op_func_list(const platform::Place& place, ...@@ -510,14 +519,14 @@ void build_op_func_list(const platform::Place& place,
std::vector<OpFuncNode> copy_op_to_insert; std::vector<OpFuncNode> copy_op_to_insert;
// NOTE(xiongkun03): assign op_base here to reduce parameter number of // NOTE(xiongkun03): assign op_base here to reduce parameter number of
// apply_data_transform. // apply_data_transform.
op_func_node.operator_base_ = op_base; op_func_node.operator_base_ = ops[i];
copy_op_to_insert = apply_data_transform( apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope,
expected_kernel_key, place, &ins_map_temp, var_scope, &op_func_node); &op_func_node, &copy_op_to_insert);
for (auto& item : copy_op_to_insert) { for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item); vec_func_list->push_back(item);
} }
// step 4. Run op kernel // step 4. Run op kernel
VLOG(3) << op_base->Type() VLOG(3) << op->Type()
<< " : expected_kernel_key : " << expected_kernel_key; << " : expected_kernel_key : " << expected_kernel_key;
if (platform::is_gpu_place(expected_kernel_key.place_)) { if (platform::is_gpu_place(expected_kernel_key.place_)) {
...@@ -533,8 +542,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -533,8 +542,7 @@ void build_op_func_list(const platform::Place& place,
} }
op_func_node.dev_ctx_ = dev_ctx; op_func_node.dev_ctx_ = dev_ctx;
auto exec_ctx = auto exec_ctx = ExecutionContext(*op, scope, *dev_ctx, runtime_context);
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context);
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
...@@ -547,9 +555,9 @@ void build_op_func_list(const platform::Place& place, ...@@ -547,9 +555,9 @@ void build_op_func_list(const platform::Place& place,
op_func_node.kernel_func_(exec_ctx); op_func_node.kernel_func_(exec_ctx);
} }
vec_func_list->push_back(op_func_node); vec_func_list->emplace_back(op_func_node);
// gc--------------------------------------------------------------------------- // gc---------------------------------------------------------------------------
auto iter = unused_var_map.find(op_base); auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) { if (iter == unused_var_map.end()) {
continue; continue;
} }
...@@ -586,7 +594,7 @@ void build_op_func_list(const platform::Place& place, ...@@ -586,7 +594,7 @@ void build_op_func_list(const platform::Place& place,
delete garbages; // free mem delete garbages; // free mem
VLOG(3) << "run " << op_base->Type() << " done."; VLOG(3) << "run " << op->Type() << " done.";
} }
} }
......
...@@ -629,5 +629,109 @@ void VariableScopeListener::onCreateScope(Scope* Scope) {} ...@@ -629,5 +629,109 @@ void VariableScopeListener::onCreateScope(Scope* Scope) {}
void VariableScopeListener::onDeleteScope(Scope* Scope) {} void VariableScopeListener::onDeleteScope(Scope* Scope) {}
void VariableScopeListener::onClear() {} void VariableScopeListener::onClear() {}
Instruction::Instruction(size_t id, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx)
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet(
"Required id >= 0, but received id = %d", id));
}
size_t Instruction::Id() const { return id_; }
const std::map<std::string, std::vector<int>>& Instruction::Inputs() const {
return op_func_node_.input_index;
}
const std::map<std::string, std::vector<int>>& Instruction::Outputs() const {
return op_func_node_.output_index;
}
const std::unordered_set<int>& Instruction::NoDataTransformVars() const {
return op_func_node_.no_data_transform_index;
}
OpKernelComputeFunc Instruction::KernelFunc() const {
return op_func_node_.kernel_func_;
}
OpFuncType Instruction::KernelType() const { return op_func_node_.type_; }
OperatorBase* Instruction::OpBase() const {
auto op_base = op_func_node_.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base shall not be nullptr."));
return op_base.get();
}
NextInstruction& Instruction::NextInstructions() { return next_instruction_; }
const NextInstruction& Instruction::NextInstructions() const {
return next_instruction_;
}
void Instruction::AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); }
const std::vector<size_t>& Instruction::GCCheckVars() const {
return gc_check_var_list_;
}
void Instruction::ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) {
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference.
static framework::Scope scope_;
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
}
std::shared_ptr<RuntimeContext> Instruction::InnerRuntimeContext() const {
return runtime_ctx_;
}
std::shared_ptr<InterpretercoreInferShapeContext>
Instruction::InnerInferShapeContext() const {
return infershape_ctx_;
}
std::shared_ptr<ExecutionContext> Instruction::InnerExecutionContext() const {
return execution_ctx_;
}
const platform::DeviceContext& Instruction::DeviceContext() const {
return dev_ctx_;
}
const std::vector<std::pair<Variable*, Variable*>>& Instruction::InplaceInfo()
const {
return vec_inplace_in_to_out_;
}
void Instruction::AddInplace(Variable* in, Variable* out) {
vec_inplace_in_to_out_.emplace_back(in, out);
}
const std::vector<EventInter>& Instruction::InputEvents() const {
return intput_events_;
}
const std::vector<EventInter>& Instruction::OutputEvents() const {
return output_events_;
}
void Instruction::AddInputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
intput_events_.emplace_back(var_id, event, waiter_type);
}
void Instruction::AddOutputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) {
output_events_.emplace_back(var_id, event, waiter_type);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -271,7 +271,8 @@ enum class OpFuncType { ...@@ -271,7 +271,8 @@ enum class OpFuncType {
class RuntimeInferShapeContext; class RuntimeInferShapeContext;
struct OpFuncNode { struct OpFuncNode {
OperatorBase* operator_base_; // TODO(zhiqiu): Better make it unique_ptr
std::shared_ptr<OperatorBase> operator_base_;
std::map<std::string, std::vector<int>> input_index; std::map<std::string, std::vector<int>> input_index;
std::map<std::string, std::vector<int>> output_index; std::map<std::string, std::vector<int>> output_index;
std::unordered_set<int> no_data_transform_index; std::unordered_set<int> no_data_transform_index;
...@@ -283,100 +284,62 @@ struct OpFuncNode { ...@@ -283,100 +284,62 @@ struct OpFuncNode {
class Instruction { class Instruction {
public: public:
Instruction(size_t id, const OpFuncNode& op_func_node, Instruction(size_t id, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx) const platform::DeviceContext& dev_ctx);
: id_(id), op_func_node_(op_func_node), dev_ctx_(dev_ctx) {
PADDLE_ENFORCE_GE(id, 0, platform::errors::PreconditionNotMet(
"Required id >= 0, but received id = %d", id));
}
size_t Id() const { return id_; } size_t Id() const;
const std::map<std::string, std::vector<int>>& Inputs() const { const std::map<std::string, std::vector<int>>& Inputs() const;
return op_func_node_.input_index;
}
const std::map<std::string, std::vector<int>>& Outputs() const { const std::map<std::string, std::vector<int>>& Outputs() const;
return op_func_node_.output_index;
}
const std::unordered_set<int>& NoDataTransformVars() const { const std::unordered_set<int>& NoDataTransformVars() const;
return op_func_node_.no_data_transform_index;
}
OpKernelComputeFunc KernelFunc() const { return op_func_node_.kernel_func_; } OpKernelComputeFunc KernelFunc() const;
OpFuncType KernelType() const { return op_func_node_.type_; } OpFuncType KernelType() const;
OperatorBase* OpBase() const { OperatorBase* OpBase() const;
auto* op_base = op_func_node_.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base shall not be nullptr."));
return op_base;
}
NextInstruction& NextInstructions() { return next_instruction_; } NextInstruction& NextInstructions();
const NextInstruction& NextInstructions() const { return next_instruction_; } const NextInstruction& NextInstructions() const;
void AddGCCheckVar(size_t id) { gc_check_var_list_.push_back(id); } void AddGCCheckVar(size_t id);
const std::vector<size_t>& GCCheckVars() const { return gc_check_var_list_; } const std::vector<size_t>& GCCheckVars() const;
void ResetContext(const VariableValueMap& in_vars, void ResetContext(const VariableValueMap& in_vars,
const VariableValueMap& out_vars) { const VariableValueMap& out_vars);
runtime_ctx_.reset(new RuntimeContext(in_vars, out_vars));
infershape_ctx_.reset(
new InterpretercoreInferShapeContext(*OpBase(), *runtime_ctx_.get()));
// NOTE: Because execution_ctx_ is constructed by `scope&`, so we fake an
// empty here to avoid illegal local reference.
static framework::Scope scope_;
execution_ctx_.reset(
new ExecutionContext(*OpBase(), scope_, dev_ctx_, *runtime_ctx_.get()));
}
std::shared_ptr<RuntimeContext> InnerRuntimeContext() const { std::shared_ptr<RuntimeContext> InnerRuntimeContext() const;
return runtime_ctx_;
}
std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext() std::shared_ptr<InterpretercoreInferShapeContext> InnerInferShapeContext()
const { const;
return infershape_ctx_;
}
std::shared_ptr<ExecutionContext> InnerExecutionContext() const { std::shared_ptr<ExecutionContext> InnerExecutionContext() const;
return execution_ctx_;
}
const platform::DeviceContext& DeviceContext() const { return dev_ctx_; } const platform::DeviceContext& DeviceContext() const;
const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const { const std::vector<std::pair<Variable*, Variable*>>& InplaceInfo() const;
return vec_inplace_in_to_out_;
}
void AddInplace(Variable* in, Variable* out) { void AddInplace(Variable* in, Variable* out);
vec_inplace_in_to_out_.emplace_back(in, out);
}
const std::vector<EventInter>& InputEvents() const { return intput_events_; } const std::vector<EventInter>& InputEvents() const;
const std::vector<EventInter>& OutputEvents() const { return output_events_; } const std::vector<EventInter>& OutputEvents() const;
void AddInputEvent(size_t var_id, void AddInputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event, std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) { platform::DeviceType waiter_type);
intput_events_.emplace_back(var_id, event, waiter_type);
}
void AddOutputEvent(size_t var_id, void AddOutputEvent(size_t var_id,
std::shared_ptr<platform::DeviceEvent> event, std::shared_ptr<platform::DeviceEvent> event,
platform::DeviceType waiter_type) { platform::DeviceType waiter_type);
output_events_.emplace_back(var_id, event, waiter_type);
}
private: private:
size_t id_; size_t id_;
const OpFuncNode& op_func_node_; // not owned OpFuncNode op_func_node_;
const platform::DeviceContext& dev_ctx_; // not owned const platform::DeviceContext& dev_ctx_; // not owned
std::shared_ptr<RuntimeContext> runtime_ctx_; std::shared_ptr<RuntimeContext> runtime_ctx_;
...@@ -403,6 +366,11 @@ static bool IsMemcpyH2D(const Instruction& instr) { ...@@ -403,6 +366,11 @@ static bool IsMemcpyH2D(const Instruction& instr) {
static bool IsMemcpyD2H(const Instruction& instr) { static bool IsMemcpyD2H(const Instruction& instr) {
return instr.OpBase()->Type() == kMemcpyD2H; return instr.OpBase()->Type() == kMemcpyD2H;
} }
static bool IsCpuOp(const Instruction& instr) {
return platform::is_cpu_place(instr.DeviceContext().GetPlace());
}
} // namespace interpreter } // namespace interpreter
} // namespace framework } // namespace framework
......
...@@ -27,33 +27,60 @@ namespace framework { ...@@ -27,33 +27,60 @@ namespace framework {
* *
* For example: matmul(gpu) -> out_var -> memcpy_d2h * For example: matmul(gpu) -> out_var -> memcpy_d2h
* out_var should be associated with an event. * out_var should be associated with an event.
*
* NOTE(zhiqiu): There are two special case that no event is needed:
* 1. the variable is marked as NoDataTransformVar
* 2. the variable is marked as NoNeedDataBuffer
*/ */
std::vector<size_t> StreamAnalyzer::ParseEventVarIds( std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
const Instruction& cur_instr, const Instruction& next_instr) { const Instruction& cur_instr, const Instruction& next_instr) {
std::unordered_set<size_t> unique_var_ids; std::unordered_set<size_t> unique_var_ids;
for (auto& item : cur_instr.Outputs()) { for (auto& item : cur_instr.Outputs()) {
unique_var_ids.insert(item.second.begin(), item.second.end()); unique_var_ids.insert(item.second.begin(), item.second.end());
} }
std::vector<size_t> new_event_var_ids; auto is_no_need_buffer = [&next_instr](std::string name) {
auto* op = next_instr.OpBase();
auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (inferer) {
auto no_need_buffer_ins =
inferer(op->Inputs(), op->Outputs(), op->Attrs());
return no_need_buffer_ins.count(name) != 0;
}
return false;
};
std::vector<size_t> need_event_var_ids;
for (auto& item : next_instr.Inputs()) { for (auto& item : next_instr.Inputs()) {
for (auto var_id : item.second) { for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0 && if (unique_var_ids.count(var_id) > 0) {
next_instr.NoDataTransformVars().count(var_id) == 0) { if (next_instr.NoDataTransformVars().count(var_id)) {
new_event_var_ids.push_back(var_id); VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoDataTransform";
continue;
}
if (is_no_need_buffer(item.first)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoNeedBufferVar";
continue;
}
need_event_var_ids.push_back(var_id);
} }
} }
} }
return new_event_var_ids; return need_event_var_ids;
} }
void StreamAnalyzer::AssociateInputWithEvents( void StreamAnalyzer::ConstructEventForVar(
const std::vector<size_t>& new_event_var_id, Instruction* next_instr, const std::vector<size_t>& new_event_var_id, Instruction* next_instr,
platform::DeviceType waiter_type) { platform::DeviceType waiter_type, const platform::Place& place) {
for (auto var_id : new_event_var_id) { for (auto var_id : new_event_var_id) {
if (var_id2event_.count(var_id) == 0) { if (var_id2event_.count(var_id) == 0) {
auto device_event = std::make_shared<platform::DeviceEvent>( auto device_event = std::make_shared<platform::DeviceEvent>(
place_, platform::GenerateDeviceEventFlag()); place, platform::GenerateDeviceEventFlag());
var_id2event_.emplace(var_id, std::move(device_event)); var_id2event_.emplace(var_id, std::move(device_event));
} }
// Add events for next_instr.inputs // Add events for next_instr.inputs
...@@ -69,21 +96,27 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops, ...@@ -69,21 +96,27 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
std::vector<size_t> event_var_ids; std::vector<size_t> event_var_ids;
for (auto next_op_id : downstream_ops) { for (auto next_op_id : downstream_ops) {
auto& next_instr = instructions->at(next_op_id); auto& next_instr = instructions->at(next_op_id);
if (IsDirectRun(cur_instr, next_instr)) { if (IsDirectRun(cur_instr, next_instr)) {
VLOG(4) << "DirectRun: " << cur_instr.OpBase()->Type() << "->"
<< next_instr.OpBase()->Type();
next_instruction.AddDirectRun(next_op_id); next_instruction.AddDirectRun(next_op_id);
} else { } else {
// Always insert events between different stream // Always insert events between different stream
auto new_event_var_ids = ParseEventVarIds(cur_instr, next_instr); auto need_event_var_ids = GetNeedEventVarIds(cur_instr, next_instr);
event_var_ids.insert(event_var_ids.end(), new_event_var_ids.begin(), event_var_ids.insert(event_var_ids.end(), need_event_var_ids.begin(),
new_event_var_ids.end()); need_event_var_ids.end());
auto waiter_type = GetWaiterType(next_instr); auto waiter_type = GetWaiterType(next_instr);
AssociateInputWithEvents(new_event_var_ids, &next_instr, waiter_type); ConstructEventForVar(need_event_var_ids, &next_instr, waiter_type,
cur_instr.DeviceContext().GetPlace());
if (waiter_type == platform::kCPU) { // GPU -> CPU if (waiter_type == platform::kCPU) { // GPU -> CPU
VLOG(4) << "SyncRun: " << cur_instr.OpBase()->Type() << "->"
<< next_instr.OpBase()->Type();
next_instruction.AddSyncRun(next_op_id); next_instruction.AddSyncRun(next_op_id);
} else { // GPU -> GPU(different stream) } else { // GPU -> GPU(different stream)
VLOG(4) << "EventRun: " << cur_instr.OpBase()->Type() << "->"
<< next_instr.OpBase()->Type();
next_instruction.ADDEventRun(next_op_id); next_instruction.ADDEventRun(next_op_id);
} }
} }
...@@ -116,12 +149,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -116,12 +149,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
* NOTE(dev): The following cases are considered as directly run: * NOTE(dev): The following cases are considered as directly run:
* *
* 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU * 1. with same dev_ctx_, such as: CPU -> CPU, GPU -> GPU
* 2. D2H -> CPU * 2. CPU -> any (it is possible: CPU op->VAR->GPU op, when var is no need
* 3. CPU -> H2D * buffer or no need data transform)
* 3. D2H -> CPU
* 4. CPU -> H2D
*/ */
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) { const Instruction& next_instr) {
return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() || return (&cur_instr.DeviceContext() == &next_instr.DeviceContext() ||
interpreter::IsCpuOp(cur_instr) ||
interpreter::IsMemcpyD2H(cur_instr) || interpreter::IsMemcpyD2H(cur_instr) ||
interpreter::IsMemcpyH2D(next_instr)); interpreter::IsMemcpyH2D(next_instr));
} }
......
...@@ -35,12 +35,13 @@ class StreamAnalyzer { ...@@ -35,12 +35,13 @@ class StreamAnalyzer {
platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node); platform::DeviceContext* ParseDeviceContext(const OpFuncNode& op_func_node);
private: private:
std::vector<size_t> ParseEventVarIds(const Instruction& cur_instr, std::vector<size_t> GetNeedEventVarIds(const Instruction& cur_instr,
const Instruction& next_instr); const Instruction& next_instr);
void AssociateInputWithEvents(const std::vector<size_t>& new_event_var_id, void ConstructEventForVar(const std::vector<size_t>& new_event_var_id,
Instruction* next_instr, Instruction* next_instr,
platform::DeviceType waiter_type); platform::DeviceType waiter_type,
const platform::Place& place);
bool IsDirectRun(Instruction& cur_instr, // NOLINT bool IsDirectRun(Instruction& cur_instr, // NOLINT
const Instruction& next_instr); const Instruction& next_instr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册