diff --git a/paddle/fluid/dialect/CMakeLists.txt b/paddle/fluid/dialect/CMakeLists.txt index 24c18e24c23b05ac1dc014e6a4ac6b3fe9a27458..a20f84627e8ca612701d1aa8639c4ee8f7e58996 100644 --- a/paddle/fluid/dialect/CMakeLists.txt +++ b/paddle/fluid/dialect/CMakeLists.txt @@ -8,16 +8,17 @@ set(op_forward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml ) set(op_forward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_ops.parsed.yaml + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml ) set(op_backward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml ) set(op_backward_yaml_file2 - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_backward.parsed.yaml + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) +set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/dialect/pd_op.yaml) set(op_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2} + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3} ) set(op_namespace paddle,dialect) set(dialect_name pd) diff --git a/paddle/fluid/dialect/legacy_pd_op.h b/paddle/fluid/dialect/legacy_pd_op.h deleted file mode 100644 index 44d55d57225e91f542de0f460ba05e3a48d1d920..0000000000000000000000000000000000000000 --- a/paddle/fluid/dialect/legacy_pd_op.h +++ /dev/null @@ -1,97 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/ir/core/op_base.h" - -namespace paddle { -namespace dialect { - -#define OPNAME(op_name) "pd." #op_name - -#define REIGSTER_EMPTY_OP(op_name, className) \ - class className : public ir::Op { \ - public: \ - static const char *name() { return OPNAME(op_name); } \ - static constexpr const char **attributes_name = nullptr; \ - static constexpr uint32_t attributes_num = 0; \ - static void verify(const std::vector &inputs, \ - const std::vector &outputs, \ - const ir::AttributeMap &attributes) { \ - LOG(WARNING) << "This is a fake verify"; \ - } \ - }; - -// TODO(zhangbo): As operators are supplemented and defined, they are gradually -// removed. -REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d -REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed -REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: batch_norm -REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_ -REIGSTER_EMPTY_OP(elementwise_add, - ElementwiseAddOp); // To be customized: add (elementwise_add) -REIGSTER_EMPTY_OP(pool2d, Pool2DOp); // To be customized: pool2d -REIGSTER_EMPTY_OP( - flatten_contiguous_range, - FlattenContiguousRangeOp); // flatten (flatten_contiguous_range) -REIGSTER_EMPTY_OP(matmul_v2, - MatmulV2Op); // To be customized: matmul (matmul_v2) -REIGSTER_EMPTY_OP(reshape2, Reshape2Op); // To be customized: reshape -REIGSTER_EMPTY_OP(softmax_with_cross_entropy, - SoftmaxWithCrossEntropyOp); // cross_entropy_with_softmax - // (softmax_with_cross_entropy) -REIGSTER_EMPTY_OP(reduce_mean, - ReduceMeanOp); // To be customized: mean (reduce_mean) -REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); // topk (top_k_v2) -REIGSTER_EMPTY_OP(fill_constant, - FillConstantOp); // To be customized: full (fill_constant) -REIGSTER_EMPTY_OP(reduce_mean_grad, - ReduceMeanGradOp); // To be customized: reduce_mean_grad -REIGSTER_EMPTY_OP( - softmax_with_cross_entropy_grad, - SoftmaxWithCrossEntropyGradOp); // cross_entropy_with_softmax_grad - // (softmax_with_cross_entropy_grad) -REIGSTER_EMPTY_OP( - elementwise_add_grad, - ElementwiseAddGradOp); // To be customized: add_grad (elementwise_add_grad) -REIGSTER_EMPTY_OP( - matmul_v2_grad, - MatmulV2GradOp); // To be customized: matmul_grad (matmul_v2_grad) -REIGSTER_EMPTY_OP( - flatten_contiguous_range_grad, - FlattenContiguousRangeGradOp); // flatten_grad - // (flatten_contiguous_range_grad) -REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad -REIGSTER_EMPTY_OP(batch_norm_grad, - BatchNormGradOp); // To be customized: batch_norm_grad -REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad -REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum) -REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2 -REIGSTER_EMPTY_OP(add, AddOp); -REIGSTER_EMPTY_OP(add_grad, AddGradOp); -REIGSTER_EMPTY_OP(matmul, MatMulOp); -REIGSTER_EMPTY_OP(matmul_grad, MatMulGradOp); -REIGSTER_EMPTY_OP(reshape, ReshapeOp); -REIGSTER_EMPTY_OP(reshape_grad, ReshapeGradOp); -REIGSTER_EMPTY_OP(mean, MeanOp); -REIGSTER_EMPTY_OP(cross_entropy_with_softmax, CrossEntropyOp); -REIGSTER_EMPTY_OP(cross_entropy_with_softmax_grad, CrossEntropyGradOp); -REIGSTER_EMPTY_OP(topk, TopKOp); -REIGSTER_EMPTY_OP(topk_grad, TopKGradOp); -REIGSTER_EMPTY_OP(full, FullOp); -REIGSTER_EMPTY_OP(add_n, AddNOp); - -} // namespace dialect -} // namespace paddle diff --git a/paddle/fluid/dialect/op_gen.py b/paddle/fluid/dialect/op_gen.py index 163f7149e11596a0b6fcb142623734ceb921dc6a..acb3b2721278f3ef1a3602f4141ff07e7d502d16 100644 --- a/paddle/fluid/dialect/op_gen.py +++ b/paddle/fluid/dialect/op_gen.py @@ -29,7 +29,11 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST {op_declare} #else +#include + #include "paddle/ir/core/op_base.h" +#include "paddle/fluid/dialect/utils.h" +#include "paddle/fluid/dialect/pd_interface.h" {input} #endif @@ -45,6 +49,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static const char *name() {{ return "{dialect_op_name}"; }} {attribute_declare} static constexpr uint32_t attributes_num = {attribute_num}; + static OpInfoTuple GetOpInfo(); static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); {get_inputs_and_outputs} }}; @@ -79,6 +84,46 @@ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """ const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; """ +# get op input info +OP_INFO_TEMPLATE = """ +OpInfoTuple {op_name}::GetOpInfo() {{ + std::vector inputs = {{ {inputs} }}; + std::vector attributes = {{ {attributes} }}; + std::vector outputs = {{ {outputs} }}; + return std::make_tuple(inputs, attributes, outputs); +}} +""" + +OP_INPUT_INFO_TEMPLATE = """ +std::vector {op_name}::inputs_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_INPUT_INFO_TEMPLATE = ( + """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})""" +) + +# get op output info +OP_OUTPUT_INFO_TEMPLATE = """ +std::vector {op_name}::outputs_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( + """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" +) + +# get op attribute info +OP_ATTRIBUTE_INFO_TEMPLATE = """ +std::vector {op_name}::attributes_info() {{ + return {{ {impl} }}; +}} +""" +CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( + """OpAttributeInfo("{name}", "{typename}", "{data_type}")""" +) + +# verify OP_VERIFY_TEMPLATE = """ void {op_name}::verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}."; @@ -158,10 +203,14 @@ OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{ }} """ -ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true, +ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true, + phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}.")); + PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true, phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); """ -ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa(), true, +ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true, + phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}.")); + PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa(), true, phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast()[i].isa<{standard}>(), true, @@ -170,32 +219,65 @@ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.at("{attribute """ +def to_phi_and_fluid_op_name(op_item): + # Templat: - op : phi_name (fluid_name) + names = op_item.split('(') + if len(names) == 1: + phi_fluid_name = names[0].strip() + return phi_fluid_name, phi_fluid_name + else: + phi_name = names[0].strip() + fluid_name = names[1].split(')')[0].strip() + return phi_name, fluid_name + + # ===================================== -# Parse Op information from Yaml item +# Parse Op Compat From Yaml +# ===================================== +class OpCompatParser: + def __init__(self, ops_compat_yaml_file): + self.ops_compat_yaml_file = ops_compat_yaml_file + with open(self.ops_compat_yaml_file, "r") as f: + self.ops_compat = yaml.safe_load(f) + + def get_compat(self, op_name): + for compat in self.ops_compat: + phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op']) + if op_name == phi_name: + return compat + return None + + +# ===================================== +# Parse Op Information From Yaml # ===================================== class OpInfoParser: - def __init__(self, op_yaml_item): + def __init__(self, op_yaml_item, op_compat_item): self.op_yaml_item = op_yaml_item + self.op_compat_item = op_compat_item self.op_phi_name = self.parse_op_phi_name() - + # parse inputs self.input_name_list = self.parse_input_name_list() self.input_type_list = self.parse_input_type_list() self.input_optional_list = self.parse_input_optional_list() + self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list() self.cross_check( self.input_name_list, self.input_type_list, self.input_optional_list ) - + # parse outputs self.output_name_list = self.parse_output_name_list() self.output_type_list = self.parse_output_type_list() self.output_optional_list = self.parse_output_optional_list() + self.output_intermediate_list = self.parse_output_intermediate_list() self.cross_check( self.output_name_list, self.output_type_list, self.output_optional_list, ) - + # parse attributes self.attribute_name_list = self.parse_attribute_name_list() self.attribute_type_list = self.parse_attribute_type_list() + self.attribute_data_type_list = self.parse_attribute_data_type_list() self.cross_check(self.attribute_name_list, self.attribute_type_list) def cross_check(self, name_list, type_list, optional_list=None): @@ -229,9 +311,21 @@ class OpInfoParser: def parse_input_optional_list(self): optional_list = [] for input_info in self.op_yaml_item['inputs']: - optional_list.append(input_info['optional']) + if input_info['optional']: + optional_list.append("true") + else: + optional_list.append("false") return optional_list + def parse_input_no_need_buffer_list(self): + no_need_buffer_list = [] + for input_info in self.op_yaml_item['inputs']: + if input_info['no_need_buffer']: + no_need_buffer_list.append("true") + else: + no_need_buffer_list.append("false") + return no_need_buffer_list + def parse_output_name_list(self): name_list = [] for output_info in self.op_yaml_item['outputs']: @@ -255,11 +349,26 @@ class OpInfoParser: optional_list = [] for output_info in self.op_yaml_item['outputs']: if 'optional' in output_info: - optional_list.append(output_info['optional']) + if output_info['optional']: + optional_list.append("true") + else: + optional_list.append("false") else: - optional_list.append(False) + optional_list.append("false") return optional_list + def parse_output_intermediate_list(self): + intermediate_list = [] + for output_info in self.op_yaml_item['outputs']: + if 'intermediate' in output_info: + if output_info['intermediate']: + intermediate_list.append("true") + else: + intermediate_list.append("false") + else: + intermediate_list.append("false") + return intermediate_list + def parse_attribute_name_list(self): name_list = [] for attribute_info in self.op_yaml_item['attrs']: @@ -301,8 +410,31 @@ class OpInfoParser: type_list.append(attr_types_map[attribute_info['typename']]) return type_list + def parse_attribute_data_type_list(self): + data_type_list = [] + for attribute_info in self.op_yaml_item['attrs']: + if 'data_type' in attribute_info: + data_type_list.append(attribute_info['data_type']) + else: + data_type_list.append("") + return data_type_list + def parse_op_phi_name(self): - return self.op_yaml_item['name'] + if self.parse_op_inplace_info() is None: + return [self.op_yaml_item['name']] + else: + if self.op_yaml_item['name'][-1] == "_": + return [self.op_yaml_item['name']] + else: + return [ + self.op_yaml_item['name'], + self.op_yaml_item['name'] + "_", + ] + + def parse_op_inplace_info(self): + if 'inplace' in self.op_yaml_item: + return self.op_yaml_item['inplace'] + return None def to_pascal_case(s): @@ -314,10 +446,11 @@ def to_pascal_case(s): # ===================================== -# Generate op definition files +# Generate Op Definition Files # ===================================== def OpGenerator( op_yaml_files, + op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, @@ -330,6 +463,8 @@ def OpGenerator( os.remove(op_def_cc_file) # (2) Prepare: Get all op item in all op_yaml_files + op_compat_parser = OpCompatParser(op_compat_yaml_file) + op_yaml_items = [] for yaml_file in op_yaml_files: with open(yaml_file, "r") as f: @@ -337,7 +472,9 @@ def OpGenerator( op_yaml_items = op_yaml_items + ops op_info_items = [] for op in op_yaml_items: - op_info_items.append(OpInfoParser(op)) + op_info_items.append( + OpInfoParser(op, op_compat_parser.get_compat(op['name'])) + ) # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list @@ -345,177 +482,241 @@ def OpGenerator( ops_defined_list = [] # all op class defined store in this list for op_info in op_info_items: # get op info - op_name = op_info.op_phi_name - op_class_name = to_pascal_case(op_name) + "Op" - op_dialect_name = dialect_name + "." + op_name op_input_name_list = op_info.input_name_list op_input_type_list = op_info.input_type_list op_input_optional_list = op_info.input_optional_list + op_input_no_need_buffer_list = op_info.input_no_need_buffer_list op_output_name_list = op_info.output_name_list op_output_type_list = op_info.output_type_list op_output_optional_list = op_info.output_optional_list + op_output_intermediate_list = op_info.output_intermediate_list op_attribute_name_list = op_info.attribute_name_list op_attribute_type_list = op_info.attribute_type_list - op_interfaces = [] + op_attribute_data_type_list = op_info.attribute_data_type_list + op_interfaces = ["GetOpInfoInterface"] op_traits = [] - # gen interface/trait str - op_interfaces_str = "" - if len(op_interfaces) > 0: - op_interfaces_str = "," + ",".join(op_interfaces) - op_traits_str = "" - if len(op_interfaces) > 0: - op_traits_str = "," + ",".join(op_traits) - - op_get_inputs_outputs_str = "" - for idx in range(len(op_input_name_list)): - op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( - input_name=op_input_name_list[idx], input_index=idx - ) - for idx in range(len(op_output_name_list)): - op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( - output_name=op_output_name_list[idx], output_index=idx - ) + # If op has inplace info, we will generate inplace op and non-inplace op. + for op_name in op_info.op_phi_name: + op_class_name = to_pascal_case(op_name) + "Op" + op_dialect_name = dialect_name + "." + op_name + + # gen interface/trait str + op_interfaces_str = "" + if len(op_interfaces) > 0: + op_interfaces_str = "," + ",".join(op_interfaces) + op_traits_str = "" + if len(op_traits) > 0: + op_traits_str = "," + ",".join(op_traits) + + op_get_inputs_outputs_str = "" + for idx in range(len(op_input_name_list)): + op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( + input_name=op_input_name_list[idx], + input_index=idx, + ) + for idx in range(len(op_output_name_list)): + op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( + output_name=op_output_name_list[idx], + output_index=idx, + ) - # gen op_declare_str/op_defined_str - if len(op_attribute_name_list) == 0: - op_declare_str = OP_DECLARE_TEMPLATE.format( - op_name=op_class_name, - dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, - traits=op_traits_str, - attribute_declare=op_0_attribute_declare_str, - attribute_num=0, - get_inputs_and_outputs=op_get_inputs_outputs_str, - ) - op_defined_str = "" - else: - op_declare_str = OP_DECLARE_TEMPLATE.format( - op_name=op_class_name, - dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, - traits=op_traits_str, - attribute_declare=op_n_attribute_declare_str.format( - attribute_num=len(op_attribute_name_list) - ), - attribute_num=len(op_attribute_name_list), - get_inputs_and_outputs=op_get_inputs_outputs_str, - ) - attribute_names_str = ( - '"' + '", "'.join(op_attribute_name_list) + '"' - ) - op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( - op_name=op_class_name, - attribute_num=len(op_attribute_name_list), - attribute_names=attribute_names_str, - ) + # gen op_declare_str/op_defined_str + if len(op_attribute_name_list) == 0: + op_declare_str = OP_DECLARE_TEMPLATE.format( + op_name=op_class_name, + dialect_op_name=op_dialect_name, + interfaces=op_interfaces_str, + traits=op_traits_str, + attribute_declare=op_0_attribute_declare_str, + attribute_num=0, + get_inputs_and_outputs=op_get_inputs_outputs_str, + ) + op_defined_str = "" + else: + op_declare_str = OP_DECLARE_TEMPLATE.format( + op_name=op_class_name, + dialect_op_name=op_dialect_name, + interfaces=op_interfaces_str, + traits=op_traits_str, + attribute_declare=op_n_attribute_declare_str.format( + attribute_num=len(op_attribute_name_list) + ), + attribute_num=len(op_attribute_name_list), + get_inputs_and_outputs=op_get_inputs_outputs_str, + ) + attribute_names_str = ( + '"' + '", "'.join(op_attribute_name_list) + '"' + ) + op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format( + op_name=op_class_name, + attribute_num=len(op_attribute_name_list), + attribute_names=attribute_names_str, + ) - # generate op verify function: inputs_type_check_str - if len(op_input_type_list) == 0: - inputs_type_check_str = ( - "// Inputs num is 0, not need to check inputs type." - ) - else: - inputs_type_check_str = "" - for idx in range(len(op_input_type_list)): - input_type = op_input_type_list[idx] - is_optional = op_input_optional_list[idx] - is_vector = False - if input_type.startswith("ir::VectorType<"): - is_vector = True - input_type = input_type[15:-1] - check_str = "" - if is_optional: - if is_vector: - check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) - else: - check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type + # generate get op info funciton: inputs + inputs_info_str = "" + if len(op_input_name_list) > 0: + input_info_list = [] + for idx in range(len(op_input_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_input_name_list[idx], + typename=op_input_type_list[idx], + optional=op_input_optional_list[idx], + no_need_buffer=op_input_no_need_buffer_list[idx], + ) ) - else: - if is_vector: - check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type + inputs_info_str = ", ".join(input_info_list) + + # generate get op info funciton: outputs + outputs_info_str = "" + if len(op_output_name_list) > 0: + output_info_list = [] + for idx in range(len(op_output_name_list)): + output_info_list.append( + CONSTRUCT_OUTPUT_INFO_TEMPLATE.format( + name=op_output_name_list[idx], + typename=op_output_type_list[idx], + optional=op_output_optional_list[idx], + intermediate=op_output_intermediate_list[idx], + ) ) - else: - check_str = INPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type + outputs_info_str = ", ".join(output_info_list) + + # generate get op info funciton: attributes + attribute_info_str = "" + if len(op_attribute_name_list) > 0: + attribute_info_list = [] + for idx in range(len(op_attribute_name_list)): + attribute_info_list.append( + CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( + name=op_attribute_name_list[idx], + typename=op_attribute_type_list[idx], + data_type=op_attribute_data_type_list[idx], + ) ) - inputs_type_check_str += check_str + attribute_info_str = ", ".join(attribute_info_list) - # generate op verify function: outputs_type_check_str - if len(op_output_type_list) == 0: - outputs_type_check_str = ( - "// Outputs num is 0, not need to check outputs type." + op_info_func_str = OP_INFO_TEMPLATE.format( + op_name=op_class_name, + inputs=inputs_info_str, + attributes=attribute_info_str, + outputs=outputs_info_str, ) - else: - outputs_type_check_str = "" - for idx in range(len(op_output_type_list)): - output_type = op_output_type_list[idx] - is_optional = op_output_optional_list[idx] - is_vector = False - if output_type.startswith("ir::VectorType<"): - is_vector = True - output_type = output_type[15:-1] - check_str = "" - if is_optional: - if is_vector: - check_str = ( - OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + + # generate op verify function: inputs_type_check_str + if len(op_input_type_list) == 0: + inputs_type_check_str = ( + "// Inputs num is 0, not need to check inputs type." + ) + else: + inputs_type_check_str = "" + for idx in range(len(op_input_type_list)): + input_type = op_input_type_list[idx] + is_optional = op_input_optional_list[idx] + is_vector = False + if input_type.startswith("ir::VectorType<"): + is_vector = True + input_type = input_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = ( + INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + ) + else: + check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + else: + if is_vector: + check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + else: + check_str = INPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + inputs_type_check_str += check_str + + # generate op verify function: outputs_type_check_str + if len(op_output_type_list) == 0: + outputs_type_check_str = ( + "// Outputs num is 0, not need to check outputs type." + ) + else: + outputs_type_check_str = "" + for idx in range(len(op_output_type_list)): + output_type = op_output_type_list[idx] + is_optional = op_output_optional_list[idx] + is_vector = False + if output_type.startswith("ir::VectorType<"): + is_vector = True + output_type = output_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = ( + OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + ) + else: + check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( index=idx, standard=output_type ) - ) else: - check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type - ) + if is_vector: + check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + else: + check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + outputs_type_check_str += check_str + + # generate op verify function: attributes_check_str + if len(op_attribute_name_list) == 0: + attributes_check_str = ( + "// Attributes num is 0, not need to check attributes type." + ) else: - if is_vector: - check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type + attributes_check_str = "" + for idx in range(len(op_attribute_name_list)): + attribute_name = op_attribute_name_list[idx] + attribute_type = op_attribute_type_list[idx] + if attribute_type.startswith("ir::ArrayAttribute<"): + attribute_type = attribute_type[19:-1] + attributes_check_str += ( + ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( + attribute_name=attribute_name, + standard=attribute_type, + ) ) else: - check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type + attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( + attribute_name=attribute_name, standard=attribute_type ) - outputs_type_check_str += check_str - # generate op verify function: attributes_check_str - if len(op_attribute_name_list) == 0: - attributes_check_str = ( - "// Attributes num is 0, not need to check attributes type." + # generate op verify function + op_verify_str = OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + inputs_size=len(op_input_type_list), + outputs_size=len(op_output_type_list), + inputs_type_check=inputs_type_check_str, + outputs_type_check=outputs_type_check_str, + attributes_check=attributes_check_str, ) - else: - attributes_check_str = "" - for idx in range(len(op_attribute_name_list)): - attribute_name = op_attribute_name_list[idx] - attribute_type = op_attribute_type_list[idx] - if attribute_type.startswith("ir::ArrayAttribute<"): - attribute_type = attribute_type[19:-1] - attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( - attribute_name=attribute_name, standard=attribute_type - ) - else: - attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( - attribute_name=attribute_name, standard=attribute_type - ) - - # generate op verify function - op_verify_str = OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - inputs_size=len(op_input_type_list), - outputs_size=len(op_output_type_list), - inputs_type_check=inputs_type_check_str, - outputs_type_check=outputs_type_check_str, - attributes_check=attributes_check_str, - ) - ops_name_list.append(op_class_name) - ops_declare_list.append(op_declare_str) - ops_defined_list.append(op_defined_str) - ops_defined_list.append(op_verify_str) + ops_name_list.append(op_class_name) + ops_declare_list.append(op_declare_str) + ops_defined_list.append(op_defined_str) + ops_defined_list.append(op_info_func_str) + ops_defined_list.append(op_verify_str) # (4) Generate head file str op_namespaces_prev = "" @@ -588,6 +789,7 @@ if __name__ == "__main__": # auto code generate OpGenerator( op_yaml_files, + op_compat_yaml_file, namespaces, dialect_name, op_def_h_file, diff --git a/paddle/fluid/dialect/pd_dialect.cc b/paddle/fluid/dialect/pd_dialect.cc index c5d648d3838742c3171ddfa2f9ab042b46b9ea70..e9802e790e6813c9bba4e759b7960d9d5d8d9922 100644 --- a/paddle/fluid/dialect/pd_dialect.cc +++ b/paddle/fluid/dialect/pd_dialect.cc @@ -16,7 +16,6 @@ #include "paddle/fluid/dialect/pd_attribute.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/dialect/CMakeLists.txt. -#include "paddle/fluid/dialect/legacy_pd_op.h" #include "paddle/fluid/dialect/pd_op.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/pd_type_storage.h" @@ -111,42 +110,6 @@ void PaddleDialect::initialize() { >(); RegisterInterfaces(); - RegisterOps(); } void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { diff --git a/paddle/fluid/dialect/pd_interface.h b/paddle/fluid/dialect/pd_interface.h new file mode 100644 index 0000000000000000000000000000000000000000..45c32bc837023d17c4e29db6d14ae42131f978e4 --- /dev/null +++ b/paddle/fluid/dialect/pd_interface.h @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/dialect/utils.h" +#include "paddle/ir/core/op_base.h" + +using OpInfoTuple = std::tuple, + std::vector, + std::vector>; + +namespace paddle { +namespace dialect { +class GetOpInfoInterface : public ir::OpInterfaceBase { + public: + struct Concept { + explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) + : get_op_info_(get_op_info) {} + OpInfoTuple (*get_op_info_)(ir::Operation *); + }; + + template + struct Model : public Concept { + static OpInfoTuple GetOpInfo(ir::Operation *op) { + ConcreteOp concret_op = op->dyn_cast(); + if (concret_op == nullptr) throw("concret_op is nullptr"); + return concret_op.GetOpInfo(); + } + + Model() : Concept(GetOpInfo) {} + }; + + GetOpInfoInterface(ir::Operation *op, Concept *impl) + : ir::OpInterfaceBase(op), impl_(impl) {} + + OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } + + private: + Concept *impl_; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/dialect/pd_op.yaml b/paddle/fluid/dialect/pd_op.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7ca8646a038c6c37075b7806e65b0456f8b1cd1b --- /dev/null +++ b/paddle/fluid/dialect/pd_op.yaml @@ -0,0 +1,52 @@ +- name: feed + inputs: + - typename: Tensor[] + name: x + optional: false + no_need_buffer: false + data_transform: {} + attrs: + - {typename: int, name: col} + outputs: + - {typename: Tensor, name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: null + param: null + kernel: + func: null + param: null + backend: null + layout: null + data_type: null + dispatch: null + force_backend: null + inplace: null + backward: null +- name: fetch + inputs: + - typename: Tensor + name: x + optional: false + no_need_buffer: false + data_transform: {} + attrs: + - {typename: int, name: col} + outputs: + - {typename: 'Tensor[]', name: out, optional: false, intermediate: false} + no_need_buffer: null + data_transform: null + infer_meta: + func: null + param: null + kernel: + func: null + param: null + backend: null + layout: null + data_type: null + dispatch: null + force_backend: null + inplace: null + backward: null diff --git a/paddle/fluid/dialect/utils.h b/paddle/fluid/dialect/utils.h index 56c6db01a8250cfb494df7c747caca7946520bac..bfc28f06267635af68b211c1c8367587c21b08b0 100644 --- a/paddle/fluid/dialect/utils.h +++ b/paddle/fluid/dialect/utils.h @@ -132,5 +132,45 @@ inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout( } } +struct OpInputInfo { + std::string name; + std::string type_name; + bool optional = false; + bool no_need_buffer = false; + OpInputInfo(std::string name, + std::string type_name, + bool optional, + bool no_need_buffer) + : name(name), + type_name(type_name), + optional(optional), + no_need_buffer(no_need_buffer) {} +}; + +struct OpOutputInfo { + std::string name; + std::string type_name; + bool optional = false; + bool intermediate = false; + OpOutputInfo(std::string name, + std::string type_name, + bool optional, + bool intermediate) + : name(name), + type_name(type_name), + optional(optional), + intermediate(intermediate) {} +}; + +struct OpAttributeInfo { + std::string name; + std::string type_name; + std::string data_type; + OpAttributeInfo(std::string name, + std::string type_name, + std::string data_type) + : name(name), type_name(type_name), data_type(data_type) {} +}; + } // namespace dialect } // namespace paddle diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 4dba3fdd74ec800a3c555e4117a4e475771e2279..e4ab20dc5ee8c34b42d12c71ccf1f9c2f5a813ce 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -296,7 +296,7 @@ - op : einsum args : (Tensor[] x, str equation) - output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} + output : Tensor(out), Tensor[](inner_cache){x.size()}, Tensor[](xshape){x.size()} infer_meta : func : EinsumRawInferMeta param : [x, equation] diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index af83acd5bfa87f79702f1d69135175a393bae09d..65796d827e0e9c987e3b1ca2b3ad609a267cd4eb 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -50,10 +50,7 @@ class InferShapeInterface : public ir::OpInterfaceBase { concret_op.InferShape(); } - Model() : Concept(InferShape) { - static_assert(sizeof(Model) == sizeof(Concept), - "sizeof(Model) != sizeof(Concept)"); - } + Model() : Concept(InferShape) {} }; InferShapeInterface(ir::Operation *op, Concept *impl) diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index c16858ba1a219f225ab4b9a6922fa327cf6e2d6c..8b78b739a659433283bc0a7ddc1fe7a76c8557b9 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -15,6 +15,7 @@ #include #include "paddle/fluid/dialect/pd_dialect.h" +#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/dialect/pd_type.h" #include "paddle/fluid/dialect/utils.h" #include "paddle/ir/core/builtin_attribute.h" @@ -177,7 +178,21 @@ TEST(program_test, program) { EXPECT_EQ(*(dst_tensor->data() + i), data_a[i] + data_b[i]); } - // (7) Def SetParameterOp(c, "c") + // (7) Def AbsOp(b) + ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs"); + std::vector operands = {op1->GetResultByIndex(0)}; + std::unordered_map abs_op_attribute; + std::vector output_types = {dense_tensor_dtype}; + ir::OperationArgument abs_argument(abs_info); + abs_argument.addOperands(operands.begin(), operands.end()); + abs_argument.addAttributes(abs_op_attribute.begin(), abs_op_attribute.end()); + abs_argument.addTypes(output_types.begin(), output_types.end()); + ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument)); + paddle::dialect::GetOpInfoInterface interface = + abs_op->dyn_cast(); + EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true); + + // (8) Def SetParameterOp(c, "c") std::string op4_name = builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name()); ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);