未验证 提交 5b97278e 编写于 作者: Z zhangbo9674 提交者: GitHub

add inplace view info to OpYamlInfoInterface (#54551)

上级 a56eba3a
...@@ -110,7 +110,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ ...@@ -110,7 +110,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }}; std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }}; std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }}; std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}); paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{inplace}}}, {{{view}}});
return std::make_tuple(inputs, attributes, outputs, run_time_info); return std::make_tuple(inputs, attributes, outputs, run_time_info);
}} }}
""" """
...@@ -386,6 +386,10 @@ class OpInfoParser: ...@@ -386,6 +386,10 @@ class OpInfoParser:
else: else:
self.infer_shape_func = None self.infer_shape_func = None
# parse inplace && view
self.inplace_map = self.parse_op_inplace_info()
self.view_map = self.parse_op_view_info()
def cross_check(self, name_list, type_list, optional_list=None): def cross_check(self, name_list, type_list, optional_list=None):
assert len(name_list) == len( assert len(name_list) == len(
type_list type_list
...@@ -396,7 +400,9 @@ class OpInfoParser: ...@@ -396,7 +400,9 @@ class OpInfoParser:
), "type list size != optional list size." ), "type list size != optional list size."
def parse_op_phi_name(self): def parse_op_phi_name(self):
if self.parse_op_inplace_info() is None: if (self.parse_op_inplace_info() is None) and (
self.parse_op_view_info() is None
):
return [self.op_yaml_item['name']] return [self.op_yaml_item['name']]
else: else:
if self.op_yaml_item['name'][-1] == "_": if self.op_yaml_item['name'][-1] == "_":
...@@ -412,6 +418,11 @@ class OpInfoParser: ...@@ -412,6 +418,11 @@ class OpInfoParser:
return self.op_yaml_item['inplace'] return self.op_yaml_item['inplace']
return None return None
def parse_op_view_info(self):
if 'view' in self.op_yaml_item:
return self.op_yaml_item['view']
return None
def parse_mutable_attribute(self): def parse_mutable_attribute(self):
""" """
{'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'} {'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'}
...@@ -1256,6 +1267,8 @@ def OpGenerator( ...@@ -1256,6 +1267,8 @@ def OpGenerator(
# others # others
op_infer_meta_map = op_info.infer_meta_map op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map op_kernel_map = op_info.kernel_map
op_inplace_map = op_info.inplace_map
op_view_map = op_info.view_map
op_interfaces = ["OpYamlInfoInterface"] op_interfaces = ["OpYamlInfoInterface"]
op_traits = [] op_traits = []
...@@ -1472,12 +1485,25 @@ def OpGenerator( ...@@ -1472,12 +1485,25 @@ def OpGenerator(
if op_infer_meta_map is not None: if op_infer_meta_map is not None:
infer_meta_func_str = op_infer_meta_map['func'] infer_meta_func_str = op_infer_meta_map['func']
infer_meta_param_str = '", "'.join(op_infer_meta_map['param']) infer_meta_param_str = '", "'.join(op_infer_meta_map['param'])
kernel_func_str = "" kernel_func_str = ""
kernel_param_str = "" kernel_param_str = ""
if op_kernel_map is not None: if op_kernel_map is not None:
kernel_func_str = '", "'.join(op_kernel_map['func']) kernel_func_str = '", "'.join(op_kernel_map['func'])
kernel_param_str = '", "'.join(op_kernel_map['param']) kernel_param_str = '", "'.join(op_kernel_map['param'])
inplace_str = ""
view_str = ""
if op_name[-1] == "_":
if op_inplace_map is not None:
for key, value in op_inplace_map.items():
inplace_str += '{"' + key + '", "' + value + '"},'
inplace_str = inplace_str[:-1]
if op_view_map is not None:
for key, value in op_view_map.items():
view_str += '{"' + key + '", "' + value + '"},'
view_str = view_str[:-1]
op_info_func_str = OP_INFO_TEMPLATE.format( op_info_func_str = OP_INFO_TEMPLATE.format(
op_name=op_class_name, op_name=op_class_name,
inputs=inputs_info_str, inputs=inputs_info_str,
...@@ -1487,6 +1513,8 @@ def OpGenerator( ...@@ -1487,6 +1513,8 @@ def OpGenerator(
infer_meta_param=infer_meta_param_str, infer_meta_param=infer_meta_param_str,
kernel_func=kernel_func_str, kernel_func=kernel_func_str,
kernel_param=kernel_param_str, kernel_param=kernel_param_str,
inplace=inplace_str,
view=view_str,
) )
# =================================== # # =================================== #
......
...@@ -144,14 +144,20 @@ struct OpRunTimeInfo { ...@@ -144,14 +144,20 @@ struct OpRunTimeInfo {
std::vector<std::string> infer_meta_param; std::vector<std::string> infer_meta_param;
std::vector<std::string> kernel_func; std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param; std::vector<std::string> kernel_param;
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
OpRunTimeInfo(std::string infer_meta_func, OpRunTimeInfo(std::string infer_meta_func,
std::vector<std::string> infer_meta_param, std::vector<std::string> infer_meta_param,
std::vector<std::string> kernel_func, std::vector<std::string> kernel_func,
std::vector<std::string> kernel_param) std::vector<std::string> kernel_param,
std::vector<std::pair<std::string, std::string>> inplace,
std::vector<std::pair<std::string, std::string>> view)
: infer_meta_func(infer_meta_func), : infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param), infer_meta_param(infer_meta_param),
kernel_func(kernel_func), kernel_func(kernel_func),
kernel_param(kernel_param) {} kernel_param(kernel_param),
inplace(inplace),
view(view) {}
}; };
} // namespace dialect } // namespace dialect
......
...@@ -259,12 +259,22 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]: ...@@ -259,12 +259,22 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]:
return kernel return kernel
def delete_bracket(name: str):
if name[0] == "(":
name = name.lstrip("(")
if name[-1] == ")":
name = name.rstrip(")")
return name
def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]: def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]:
inplace_map = {} inplace_map = {}
inplace_cfg = inplace_cfg.lstrip("(").rstrip(")") inplace_cfg = inplace_cfg.lstrip("(").rstrip(")")
pairs = parse_plain_list(inplace_cfg) pairs = parse_plain_list(inplace_cfg)
for pair in pairs: for pair in pairs:
in_name, out_name = parse_plain_list(pair, sep="->") in_name, out_name = parse_plain_list(pair, sep="->")
in_name = delete_bracket(in_name)
out_name = delete_bracket(out_name)
inplace_map[out_name] = in_name inplace_map[out_name] = in_name
return inplace_map return inplace_map
...@@ -521,11 +531,17 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): ...@@ -521,11 +531,17 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"):
inplace_pairs = parse_inplace(op_name, op_entry["inplace"]) inplace_pairs = parse_inplace(op_name, op_entry["inplace"])
else: else:
inplace_pairs = None inplace_pairs = None
# view
if "view" in op_entry:
view_pairs = parse_inplace(op_name, op_entry["view"])
else:
view_pairs = None
op.update( op.update(
{ {
"infer_meta": infer_meta, "infer_meta": infer_meta,
"kernel": kernel, "kernel": kernel,
"inplace": inplace_pairs, "inplace": inplace_pairs,
"view": view_pairs,
} }
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册