未验证 提交 9d2dd727 编写于 作者: X xiongkun 提交者: GitHub

Refactor apply transformer (#36899)

* bugfix: ps mode can't set backend automatically

* refactor

* fix

* refact

* refine code

* refine

* push
上级 7ee727a8
......@@ -229,97 +229,41 @@ void apply_device_guard(const OperatorBase* op_base,
}
}
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
// Step 1: create all ops for global block.
auto ops = create_all_ops(global_block);
auto unused_var_map = get_unused_vars(global_block, ops);
size_t ops_index = 0;
for (auto& op : global_block.AllOps()) {
VLOG(6) << "Build OpFuncNode from : " << op->Type();
auto op_base = ops[ops_index++];
auto inputs_names = op->Inputs();
auto outputs_names = op->Outputs();
VariableValueMap ins_map;
VariableIdMap ins_name2id;
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope);
VariableValueMap outs_map;
VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) =
build_variable_map(outputs_names, var_scope);
// step 2: build OpFuncNode
OpFuncNode op_func_node;
op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id;
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
InterpretercoreInferShapeContext infer_shape_ctx(*op_base, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT inheritted
// from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
op->Type()));
OpKernelMap& kernels = kernels_iter->second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
Scope scope;
auto expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(op_base)
->GetExpectedKernelType(
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context));
// consider device_guard()
apply_device_guard(op_base, place, &expected_kernel_key);
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
// step 3. Insert memcpy_op if needed
VariableValueMap& ins_map_temp = runtime_context.inputs;
std::unordered_set<int> no_data_transform_index;
for (auto& var_name_item : ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i];
auto& var_name = inputs_names[var_name_item.first].at(i);
auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
if (!tensor_in->IsInitialized()) {
continue;
}
auto kernel_type_for_var =
static_cast<const framework::OperatorWithKernel*>(op_base)
->GetKernelTypeForVar(var_name_item.first, *tensor_in,
expected_kernel_key);
// the return value is whether data transformer is needed for this var
bool need_place_transform_for_var(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) {
if (platform::is_same_place(kernel_type_for_var.place_,
expected_kernel_key.place_) ||
(is_cuda_pinned_place(kernel_type_for_var.place_) &&
is_cpu_place(expected_kernel_key.place_))) {
// record no need data transformer input var_id
VLOG(3) << op->Type() << " found no data_transform var: " << var_name
<< " with id: " << var_name;
no_data_transform_index.emplace(var_scope->VarId(var_name));
return false;
} else {
if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false);
return true;
}
}
bool need_dtype_transform_for_var(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) {
return false; // TODO(@xiongkun) add dtype judgement here
}
bool need_layout_transform_for_var(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key) {
return false; // TODO(@xiongkun) add layout judgement here
}
// NOTE(@xiongkun03)
// the difference between var_name and outer_name :
// if "X": ["var1", "var2"], then X is the outer name,
// var1 and var2 is the var_name
std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key, const platform::Place& place,
const std::string& var_name, const std::string& outer_name,
const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope) {
auto& ins_name2id = op_func_node.input_index;
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
std::string new_var_name =
var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1);
var_scope->AddVar(new_var_name, nullptr);
......@@ -335,28 +279,24 @@ void build_op_func_list(const platform::Place& place,
: is_gpu_place(expected_kernel_key.place_) ? 1 : -1;
std::map<std::string, std::vector<int>> copy_ins_name2id;
copy_ins_name2id["X"] = ins_name2id.at(var_name_item.first);
copy_ins_name2id["X"] = ins_name2id.at(outer_name);
std::map<std::string, std::vector<int>> copy_out_name2id;
copy_out_name2id["Out"] = {var_scope->VarId(new_var_name)};
op_func_node.input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
VariableValueMap copy_ins_value_map;
copy_ins_value_map["X"] = {var};
VariableValueMap copy_outs_value_map;
copy_outs_value_map["Out"] = {var_scope->Var(new_var_name)};
// memcpy_d2h, memcpy_h2d
auto memcpy_op_type = get_memcpy_type(kernel_type_for_var.place_,
expected_kernel_key.place_);
VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).",
memcpy_op_type, var_name,
kernel_type_for_var.place_, new_var_name,
auto memcpy_op_type =
get_memcpy_type(kernel_type_for_var.place_, expected_kernel_key.place_);
VLOG(3) << string::Sprintf("Insert %s with %s(%s) -> %s(%s).", memcpy_op_type,
var_name, kernel_type_for_var.place_, new_var_name,
expected_kernel_key.place_);
auto& copy_info = OpInfoMap::Instance().Get(memcpy_op_type);
auto copy_op = copy_info.Creator()(memcpy_op_type, copy_in_map,
copy_out_map, attr_map);
auto copy_op =
copy_info.Creator()(memcpy_op_type, copy_in_map, copy_out_map, attr_map);
OpFuncNode copy_op_func_node;
copy_op_func_node.input_index = copy_ins_name2id;
copy_op_func_node.output_index = copy_out_name2id;
......@@ -364,10 +304,10 @@ void build_op_func_list(const platform::Place& place,
RuntimeContext copy_runtime_context({}, {});
copy_runtime_context.inputs.swap(copy_ins_value_map);
copy_runtime_context.outputs.swap(copy_outs_value_map);
InterpretercoreInferShapeContext copy_infer_shape_ctx(
*copy_op, copy_runtime_context);
static_cast<const framework::OperatorWithKernel*>(copy_op)
->InferShape(&copy_infer_shape_ctx);
InterpretercoreInferShapeContext copy_infer_shape_ctx(*copy_op,
copy_runtime_context);
static_cast<const framework::OperatorWithKernel*>(copy_op)->InferShape(
&copy_infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(memcpy_op_type);
PADDLE_ENFORCE_NE(kernels_iter, all_op_kernels.end(),
......@@ -380,12 +320,11 @@ void build_op_func_list(const platform::Place& place,
Scope scope;
auto copy_exec_ctx =
ExecutionContext(*copy_op, scope, *dev_ctx, copy_runtime_context);
auto expected_kernel_key =
auto copy_expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(copy_op)
->GetExpectedKernelType(copy_exec_ctx);
auto kernel_iter = kernels.find(expected_kernel_key);
copy_op_func_node.kernel_func_ =
OpKernelComputeFunc(kernel_iter->second);
auto kernel_iter = kernels.find(copy_expected_kernel_key);
copy_op_func_node.kernel_func_ = OpKernelComputeFunc(kernel_iter->second);
copy_op_func_node.kernel_func_(copy_exec_ctx);
VLOG(3) << "Run " << memcpy_op_type << " done.";
// NOTE(Aurelius84): memcpy_op is expensive operation, so we tag them
......@@ -393,15 +332,147 @@ void build_op_func_list(const platform::Place& place,
copy_op_func_node.type_ = OpFuncType::kQueueSync;
copy_op_func_node.dev_ctx_ = dev_ctx;
copy_op_func_node.operator_base_ = copy_op;
vec_func_list->push_back(copy_op_func_node);
return std::make_pair(new_var_name, copy_op_func_node);
}
std::vector<OpFuncNode> apply_data_transform(
const OpKernelType& expected_kernel_key, const platform::Place& place,
VariableValueMap& ins_map_temp, VariableScope* var_scope,
OpFuncNode& op_func_node) {
auto& op_base = op_func_node.operator_base_;
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid "
"op_base in apply_data_transform."));
auto inputs_names = op_base->Inputs();
std::unordered_set<int>
no_data_transform_index; // record the no need transform variable index.
std::vector<OpFuncNode> copy_func_nodes; // return all the copy opfuncnode.
for (auto& var_name_item : ins_map_temp) {
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var = var_name_item.second[i];
auto& var_name = inputs_names[var_name_item.first].at(i);
auto tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
if (!tensor_in->IsInitialized()) {
continue;
}
auto kernel_type_for_var = // the true kernel type for op_base
static_cast<const framework::OperatorWithKernel*>(op_base)
->GetKernelTypeForVar(var_name_item.first, *tensor_in,
expected_kernel_key);
if (need_place_transform_for_var(kernel_type_for_var,
expected_kernel_key)) {
if (op_base->Type() == "fetch_v2") {
op_base->SetAttr("deepcopy", false);
}
std::string new_var_name;
OpFuncNode copy_op_func_node;
std::tie(new_var_name, copy_op_func_node) =
apply_place_transform_for_var(
kernel_type_for_var, expected_kernel_key, place, var_name,
var_name_item.first, op_func_node, var, var_scope);
op_func_node.input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
copy_func_nodes.push_back(copy_op_func_node);
var_name_item.second[i] = var_scope->Var(new_var_name);
} else if (need_dtype_transform_for_var(kernel_type_for_var,
expected_kernel_key)) {
// TODO(@xiongkun) add dtype judgement here
} else if (need_layout_transform_for_var(kernel_type_for_var,
expected_kernel_key)) {
// TODO(@xiongkun) add layout judgement here
} else {
// record no need data transformer input var_id
VLOG(3) << op_base->Type()
<< " found no data_transform var: " << var_name
<< " with id: " << var_scope->VarId(var_name);
no_data_transform_index.emplace(var_scope->VarId(var_name));
}
}
}
op_func_node.no_data_transform_index = std::move(no_data_transform_index);
// step 4. Run op kernel
return copy_func_nodes;
}
void build_op_func_list(const platform::Place& place,
const framework::ProgramDesc& pdesc,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
auto& global_block = pdesc.Block(0);
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
// Step 1: create all ops for global block.
auto ops = create_all_ops(global_block);
auto unused_var_map = get_unused_vars(global_block, ops);
size_t ops_index = 0;
for (auto& op : global_block.AllOps()) {
VLOG(6) << "Build OpFuncNode from : " << op->Type();
auto op_base = ops[ops_index++];
auto inputs_names = op->Inputs();
auto outputs_names = op->Outputs();
VariableValueMap ins_map;
VariableIdMap ins_name2id;
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope);
VariableValueMap outs_map;
VariableIdMap outs_name2id;
std::tie(outs_map, outs_name2id) =
build_variable_map(outputs_names, var_scope);
// step 2: build OpFuncNode
OpFuncNode op_func_node;
op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id;
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
runtime_context.inputs.swap(ins_map);
runtime_context.outputs.swap(outs_map);
InterpretercoreInferShapeContext infer_shape_ctx(*op_base, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT inheritted
// from OperatorWithKernel.
static_cast<const framework::OperatorWithKernel*>(op_base)->InferShape(
&infer_shape_ctx);
auto kernels_iter = all_op_kernels.find(op->Type());
PADDLE_ENFORCE_NE(
kernels_iter, all_op_kernels.end(),
platform::errors::Unavailable(
"There are no kernels which are registered in the %s operator.",
op->Type()));
OpKernelMap& kernels = kernels_iter->second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
Scope scope;
auto expected_kernel_key =
dynamic_cast<const framework::OperatorWithKernel*>(op_base)
->GetExpectedKernelType(
ExecutionContext(*op_base, scope, *dev_ctx, runtime_context));
// consider device_guard()
apply_device_guard(
op_base, place,
&expected_kernel_key); // change device by the device_guard()
VLOG(3) << "expected_kernel_key : " << expected_kernel_key;
// step 3. apply data transforms and insert memory ops
VariableValueMap& ins_map_temp = runtime_context.inputs;
std::vector<OpFuncNode> copy_op_to_insert;
// NOTE(xiongkun03): assign op_base here to reduce parameter number of
// apply_data_transform.
op_func_node.operator_base_ = op_base;
copy_op_to_insert = apply_data_transform(
expected_kernel_key, place, ins_map_temp, var_scope, op_func_node);
for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item);
}
// step 4. Run op kernel
VLOG(3) << op_base->Type()
<< " : expected_kernel_key : " << expected_kernel_key;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册