diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index cc4ca4a2cd790b240696b750f7b3dce73e03a406..14cffd9453645c595652f5e5c071751f4aeeb482 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -110,7 +110,7 @@ OpInfoTuple {op_name}::GetOpInfo() {{ std::vector inputs = {{ {inputs} }}; std::vector attributes = {{ {attributes} }}; std::vector outputs = {{ {outputs} }}; - paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}); + paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}}, {{{inplace}}}, {{{view}}}); return std::make_tuple(inputs, attributes, outputs, run_time_info); }} """ @@ -386,6 +386,10 @@ class OpInfoParser: else: self.infer_shape_func = None + # parse inplace && view + self.inplace_map = self.parse_op_inplace_info() + self.view_map = self.parse_op_view_info() + def cross_check(self, name_list, type_list, optional_list=None): assert len(name_list) == len( type_list @@ -396,7 +400,9 @@ class OpInfoParser: ), "type list size != optional list size." def parse_op_phi_name(self): - if self.parse_op_inplace_info() is None: + if (self.parse_op_inplace_info() is None) and ( + self.parse_op_view_info() is None + ): return [self.op_yaml_item['name']] else: if self.op_yaml_item['name'][-1] == "_": @@ -412,6 +418,11 @@ class OpInfoParser: return self.op_yaml_item['inplace'] return None + def parse_op_view_info(self): + if 'view' in self.op_yaml_item: + return self.op_yaml_item['view'] + return None + def parse_mutable_attribute(self): """ {'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'} @@ -1256,6 +1267,8 @@ def OpGenerator( # others op_infer_meta_map = op_info.infer_meta_map op_kernel_map = op_info.kernel_map + op_inplace_map = op_info.inplace_map + op_view_map = op_info.view_map op_interfaces = ["OpYamlInfoInterface"] op_traits = [] @@ -1472,12 +1485,25 @@ def OpGenerator( if op_infer_meta_map is not None: infer_meta_func_str = op_infer_meta_map['func'] infer_meta_param_str = '", "'.join(op_infer_meta_map['param']) + kernel_func_str = "" kernel_param_str = "" if op_kernel_map is not None: kernel_func_str = '", "'.join(op_kernel_map['func']) kernel_param_str = '", "'.join(op_kernel_map['param']) + inplace_str = "" + view_str = "" + if op_name[-1] == "_": + if op_inplace_map is not None: + for key, value in op_inplace_map.items(): + inplace_str += '{"' + key + '", "' + value + '"},' + inplace_str = inplace_str[:-1] + if op_view_map is not None: + for key, value in op_view_map.items(): + view_str += '{"' + key + '", "' + value + '"},' + view_str = view_str[:-1] + op_info_func_str = OP_INFO_TEMPLATE.format( op_name=op_class_name, inputs=inputs_info_str, @@ -1487,6 +1513,8 @@ def OpGenerator( infer_meta_param=infer_meta_param_str, kernel_func=kernel_func_str, kernel_param=kernel_param_str, + inplace=inplace_str, + view=view_str, ) # =================================== # diff --git a/paddle/fluid/ir/dialect/utils.h b/paddle/fluid/ir/dialect/utils.h index 46724c0c77304920c9bdd1ce84f82b5379606b2f..5f08ed28d2e8b2cd77373bfbf776d55ac756adf5 100644 --- a/paddle/fluid/ir/dialect/utils.h +++ b/paddle/fluid/ir/dialect/utils.h @@ -144,14 +144,20 @@ struct OpRunTimeInfo { std::vector infer_meta_param; std::vector kernel_func; std::vector kernel_param; + std::vector> inplace; + std::vector> view; OpRunTimeInfo(std::string infer_meta_func, std::vector infer_meta_param, std::vector kernel_func, - std::vector kernel_param) + std::vector kernel_param, + std::vector> inplace, + std::vector> view) : infer_meta_func(infer_meta_func), infer_meta_param(infer_meta_param), kernel_func(kernel_func), - kernel_param(kernel_param) {} + kernel_param(kernel_param), + inplace(inplace), + view(view) {} }; } // namespace dialect diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index adb3c3ab7c4904a95bf18b76ad56ce2fce225ef8..7e09706d21a3b4527c19ea7641a9298269876928 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -259,12 +259,22 @@ def parse_kernel(op_name: str, kernel_config: Dict[str, Any]) -> Dict[str, Any]: return kernel +def delete_bracket(name: str): + if name[0] == "(": + name = name.lstrip("(") + if name[-1] == ")": + name = name.rstrip(")") + return name + + def parse_inplace(op_name: str, inplace_cfg: str) -> Dict[str, str]: inplace_map = {} inplace_cfg = inplace_cfg.lstrip("(").rstrip(")") pairs = parse_plain_list(inplace_cfg) for pair in pairs: in_name, out_name = parse_plain_list(pair, sep="->") + in_name = delete_bracket(in_name) + out_name = delete_bracket(out_name) inplace_map[out_name] = in_name return inplace_map @@ -521,11 +531,17 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): inplace_pairs = parse_inplace(op_name, op_entry["inplace"]) else: inplace_pairs = None + # view + if "view" in op_entry: + view_pairs = parse_inplace(op_name, op_entry["view"]) + else: + view_pairs = None op.update( { "infer_meta": infer_meta, "kernel": kernel, "inplace": inplace_pairs, + "view": view_pairs, } )