未验证 提交 a15fec8b 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Fix BuildOperatorBase bug (#56552)

* fix inplace with set_parameter

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* refine code

* refine code

* fix bug

* refine op_gem

* fix bug
上级 9ad06e06
...@@ -151,7 +151,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ ...@@ -151,7 +151,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }}; std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }}; std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }}; std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{"{kernel_key_dtype}"}}, {{{inplace}}}, {{{view}}}); paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}"); return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}} }}
...@@ -1020,6 +1020,8 @@ def OpGenerator( ...@@ -1020,6 +1020,8 @@ def OpGenerator(
kernel_key_dtype = '", "'.join( kernel_key_dtype = '", "'.join(
op_kernel_map['data_type']['candidates'] op_kernel_map['data_type']['candidates']
) )
if kernel_key_dtype != "":
kernel_key_dtype = '"' + kernel_key_dtype + '"'
inplace_str = "" inplace_str = ""
view_str = "" view_str = ""
......
...@@ -704,6 +704,51 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase( ...@@ -704,6 +704,51 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
attr_map[name] = val.dyn_cast<ir::DoubleAttribute>().data(); attr_map[name] = val.dyn_cast<ir::DoubleAttribute>().data();
} else if (val.isa<ir::Int64Attribute>()) { } else if (val.isa<ir::Int64Attribute>()) {
attr_map[name] = val.dyn_cast<ir::Int64Attribute>().data(); attr_map[name] = val.dyn_cast<ir::Int64Attribute>().data();
} else if (val.isa<ir::ArrayAttribute>()) {
auto array_list = val.dyn_cast<ir::ArrayAttribute>().AsVector();
PADDLE_ENFORCE(
array_list.size() > 0,
paddle::platform::errors::Fatal("Attribute %s is empty", name));
if (array_list[0].isa<ir::Int32Attribute>()) {
std::vector<int> vec_int;
for (auto attribute : array_list) {
vec_int.push_back(attribute.dyn_cast<ir::Int32Attribute>().data());
}
attr_map[name] = vec_int;
} else if (array_list[0].isa<ir::Int64Attribute>()) {
std::vector<int> vec_int64;
for (auto attribute : array_list) {
vec_int64.push_back(attribute.dyn_cast<ir::Int64Attribute>().data());
}
attr_map[name] = vec_int64;
} else if (array_list[0].isa<ir::BoolAttribute>()) {
std::vector<int> vec_bool;
for (auto attribute : array_list) {
vec_bool.push_back(attribute.dyn_cast<ir::BoolAttribute>().data());
}
attr_map[name] = vec_bool;
} else if (array_list[0].isa<ir::FloatAttribute>()) {
std::vector<int> vec_float;
for (auto attribute : array_list) {
vec_float.push_back(attribute.dyn_cast<ir::FloatAttribute>().data());
}
attr_map[name] = vec_float;
} else if (array_list[0].isa<ir::DoubleAttribute>()) {
std::vector<int> vec_double;
for (auto attribute : array_list) {
vec_double.push_back(
attribute.dyn_cast<ir::DoubleAttribute>().data());
}
attr_map[name] = vec_double;
} else {
std::stringstream ss;
val.Print(ss);
VLOG(1) << "type not support " << ss.str() << std::endl;
PADDLE_THROW("Type[%s] in attribute map not support yet", ss.str());
}
} else if (val.isa<paddle::dialect::DataTypeAttribute>()) {
attr_map[name] = paddle::framework::TransToProtoVarType(
val.dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else { } else {
std::stringstream ss; std::stringstream ss;
val.Print(ss); val.Print(ss);
......
...@@ -248,7 +248,7 @@ phi::KernelKey GetKernelKey( ...@@ -248,7 +248,7 @@ phi::KernelKey GetKernelKey(
auto attr_map = op->attributes(); auto attr_map = op->attributes();
auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype;
if (!data_type_info.empty() && !data_type_info[0].empty()) { if (!data_type_info.empty()) {
// only support single input and attribute // only support single input and attribute
auto slot_name = data_type_info[0]; auto slot_name = data_type_info[0];
auto& input_map = op_info_parser->InputName2Id(); auto& input_map = op_info_parser->InputName2Id();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册