未验证 提交 5b97278e 编写于 作者: Z zhangbo9674 提交者: GitHub

add inplace view info to OpYamlInfoInterface (#54551)

上级 a56eba3a
......@@ -110,7 +110,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}});
paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info);
}}
"""
......@@ -386,6 +386,10 @@ class OpInfoParser:
else:
self.infer_shape_func = None
# parse inplace && view
self.inplace_map = self.parse_op_inplace_info()
self.view_map = self.parse_op_view_info()
def cross_check(self, name_list, type_list, optional_list=None):
assert len(name_list) == len(
type_list
......@@ -396,7 +400,9 @@ class OpInfoParser:
), "type list size != optional list size."
def parse_op_phi_name(self):
if self.parse_op_inplace_info() is None:
if (self.parse_op_inplace_info() is None) and (
self.parse_op_view_info() is None
):
return [self.op_yaml_item['name']]
else:
if self.op_yaml_item['name'][-1] == "_":
......@@ -412,6 +418,11 @@ class OpInfoParser:
return self.op_yaml_item['inplace']
return None
def parse_op_view_info(self):
if 'view' in self.op_yaml_item:
return self.op_yaml_item['view']
return None
def parse_mutable_attribute(self):
"""
{'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'}
......@@ -1256,6 +1267,8 @@ def OpGenerator(
# others
op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map
op_inplace_map = op_info.inplace_map
op_view_map = op_info.view_map
op_interfaces = ["OpYamlInfoInterface"]
op_traits = []
......@@ -1472,12 +1485,25 @@ def OpGenerator(
if op_infer_meta_map is not None:
infer_meta_func_str = op_infer_meta_map['func']
infer_meta_param_str = '", "'.join(op_infer_meta_map['param'])
kernel_func_str = ""
kernel_param_str = ""
if op_kernel_map is not None:
kernel_func_str = '", "'.join(op_kernel_map['func'])
kernel_param_str = '", "'.join(op_kernel_map['param'])
inplace_str = ""
view_str = ""
if op_name[-1] == "_":
if op_inplace_map is not None:
for key, value in op_inplace_map.items():
inplace_str += '{"' + key + '", "' + value + '"},'
inplace_str = inplace_str[:-1]
if op_view_map is not None:
for key, value in op_view_map.items():
view_str += '{"' + key + '", "' + value + '"},'
view_str = view_str[:-1]
op_info_func_str = OP_INFO_TEMPLATE.format(
op_name=op_class_name,
inputs=inputs_info_str,
......@@ -1487,6 +1513,8 @@ def OpGenerator(
infer_meta_param=infer_meta_param_str,
kernel_func=kernel_func_str,
kernel_param=kernel_param_str,
inplace=inplace_str,
view=view_str,
)
# =================================== #
......
......@@ -144,14 +144,20 @@ struct OpRunTimeInfo {
std::vector<std::string> infer_meta_param;
std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param;
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
OpRunTimeInfo(std::string infer_meta_func,
std::vector<std::string> infer_meta_param,
std::vector<std::string> kernel_func,
std::vector<std::string> kernel_param)
std::vector<std::string> kernel_param,
std::vector<std::pair<std::string, std::string>> inplace,
std::vector<std::pair<std::string, std::string>> view)
: infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param),
kernel_func(kernel_func),
kernel_param(kernel_param) {}
kernel_param(kernel_param),
inplace(inplace),
view(view) {}
};
} // namespace dialect
......
......@@ -259,12 +259,22 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
return kernel
def delete_bracket(name: str):
if name[0] == "(":
name = name.lstrip("(")
if name[-1] == ")":
name = name.rstrip(")")
return name
def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
inplace_map = {}
inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
pairs = parse_plain_list(inplace_cfg)
for pair in pairs:
in_name, out_name = parse_plain_list(pair, sep="->")
in_name = delete_bracket(in_name)
out_name = delete_bracket(out_name)
inplace_map[out_name] = in_name
return inplace_map
......@@ -521,11 +531,17 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
inplace_pairs = parse_inplace(op_name, op_entry["inplace"])
else:
inplace_pairs = None
# view
if "view" in op_entry:
view_pairs = parse_inplace(op_name, op_entry["view"])
else:
view_pairs = None
op.update(
{
"infer_meta": infer_meta,
"kernel": kernel,
"inplace": inplace_pairs,
"view": view_pairs,
}
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册