From 96652265aaf0cf2e5320aea4738b5daaa75ada5e Mon Sep 17 00:00:00 2001 From: winter-wang <78149749+winter-wang@users.noreply.github.com> Date: Tue, 27 Jun 2023 20:53:38 +0800 Subject: [PATCH] [IR] rectify the verify api (#54895) --- paddle/fluid/ir/dialect/kernel_op.cc | 4 +- paddle/fluid/ir/dialect/kernel_op.h | 4 +- paddle/fluid/ir/dialect/op_gen.py | 247 +--------------- paddle/fluid/ir/dialect/op_verify_gen.py | 275 ++++++++++++++++++ paddle/ir/core/builtin_op.cc | 115 ++++---- paddle/ir/core/builtin_op.h | 25 +- paddle/ir/core/dialect.h | 2 +- paddle/ir/core/ir_context.h | 21 +- paddle/ir/core/op_base.h | 16 + paddle/ir/core/op_info.cc | 6 +- paddle/ir/core/op_info.h | 7 +- paddle/ir/core/op_info_impl.h | 3 - paddle/ir/core/operation.cc | 9 +- test/cpp/ir/core/ir_infershape_test.cc | 4 +- test/cpp/ir/core/ir_op_test.cc | 10 +- test/cpp/ir/core/ir_program_test.cc | 25 +- test/cpp/ir/pass/pass_manager_test.cc | 19 +- .../pattern_rewrite/pattern_rewrite_test.cc | 24 +- 18 files changed, 427 insertions(+), 389 deletions(-) create mode 100644 paddle/fluid/ir/dialect/op_verify_gen.py diff --git a/paddle/fluid/ir/dialect/kernel_op.cc b/paddle/fluid/ir/dialect/kernel_op.cc index b7bb3d663b7..30a2a24d07f 100644 --- a/paddle/fluid/ir/dialect/kernel_op.cc +++ b/paddle/fluid/ir/dialect/kernel_op.cc @@ -20,9 +20,7 @@ namespace dialect { const char *PhiKernelOp::attributes_name[attributes_num] = { "base_op", "infermeta_fn", "kernel_fn"}; -void PhiKernelOp::Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { +void PhiKernelOp::Verify() { VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp."; // Verify inputs type: diff --git a/paddle/fluid/ir/dialect/kernel_op.h b/paddle/fluid/ir/dialect/kernel_op.h index b3b0fe4187a..34fe2590267 100644 --- a/paddle/fluid/ir/dialect/kernel_op.h +++ b/paddle/fluid/ir/dialect/kernel_op.h @@ -26,9 +26,7 @@ class PhiKernelOp : public ir::Op { static const char *name() { return "phi.kernel"; } static constexpr uint32_t attributes_num = 3; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes); + void Verify(); }; } // namespace dialect diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index d1ea4a0c9da..65eabda7747 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -16,6 +16,7 @@ import argparse import os import yaml +from op_verify_gen import gen_verify_func_str # ===================================== # String Template for h file code gen @@ -65,7 +66,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ static OpInfoTuple GetOpInfo(); static void Build({build_args}); {build_mutable_attr_is_input} - static void Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes); + void Verify(); {get_inputs_and_outputs} {exclusive_interface} }}; @@ -141,105 +142,6 @@ void {op_name}::Build({build_args}) {{ {build_outputs} }} """ - -# verify -OP_VERIFY_TEMPLATE = """ -void {op_name}::Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ - VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}."; - - // Verify inputs type: - PADDLE_ENFORCE_EQ(inputs.size(), {inputs_size}, - phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", inputs.size())); - {inputs_type_check} - // Verify outputs type: - PADDLE_ENFORCE_EQ(outputs.size(), {outputs_size}, - phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", outputs.size())); - {outputs_type_check} - // Verify if attributes contain attribute name in attributes_name: - {attributes_check} -}} -""" - -GRAD_OP_VERIFY_TEMPLATE = """ -void {op_name}::Verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) {{ - (void)inputs; - (void)outputs; - (void)attributes; -}} -""" - -INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); - """ -INPUT_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}].type().isa()) {{ - for (size_t i = 0; i < inputs[{index}].type().dyn_cast().size(); i++) {{ - PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast()[i].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); - }} - }} else {{ - PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); - }} - """ -INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{ - PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); - }} - """ -INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{ - if (inputs[{index}].type().isa()) {{ - for (size_t i = 0; i < inputs[{index}].type().dyn_cast().size(); i++) {{ - PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast()[i].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); - }} - }} else {{ - PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); - }} - }} - """ - -OUTPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); - """ -OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}].isa()) {{ - for (size_t i = 0; i < outputs[{index}].dyn_cast().size(); i++) {{ - PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast()[i].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); - }} - }} else {{ - PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); - }} - """ -OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{ - PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); - }} - """ -OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{ - if (outputs[{index}].isa()) {{ - for (size_t i = 0; i < outputs[{index}].dyn_cast().size(); i++) {{ - PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast()[i].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); - }} - }} else {{ - PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); - }} - }} - """ - -ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); - """ -ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa(), true, - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); - for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast()[i].isa<{standard}>(), true, - phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); - }} - """ OP_INFER_SHAPE_TEMPLATE = """ void {op_name}::InferShape( phi::InferMetaContext *infer_meta ) {{ auto fn = PD_INFER_META(phi::{infer_meta_func}); @@ -1004,8 +906,8 @@ def GenBuildOutputs( }} """ - CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().operation()->attributes().at("value").dyn_cast().data().GetData(); (void){name};\n""" - CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().operation()->attributes().at("value").dyn_cast().data().to<{dtype}>(); (void){name};\n""" + CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().GetData(); (void){name};\n""" + CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast().attributes().at("value").dyn_cast().data().to<{dtype}>(); (void){name};\n""" CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; phi::MetaTensor meta_{name}(&dense_{name}); @@ -1557,135 +1459,18 @@ def OpGenerator( view=view_str, ) - # =================================== # - # gen Verify func str # - # =================================== # - # generate op verify function: inputs_type_check_str - if ( - len(op_input_type_list) + len(op_mutable_attribute_name_list) - ) == 0: - inputs_type_check_str = ( - "// Inputs num is 0, not need to check inputs type." - ) - else: - inputs_type_check_str = "" - for idx in range(len(op_input_type_list)): - input_type = op_input_type_list[idx] - is_optional = op_input_optional_list[idx] - is_vector = False - if input_type.startswith("ir::VectorType<"): - is_vector = True - input_type = input_type[15:-1] - check_str = "" - if is_optional == "true": - if is_vector: - check_str = ( - INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) - ) - else: - check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) - else: - if is_vector: - check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) - else: - check_str = INPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=input_type - ) - inputs_type_check_str += check_str - - for idx in range(len(op_mutable_attribute_name_list)): - mutable_attribute_type = op_mutable_attribute_type_list[idx][0] - check_str = "" - if mutable_attribute_type == "paddle::dialect::ScalarAttribute": - check_str = INPUT_TYPE_CHECK_TEMPLATE.format( - index=idx + len(op_input_type_list), - standard="paddle::dialect::DenseTensorType", - ) - else: - check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx + len(op_input_type_list), - standard="paddle::dialect::DenseTensorType", - ) - inputs_type_check_str += check_str - # generate op verify function: outputs_type_check_str - if len(op_output_type_list) == 0: - outputs_type_check_str = ( - "// Outputs num is 0, not need to check outputs type." - ) - else: - outputs_type_check_str = "" - for idx in range(len(op_output_type_list)): - output_type = op_output_type_list[idx] - is_optional = op_output_optional_list[idx] - is_vector = False - if output_type.startswith("ir::VectorType<"): - is_vector = True - output_type = output_type[15:-1] - check_str = "" - if is_optional == "true": - if is_vector: - check_str = ( - OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type - ) - ) - else: - check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type - ) - else: - if is_vector: - check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type - ) - else: - check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( - index=idx, standard=output_type - ) - outputs_type_check_str += check_str - # generate op verify function: attributes_check_str - if len(op_non_mutable_attribute_name_list) == 0: - attributes_check_str = ( - "// Attributes num is 0, not need to check attributes type." - ) - else: - attributes_check_str = "" - for idx in range(len(op_non_mutable_attribute_name_list)): - attribute_name = op_non_mutable_attribute_name_list[idx] - attribute_type = op_non_mutable_attribute_type_list[idx] - if attribute_type.startswith("ir::ArrayAttribute<"): - attribute_type = attribute_type[19:-1] - attributes_check_str += ( - ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( - attribute_name=attribute_name, - standard=attribute_type, - ) - ) - else: - attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( - attribute_name=attribute_name, standard=attribute_type - ) - # generate op verify function - if "GradOp" in op_class_name or "Grad_Op" in op_class_name: - op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - ) - else: - op_verify_str = OP_VERIFY_TEMPLATE.format( - op_name=op_class_name, - inputs_size=len(op_input_type_list) - + len(op_mutable_attribute_type_list), - outputs_size=len(op_output_type_list), - inputs_type_check=inputs_type_check_str, - outputs_type_check=outputs_type_check_str, - attributes_check=attributes_check_str, - ) + # generate op verify function str + op_verify_str = gen_verify_func_str( + op_class_name, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_output_type_list, + op_output_optional_list, + ) op_infer_shape_str = "" if op_info.infer_shape_func: diff --git a/paddle/fluid/ir/dialect/op_verify_gen.py b/paddle/fluid/ir/dialect/op_verify_gen.py new file mode 100644 index 00000000000..12714e4af4d --- /dev/null +++ b/paddle/fluid/ir/dialect/op_verify_gen.py @@ -0,0 +1,275 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# verify +OP_VERIFY_TEMPLATE = """ +void {op_name}::Verify() {{ + VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}."; + VLOG(4) << "Verifying inputs:"; + {{ + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ(input_size, {inputs_size}u, + phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check} + }} + VLOG(4) << "Verifying attributes:"; + {{{attributes_check} + }} + VLOG(4) << "Verifying outputs:"; + {{ + auto output_size = num_results(); + PADDLE_ENFORCE_EQ(output_size, {outputs_size}u, + phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check} + }} + VLOG(4) << "End Verifying for: {op_name}."; +}} +""" + +GRAD_OP_VERIFY_TEMPLATE = """ +void {op_name}::Verify() {{}} +""" + +INPUT_TYPE_CHECK_TEMPLATE = """ + PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));""" +INPUT_VECTORTYPE_CHECK_TEMPLATE = """ + if (auto vec_type = (*this)->operand({index}).type().dyn_cast()) {{ + for (size_t i = 0; i < vec_type.size(); ++i) {{ + PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + }} + }} + else {{ + PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + }}""" +INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ + if (auto val = (*this)->operand({index})) {{ + PADDLE_ENFORCE(val.type().isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + }}""" +INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ + if (auto val = (*this)->operand({index})) {{ + if (auto vec_type = val.type().dyn_cast()) {{ + for (size_t i = 0; i < vec_type.size(); i++) {{ + PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + }} + }} + else {{ + PADDLE_ENFORCE(val.type().isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); + }} + }}""" +ATTRIBUTE_CHECK_TEMPLATE = """ + PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), + phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));""" +ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """ + PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa(), + phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); + for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ + PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast()[i].isa<{standard}>(), + phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); + }}""" +OUTPUT_TYPE_CHECK_TEMPLATE = """ + PADDLE_ENFORCE((*this)->result({index}).type().isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));""" +OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """ + auto output_{index}_type = (*this)->result({index}).type(); + if (auto vec_type = output_{index}_type.dyn_cast()) {{ + for (size_t i = 0; i < vec_type.size(); i++) {{ + PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + }} + }} + else {{ + PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + }}""" +OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ + if (auto output_{index} = (*this)->result({index})) {{ + PADDLE_ENFORCE(output_{index}.type().isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + }}""" +OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ + if (auto output_{index}_type = (*this)->result({index}).type()) {{ + if (auto vec_type = output_{index}_type.dyn_cast()) {{ + for (size_t i = 0; i < vec_type.size(); ++i) {{ + PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + }} + }} + else {{ + PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(), + phi::errors::PreconditionNotMet("Type validation failed for the {index}th output.")); + }} + }}""" + + +# generate inputs_type_check_str +def gen_inputs_type_check_str( + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, +): + if (len(op_input_type_list) + len(op_mutable_attribute_name_list)) == 0: + inputs_type_check_str = """ + // Inputs num is 0, not need to check inputs type.""" + else: + inputs_type_check_str = "" + for idx in range(len(op_input_type_list)): + input_type = op_input_type_list[idx] + is_optional = op_input_optional_list[idx] + is_vector = False + if input_type.startswith("ir::VectorType<"): + is_vector = True + input_type = input_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + else: + check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + else: + if is_vector: + check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + else: + check_str = INPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=input_type + ) + inputs_type_check_str += check_str + for idx in range(len(op_mutable_attribute_name_list)): + mutable_attribute_type = op_mutable_attribute_type_list[idx][0] + check_str = "" + if mutable_attribute_type == "paddle::dialect::ScalarAttribute": + check_str = INPUT_TYPE_CHECK_TEMPLATE.format( + index=idx + len(op_input_type_list), + standard="paddle::dialect::DenseTensorType", + ) + else: + check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx + len(op_input_type_list), + standard="paddle::dialect::DenseTensorType", + ) + inputs_type_check_str += check_str + return inputs_type_check_str + + +# generate attributes_check_str +def gen_attributes_type_check_str( + op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list +): + if len(op_non_mutable_attribute_name_list) == 0: + attributes_check_str = """ + // Attributes num is 0, not need to check attributes type.""" + else: + attributes_check_str = """ + auto& attributes = this->attributes();""" + for idx in range(len(op_non_mutable_attribute_name_list)): + attribute_name = op_non_mutable_attribute_name_list[idx] + attribute_type = op_non_mutable_attribute_type_list[idx] + if attribute_type.startswith("ir::ArrayAttribute<"): + attribute_type = attribute_type[19:-1] + attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format( + attribute_name=attribute_name, + standard=attribute_type, + ) + else: + attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format( + attribute_name=attribute_name, standard=attribute_type + ) + return attributes_check_str + + +# generate outputs_type_check_str +def gen_outputs_type_check_str(op_output_type_list, op_output_optional_list): + if len(op_output_type_list) == 0: + outputs_type_check_str = """ + // Outputs num is 0, not need to check outputs type.""" + else: + outputs_type_check_str = "" + for idx in range(len(op_output_type_list)): + output_type = op_output_type_list[idx] + is_optional = op_output_optional_list[idx] + is_vector = False + if output_type.startswith("ir::VectorType<"): + is_vector = True + output_type = output_type[15:-1] + check_str = "" + if is_optional == "true": + if is_vector: + check_str = OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + else: + check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + else: + if is_vector: + check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + else: + check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format( + index=idx, standard=output_type + ) + outputs_type_check_str += check_str + return outputs_type_check_str + + +# generate op verify function +def gen_verify_func_str( + op_class_name, + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_output_type_list, + op_output_optional_list, +): + if "GradOp" in op_class_name or "Grad_Op" in op_class_name: + return GRAD_OP_VERIFY_TEMPLATE.format(op_name=op_class_name) + + inputs_type_check_str = gen_inputs_type_check_str( + op_input_type_list, + op_input_optional_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + ) + attributes_type_check_str = gen_attributes_type_check_str( + op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list + ) + + outputs_type_check_str = gen_outputs_type_check_str( + op_output_type_list, op_output_optional_list + ) + + return OP_VERIFY_TEMPLATE.format( + op_name=op_class_name, + inputs_size=len(op_input_type_list) + + len(op_mutable_attribute_type_list), + inputs_type_check=inputs_type_check_str, + attributes_check=attributes_type_check_str, + outputs_size=len(op_output_type_list), + outputs_type_check=outputs_type_check_str, + ) diff --git a/paddle/ir/core/builtin_op.cc b/paddle/ir/core/builtin_op.cc index ed49b347780..091f0fdebf2 100644 --- a/paddle/ir/core/builtin_op.cc +++ b/paddle/ir/core/builtin_op.cc @@ -23,7 +23,7 @@ namespace ir { const char *ModuleOp::attributes_name[attributes_num] = {"program"}; Program *ModuleOp::program() { - const AttributeMap &attr = operation()->attributes(); + const AttributeMap &attr = this->attributes(); auto iter = attr.find("program"); if (iter == attr.end() || !iter->second) return nullptr; return static_cast( @@ -52,20 +52,19 @@ void ModuleOp::Destroy() { } } -void ModuleOp::Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { +void ModuleOp::Verify() { VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; - // Verify inputs type: - IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); + // Verify inputs: + IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); - // Verify if attributes contain attribute name in attributes_name: + // Verify attributes: + auto &attributes = this->attributes(); auto iter = attributes.find("program"); IR_ENFORCE(iter != attributes.end() && iter->second.isa(), "Type of attribute: program is not right."); - // Verify outputs type: - IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); + // Verify outputs: + IR_ENFORCE(num_results() == 0u, "The size of inputs must be equal to 0."); } const char *GetParameterOp::attributes_name[attributes_num] = { @@ -80,20 +79,19 @@ void GetParameterOp::Build(Builder &builder, argument.output_types.emplace_back(type); } -void GetParameterOp::Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { +void GetParameterOp::Verify() { VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; - // Verify inputs type: - IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); + // Verify inputs: + IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0."); // Verify if attributes contain attribute name in attributes_name: + auto &attributes = this->attributes(); auto iter = attributes.find("parameter_name"); IR_ENFORCE(iter != attributes.end() && iter->second.isa(), "Type of attribute: parameter_name is not right."); // Verify outputs type: - IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1."); + IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); } const char *SetParameterOp::attributes_name[attributes_num] = { @@ -107,20 +105,19 @@ void SetParameterOp::Build(Builder &builder, // NOLINT argument.AddAttribute(attributes_name[0], ir::StrAttribute::get(builder.ir_context(), name)); } -void SetParameterOp::Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { +void SetParameterOp::Verify() { VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; - // Verify inputs type: - IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1."); + // Verify inputs: + IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1."); - // Verify if attributes contain attribute name in attributes_name: + // Verify attributes: + auto &attributes = this->attributes(); auto iter = attributes.find("parameter_name"); IR_ENFORCE(iter != attributes.end() && iter->second.isa(), "Type of attribute: parameter_name is not right."); - // Verify outputs type: - IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); + // Verify outputs: + IR_ENFORCE(num_results() == 0u, "The size of outputs must be equal to 0."); } void CombineOp::Build(Builder &builder, @@ -135,58 +132,56 @@ void CombineOp::Build(Builder &builder, ir::VectorType::get(builder.ir_context(), inputs_type)); } -void CombineOp::Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { +void CombineOp::Verify() { // outputs.size() == 1 - IR_ENFORCE(outputs.size() == 1, - "The size %d of outputs must be equal to 1.", - outputs.size()); + IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); + + // output_type == Vector + auto output_type = (*this)->result(0).type().dyn_cast(); + IR_ENFORCE(output_type, + "The type of outputs[0] must be equal to VectorType."); - // outputs[0].type == Vector - IR_ENFORCE(outputs[0].isa(), - "The type %s of outputs[0] must be equal to VectorType.", - outputs[0]); - ir::VectorType output_type = outputs[0].dyn_cast(); // inputs.size() == outputs[0].size() - IR_ENFORCE(output_type.size() == inputs.size(), - "The size %d of outputs[0] must be equal to size %d of inputs.", + auto input_num = num_operands(); + IR_ENFORCE(output_type.size() == input_num, + "The size %d of output must be equal to size %d of inputs.", output_type.size(), - inputs.size()); + input_num); // forall i in inputs.size(): inputs[i].type == outputs[0][i].type - for (size_t i = 0; i < inputs.size(); i++) { - IR_ENFORCE(output_type[i] == inputs[i].type(), + for (size_t i = 0; i < input_num; ++i) { + auto type = (*this)->operand(i).type(); + IR_ENFORCE(output_type[i] == type, "The type %s of outputs[0][%d] must be " "equal to type %s of inputs[%d].", output_type[i], i, - inputs[i].type(), + type, i); } } const char *SliceOp::attributes_name[attributes_num] = {"index"}; -void SliceOp::Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { +void SliceOp::Verify() { // inputs.size() == 1 - IR_ENFORCE(inputs.size() == 1, - "The size %d of inputs must be equal to 1.", - inputs.size()); + auto input_size = num_operands(); + IR_ENFORCE( + input_size == 1, "The size %d of inputs must be equal to 1.", input_size); // inputs[0].type == Vector - IR_ENFORCE(inputs[0].type().isa(), + auto input_type = (*this)->operand(0).type().dyn_cast(); + IR_ENFORCE(input_type, "The type %s of inputs[0] must be equal to VectorType.", - inputs[0].type()); - ir::VectorType input_type = inputs[0].type().dyn_cast(); + input_type); + auto output_size = num_results(); // outputs.size() == 1 - IR_ENFORCE(outputs.size() == 1, + IR_ENFORCE(output_size == 1, "The size %d of outputs must be equal to 1.", - outputs.size()); + output_size); // attributes contains index: Int32 + auto &attributes = this->attributes(); IR_ENFORCE(attributes.count("index") != 0, "The attributes must contains index."); const ir::Attribute &attr = attributes.at("index"); @@ -203,12 +198,13 @@ void SliceOp::Verify(const std::vector &inputs, input_type.size()); // inputs[index].type == outputs[0].type + auto output_type = (*this)->result(0).type(); IR_ENFORCE( - input_type[index] == outputs[0], + input_type[index] == output_type, "The type %s of inputs[%d] must be equal to type %s of outputs[0].", input_type[index], index, - outputs[0]); + output_type); } const char *ConstantOp::attributes_name[attributes_num] = {"value"}; @@ -221,16 +217,13 @@ void ConstantOp::Build(Builder &builder, argument.output_types.push_back(output_type); } -void ConstantOp::Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { - IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); - IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1."); - IR_ENFORCE(attributes.count("value") > 0, - "Type of attribute: value is not right."); +void ConstantOp::Verify() { + IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0."); + IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1."); + IR_ENFORCE(attributes().count("value") > 0, "must has value attribute"); } -Attribute ConstantOp::value() { return operation()->attributes().at("value"); } +Attribute ConstantOp::value() { return attributes().at("value"); } } // namespace ir diff --git a/paddle/ir/core/builtin_op.h b/paddle/ir/core/builtin_op.h index 56cfafd35ff..27f264ff218 100644 --- a/paddle/ir/core/builtin_op.h +++ b/paddle/ir/core/builtin_op.h @@ -30,10 +30,7 @@ class IR_API ModuleOp : public ir::Op { static const char *name() { return "builtin.module"; } static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes); - + void Verify(); Program *program(); Block *block(); @@ -58,9 +55,7 @@ class IR_API GetParameterOp : public ir::Op { OperationArgument &argument, // NOLINT const std::string &name, Type type); - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes); + void Verify(); }; /// @@ -77,9 +72,7 @@ class IR_API SetParameterOp : public ir::Op { OperationArgument &argument, // NOLINT OpResult parameter, const std::string &name); - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes); + void Verify(); }; /// @@ -99,9 +92,7 @@ class IR_API CombineOp : public ir::Op { OperationArgument &argument, // NOLINT const std::vector &inputs); - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes); + void Verify(); }; /// @@ -116,9 +107,7 @@ class IR_API SliceOp : public ir::Op { static constexpr uint32_t attributes_num = 1; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes); + void Verify(); }; class IR_API ConstantLikeTrait : public OpTraitBase { @@ -143,9 +132,7 @@ class IR_API ConstantOp : public Op { Attribute value, Type output_type); - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const AttributeMap &attributes); + void Verify(); Attribute value(); }; diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index c5f9f86fc76..be67898dd98 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -100,7 +100,7 @@ class IR_API Dialect { ConcreteOp::GetTraitSet(), ConcreteOp::attributes_num, ConcreteOp::attributes_name, - ConcreteOp::Verify); + ConcreteOp::VerifyInvariants); } void RegisterOp(const std::string &name, OpInfoImpl *op_info); diff --git a/paddle/ir/core/ir_context.h b/paddle/ir/core/ir_context.h index 1ff5bb6e525..7abea0284a9 100644 --- a/paddle/ir/core/ir_context.h +++ b/paddle/ir/core/ir_context.h @@ -32,6 +32,7 @@ class InterfaceValue; class Type; class OpResult; class Attribute; +class Operation; using OpInfoMap = std::unordered_map; @@ -102,18 +103,14 @@ class IR_API IrContext { /// /// \brief Register an op infomation to IrContext /// - void RegisterOpInfo( - Dialect *dialect, - TypeId op_id, - const char *name, - std::vector &&interface_map, - const std::vector &trait_set, - size_t attributes_num, - const char **attributes_name, - void (*verify)( - const std::vector &inputs, - const std::vector &outputs, - const std::unordered_map &attributes)); + void RegisterOpInfo(Dialect *dialect, + TypeId op_id, + const char *name, + std::vector &&interface_map, + const std::vector &trait_set, + size_t attributes_num, + const char **attributes_name, + void (*verify)(Operation *)); /// /// \brief Get registered operaiton infomation. diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h index 43644774688..bed27d68d9b 100644 --- a/paddle/ir/core/op_base.h +++ b/paddle/ir/core/op_base.h @@ -78,6 +78,12 @@ class IR_API OpBase { IrContext *ir_context() const { return operation_->ir_context(); } + uint32_t num_results() const { return operation_->num_results(); } + + uint32_t num_operands() const { return operation_->num_operands(); } + + const AttributeMap &attributes() const { return operation_->attributes(); } + private: Operation *operation_; // Not owned }; @@ -205,6 +211,16 @@ class Op : public OpBase { ConstructInterfacesOrTraits::trait(p_first_trait); return trait_set; } + static constexpr bool HasNoDataMembers() { + class EmptyOp : public Op {}; + return sizeof(ConcreteOp) == sizeof(EmptyOp); + } + + static void VerifyInvariants(Operation *op) { + static_assert(HasNoDataMembers(), + "Op class shouldn't define new data members"); + op->dyn_cast().Verify(); + } }; } // namespace ir diff --git a/paddle/ir/core/op_info.cc b/paddle/ir/core/op_info.cc index e2e1d877fa2..6c9b62f56e6 100644 --- a/paddle/ir/core/op_info.cc +++ b/paddle/ir/core/op_info.cc @@ -35,11 +35,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); } -void OpInfo::Verify(const std::vector &inputs, - const std::vector &outputs, - const AttributeMap &attributes) { - impl_->verify()(inputs, outputs, attributes); -} +void OpInfo::Verify(Operation *operation) const { impl_->verify()(operation); } void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr; diff --git a/paddle/ir/core/op_info.h b/paddle/ir/core/op_info.h index 485e116cf5a..f92d37d4b33 100644 --- a/paddle/ir/core/op_info.h +++ b/paddle/ir/core/op_info.h @@ -25,6 +25,9 @@ class OpResult; class Type; class Attribute; class Dialect; +class Operation; + +typedef void (*VerifyPtr)(Operation *op); class IR_API OpInfo { public: @@ -49,9 +52,7 @@ class IR_API OpInfo { TypeId id() const; - void Verify(const std::vector &inputs, - const std::vector &outputs, - const std::unordered_map &attributes); + void Verify(Operation *) const; template bool HasTrait() const { diff --git a/paddle/ir/core/op_info_impl.h b/paddle/ir/core/op_info_impl.h index e5d8fd25aaf..52666f1b377 100644 --- a/paddle/ir/core/op_info_impl.h +++ b/paddle/ir/core/op_info_impl.h @@ -25,9 +25,6 @@ namespace ir { class Dialect; -typedef void (*VerifyPtr)(const std::vector &inputs, - const std::vector &outputs, - const AttributeMap &attributes); /// /// \brief OpInfoImpl class. diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index ae23338cb22..01cbafb5d59 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -46,10 +46,6 @@ Operation *Operation::Create(const std::vector &inputs, const std::vector &output_types, ir::OpInfo op_info, size_t num_regions) { - // 0. Verify - if (op_info) { - op_info.Verify(inputs, output_types, attributes); - } // 1. Calculate the required memory size for OpResults + Operation + // OpOperands. uint32_t num_results = output_types.size(); @@ -100,6 +96,11 @@ Operation *Operation::Create(const std::vector &inputs, base_ptr += sizeof(Region); } } + + // 0. Verify + if (op_info) { + op_info.Verify(op); + } return op; } diff --git a/test/cpp/ir/core/ir_infershape_test.cc b/test/cpp/ir/core/ir_infershape_test.cc index 26ad377b06b..0053cd77d89 100644 --- a/test/cpp/ir/core/ir_infershape_test.cc +++ b/test/cpp/ir/core/ir_infershape_test.cc @@ -45,9 +45,7 @@ class OperationTest static const char *name() { return "test.operation2"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) {} + static void Verify() {} static void InferShape(phi::InferMetaContext *infer_meta) { auto fn = PD_INFER_META(phi::CreateInferMeta); fn(infer_meta); diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index cb04f440c01..6ab59c6014d 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -90,9 +90,8 @@ class Operation1 : public ir::Op { static const char *name() { return "test.operation1"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { + void Verify() { + auto &attributes = this->attributes(); if (attributes.count("op1_attr1") == 0 || !attributes.at("op1_attr1").isa()) { throw("Type of attribute: parameter_name is not right."); @@ -133,9 +132,8 @@ class Operation2 static const char *name() { return "test.operation2"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { + void Verify() { + auto &attributes = this->attributes(); if (attributes.count("op2_attr1") == 0 || (!attributes.at("op2_attr1").isa())) { throw("Type of attribute: parameter_name is not right."); diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index a55f3eeb347..6e2a8e5acb9 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -38,22 +38,21 @@ class AddOp : public ir::Op { static const char *name() { return "test.add"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { - if (inputs.size() != 2) { - throw("The size of inputs must be equal to 2."); - } - if (outputs.size() != 1) { - throw("The size of outputs must be equal to 1."); - } - } + void Verify(); static void Build(ir::Builder &builder, // NOLINT ir::OperationArgument &argument, // NOLINT ir::OpResult l_operand, ir::OpResult r_operand, ir::Type sum_type); }; +void AddOp::Verify() { + if (num_operands() != 2) { + throw("The size of inputs must be equal to 2."); + } + if (num_results() != 1) { + throw("The size of outputs must be equal to 1."); + } +} void AddOp::Build(ir::Builder &, ir::OperationArgument &argument, ir::OpResult l_operand, @@ -262,9 +261,9 @@ TEST(program_test, builder) { ir::Type full_op_output = full_op->result(0).type(); EXPECT_EQ(program.block()->size(), 1u); EXPECT_EQ(program.block()->back(), full_op.operation()); - EXPECT_EQ(full_op->num_operands(), 0u); - EXPECT_EQ(full_op->num_results(), 1u); - EXPECT_EQ(full_op->attributes().size(), 4u); + EXPECT_EQ(full_op.num_operands(), 0u); + EXPECT_EQ(full_op.num_results(), 1u); + EXPECT_EQ(full_op.attributes().size(), 4u); EXPECT_EQ( full_op_output.dyn_cast().offset() == 0, true); diff --git a/test/cpp/ir/pass/pass_manager_test.cc b/test/cpp/ir/pass/pass_manager_test.cc index 22cb62dda27..87a5abd6445 100644 --- a/test/cpp/ir/pass/pass_manager_test.cc +++ b/test/cpp/ir/pass/pass_manager_test.cc @@ -65,22 +65,21 @@ class AddOp : public ir::Op { static const char *name() { return "test.add"; } static constexpr const char **attributes_name = nullptr; static constexpr uint32_t attributes_num = 0; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { - if (inputs.size() != 2) { - throw("The size of inputs must be equal to 2."); - } - if (outputs.size() != 1) { - throw("The size of outputs must be equal to 1."); - } - } + void Verify(); static void Build(ir::Builder &builder, // NOLINT ir::OperationArgument &argument, // NOLINT ir::OpResult l_operand, ir::OpResult r_operand, ir::Type sum_type); }; +void AddOp::Verify() { + if (num_operands() != 2) { + throw("The size of inputs must be equal to 2."); + } + if (num_results() != 1) { + throw("The size of outputs must be equal to 1."); + } +} void AddOp::Build(ir::Builder &, ir::OperationArgument &argument, ir::OpResult l_operand, diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index 607108d582b..bd1b83c0664 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -48,20 +48,20 @@ class Operation1 : public ir::Op { static const char *name() { return "test.Operation1"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; - static void Verify(const std::vector &inputs, - const std::vector &outputs, - const ir::AttributeMap &attributes) { - if (attributes.count("op2_attr1") == 0 || - (!attributes.at("op2_attr1").isa())) { - throw("Type of attribute: parameter_name is not right."); - } - if (attributes.count("op2_attr2") == 0 || - (!attributes.at("op2_attr2").isa())) { - throw("Type of attribute: parameter_name is not right."); - } - } + void Verify(); static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } }; +void Operation1::Verify() { + auto &attributes = this->attributes(); + if (attributes.count("op2_attr1") == 0 || + (!attributes.at("op2_attr1").isa())) { + throw("Type of attribute: parameter_name is not right."); + } + if (attributes.count("op2_attr2") == 0 || + (!attributes.at("op2_attr2").isa())) { + throw("Type of attribute: parameter_name is not right."); + } +} const char *Operation1::attributes_name[attributes_num] = {"op2_attr1", "op2_attr2"}; IR_DECLARE_EXPLICIT_TYPE_ID(Operation1) -- GitLab