未验证 提交 cf8e15ba 编写于 作者: H hong 提交者: GitHub

[NewIR]new ir support yaml backend config (#56570)

* update

* fix comile error
上级 e518058a
...@@ -151,8 +151,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ ...@@ -151,8 +151,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }}; std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }}; std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }}; std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{inplace}}}, {{{view}}}); paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}} }}
""" """
...@@ -1013,6 +1012,7 @@ def OpGenerator( ...@@ -1013,6 +1012,7 @@ def OpGenerator(
kernel_func_str = "" kernel_func_str = ""
kernel_param_str = "" kernel_param_str = ""
kernel_key_dtype = "" kernel_key_dtype = ""
kernel_key_backend = ""
if op_kernel_map is not None: if op_kernel_map is not None:
kernel_func_str = '", "'.join(op_kernel_map['func']) kernel_func_str = '", "'.join(op_kernel_map['func'])
kernel_param_str = '", "'.join(op_kernel_map['param']) kernel_param_str = '", "'.join(op_kernel_map['param'])
...@@ -1022,6 +1022,12 @@ def OpGenerator( ...@@ -1022,6 +1022,12 @@ def OpGenerator(
) )
if kernel_key_dtype != "": if kernel_key_dtype != "":
kernel_key_dtype = '"' + kernel_key_dtype + '"' kernel_key_dtype = '"' + kernel_key_dtype + '"'
if 'backend' in op_kernel_map and op_kernel_map['backend']:
kernel_key_backend = '", "'.join(
op_kernel_map['backend']['candidates']
)
if kernel_key_backend != "":
kernel_key_backend = '"' + kernel_key_backend + '"'
inplace_str = "" inplace_str = ""
view_str = "" view_str = ""
...@@ -1045,6 +1051,7 @@ def OpGenerator( ...@@ -1045,6 +1051,7 @@ def OpGenerator(
kernel_func=kernel_func_str, kernel_func=kernel_func_str,
kernel_param=kernel_param_str, kernel_param=kernel_param_str,
kernel_key_dtype=kernel_key_dtype, kernel_key_dtype=kernel_key_dtype,
kernel_key_backend=kernel_key_backend,
inplace=inplace_str, inplace=inplace_str,
view=view_str, view=view_str,
origin_op_name=op_info.op_yaml_item['name'], origin_op_name=op_info.op_yaml_item['name'],
......
...@@ -39,7 +39,7 @@ OpInfoTuple AddNOp::GetOpInfo() { ...@@ -39,7 +39,7 @@ OpInfoTuple AddNOp::GetOpInfo() {
std::vector<paddle::dialect::OpOutputInfo> outputs = { std::vector<paddle::dialect::OpOutputInfo> outputs = {
OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)}; OpOutputInfo("out", "paddle::dialect::DenseTensorType", false, false)};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}); OpRunTimeInfo("", {""}, {""}, {""}, {""}, {}, {}, {});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n"); return std::make_tuple(inputs, attributes, outputs, run_time_info, "add_n");
} }
......
...@@ -77,6 +77,7 @@ struct OpRunTimeInfo { ...@@ -77,6 +77,7 @@ struct OpRunTimeInfo {
std::vector<std::string> kernel_func; std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param; std::vector<std::string> kernel_param;
std::vector<std::string> kernel_key_dtype; std::vector<std::string> kernel_key_dtype;
std::vector<std::string> kernel_key_backend;
std::vector<std::pair<std::string, std::string>> inplace; std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view; std::vector<std::pair<std::string, std::string>> view;
OpRunTimeInfo(const std::string& infer_meta_func, OpRunTimeInfo(const std::string& infer_meta_func,
...@@ -84,6 +85,7 @@ struct OpRunTimeInfo { ...@@ -84,6 +85,7 @@ struct OpRunTimeInfo {
const std::vector<std::string>& kernel_func, const std::vector<std::string>& kernel_func,
const std::vector<std::string>& kernel_param, const std::vector<std::string>& kernel_param,
const std::vector<std::string>& dtype, const std::vector<std::string>& dtype,
const std::vector<std::string>& backend,
const std::vector<std::pair<std::string, std::string>>& inplace, const std::vector<std::pair<std::string, std::string>>& inplace,
const std::vector<std::pair<std::string, std::string>>& view) const std::vector<std::pair<std::string, std::string>>& view)
: infer_meta_func(infer_meta_func), : infer_meta_func(infer_meta_func),
...@@ -91,6 +93,7 @@ struct OpRunTimeInfo { ...@@ -91,6 +93,7 @@ struct OpRunTimeInfo {
kernel_func(kernel_func), kernel_func(kernel_func),
kernel_param(kernel_param), kernel_param(kernel_param),
kernel_key_dtype(dtype), kernel_key_dtype(dtype),
kernel_key_backend(backend),
inplace(inplace), inplace(inplace),
view(view) {} view(view) {}
}; };
......
...@@ -209,6 +209,150 @@ ir::OpResult AddPlaceTransferOp(ir::OpResult in, ...@@ -209,6 +209,150 @@ ir::OpResult AddPlaceTransferOp(ir::OpResult in,
} }
} }
phi::DataType GetKernelDataTypeByYamlInfo(
const ir::Operation* op,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
const dialect::OpYamlInfoParser* op_info_parser) {
auto& attr_map = op->attributes();
auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype;
phi::DataType kernel_data_type = phi::DataType::UNDEFINED;
for (size_t i = 0; i < data_type_info.size(); ++i) {
auto slot_name = data_type_info[i];
auto& input_map = op_info_parser->InputName2Id();
auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand_source(in_index)).type();
if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedDenseTensorType>().dtype());
} else if (type.isa<ir::VectorType>()) {
auto vec_data = type.dyn_cast<ir::VectorType>().data();
if (vec_data.empty()) {
kernel_data_type = phi::DataType::UNDEFINED;
} else {
if (vec_data[0].isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
vec_data[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType in vector"));
}
}
} else if (type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}
} else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true,
phi::errors::PreconditionNotMet(
"[%s] MUST in attribute map", slot_name));
auto attr_type = op_info_parser->AttrTypeName(slot_name);
PADDLE_ENFORCE_EQ(attr_type,
"paddle::dialect::DataTypeAttribute",
phi::errors::PreconditionNotMet(
"Type of [%s] should be DataType", slot_name));
kernel_data_type = attr_map.at(slot_name)
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data();
}
if (kernel_data_type != phi::DataType::UNDEFINED) {
// In yaml definition, data type have an order
// like: data_type : dtype > x
// Should break when found a defined data type
break;
}
}
return kernel_data_type;
}
phi::Backend GetKernelBackendByYamlInfo(
const ir::Operation* op,
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair,
const dialect::OpYamlInfoParser* op_info_parser) {
auto& attr_map = op->attributes();
auto& backend_info = op_info_parser->OpRuntimeInfo().kernel_key_backend;
phi::Backend kernel_backend = phi::Backend::UNDEFINED;
for (size_t i = 0; i < backend_info.size(); ++i) {
auto slot_name = backend_info[i];
auto& input_map = op_info_parser->InputName2Id();
if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand_source(in_index)).type();
if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_backend = paddle::experimental::ParseBackend(
type.dyn_cast<paddle::dialect::AllocatedDenseTensorType>().place());
} else if (type.isa<ir::VectorType>()) {
auto vec_data = type.dyn_cast<ir::VectorType>().data();
if (vec_data.empty()) {
kernel_backend = phi::Backend::UNDEFINED;
} else {
if (vec_data[0].isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_backend = paddle::experimental::ParseBackend(
vec_data[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType in vector"));
}
}
} else if (type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
kernel_backend = paddle::experimental::ParseBackend(
type.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.place());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}
} else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true,
phi::errors::PreconditionNotMet(
"[%s] MUST in attribute map", slot_name));
auto attr_type = op_info_parser->AttrTypeName(slot_name);
PADDLE_ENFORCE_EQ(attr_type,
"paddle::dialect::PlaceAttribute",
phi::errors::PreconditionNotMet(
"Type of [%s] should be DataType", slot_name));
kernel_backend = paddle::experimental::ParseBackend(
attr_map.at(slot_name)
.dyn_cast<paddle::dialect::PlaceAttribute>()
.data());
}
if (kernel_backend != phi::Backend::UNDEFINED) {
// In yaml definition, backend have an order
// like: backend : place > x
// Should break when found a defined data type
break;
}
}
return kernel_backend;
}
phi::KernelKey GetKernelKey( phi::KernelKey GetKernelKey(
ir::Operation* op, ir::Operation* op,
const phi::Place& place, const phi::Place& place,
...@@ -245,66 +389,11 @@ phi::KernelKey GetKernelKey( ...@@ -245,66 +389,11 @@ phi::KernelKey GetKernelKey(
// only suppurt non vector input for now // only suppurt non vector input for now
int tensor_input_number = op_info_parser->InputTensorNumber(); int tensor_input_number = op_info_parser->InputTensorNumber();
auto attr_map = op->attributes(); // get datatype info
auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; kernel_data_type =
GetKernelDataTypeByYamlInfo(op, map_value_pair, op_info_parser);
if (!data_type_info.empty()) { kernel_backend =
// only support single input and attribute GetKernelBackendByYamlInfo(op, map_value_pair, op_info_parser);
auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->InputName2Id();
auto find_it = Str2PhiDataType.find(slot_name);
if (find_it != Str2PhiDataType.end()) {
kernel_data_type = find_it->second;
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand_source(in_index)).type();
if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else if (type.isa<ir::VectorType>()) {
auto vec_data = type.dyn_cast<ir::VectorType>().data();
if (vec_data.empty()) {
kernel_data_type = phi::DataType::UNDEFINED;
} else {
if (vec_data[0].isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
vec_data[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType in vector"));
}
}
} else if (type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
kernel_data_type = TransToPhiDataType(
type.dyn_cast<paddle::dialect::AllocatedSelectedRowsType>()
.dtype());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support DenseTensorType, SelectedRows, VectorType"));
}
} else {
PADDLE_ENFORCE_EQ(attr_map.count(slot_name),
true,
phi::errors::PreconditionNotMet(
"[%s] MUST in attribute map", slot_name));
auto attr_type = op_info_parser->AttrTypeName(slot_name);
PADDLE_ENFORCE_EQ(attr_type,
"paddle::dialect::DataTypeAttribute",
phi::errors::PreconditionNotMet(
"Type of [%s] should be DataType", slot_name));
kernel_data_type = attr_map.at(slot_name)
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data();
}
}
// parse all the input tensor // parse all the input tensor
if (tensor_input_number == 0 || op->name() == "pd.full_") { if (tensor_input_number == 0 || op->name() == "pd.full_") {
......
...@@ -488,6 +488,7 @@ OpInfoTuple Conv2dFusionOpTest::GetOpInfo() { ...@@ -488,6 +488,7 @@ OpInfoTuple Conv2dFusionOpTest::GetOpInfo() {
"user_workspace_size"}, "user_workspace_size"},
{"input"}, {"input"},
{}, {},
{},
{}); {});
return std::make_tuple( return std::make_tuple(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册