未验证 提交 96652265 编写于 作者: W winter-wang 提交者: GitHub

[IR] rectify the verify api (#54895)

上级 e49c17d2
...@@ -20,9 +20,7 @@ namespace dialect { ...@@ -20,9 +20,7 @@ namespace dialect {
const char *PhiKernelOp::attributes_name[attributes_num] = { const char *PhiKernelOp::attributes_name[attributes_num] = {
"base_op", "infermeta_fn", "kernel_fn"}; "base_op", "infermeta_fn", "kernel_fn"};
void PhiKernelOp::Verify(const std::vector<ir::OpResult> &inputs, void PhiKernelOp::Verify() {
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp.";
// Verify inputs type: // Verify inputs type:
......
...@@ -26,9 +26,7 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> { ...@@ -26,9 +26,7 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> {
static const char *name() { return "phi.kernel"; } static const char *name() { return "phi.kernel"; }
static constexpr uint32_t attributes_num = 3; static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
}; };
} // namespace dialect } // namespace dialect
......
...@@ -16,6 +16,7 @@ import argparse ...@@ -16,6 +16,7 @@ import argparse
import os import os
import yaml import yaml
from op_verify_gen import gen_verify_func_str
# ===================================== # =====================================
# String Template for h file code gen # String Template for h file code gen
...@@ -65,7 +66,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ ...@@ -65,7 +66,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static OpInfoTuple GetOpInfo(); static OpInfoTuple GetOpInfo();
static void Build({build_args}); static void Build({build_args});
{build_mutable_attr_is_input} {build_mutable_attr_is_input}
static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes); void Verify();
{get_inputs_and_outputs} {get_inputs_and_outputs}
{exclusive_interface} {exclusive_interface}
}}; }};
...@@ -141,105 +142,6 @@ void {op_name}::Build({build_args}) {{ ...@@ -141,105 +142,6 @@ void {op_name}::Build({build_args}) {{
{build_outputs} {build_outputs}
}} }}
""" """
# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}.";
// Verify inputs type:
PADDLE_ENFORCE_EQ(inputs.size(), {inputs_size},
phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", inputs.size()));
{inputs_type_check}
// Verify outputs type:
PADDLE_ENFORCE_EQ(outputs.size(), {outputs_size},
phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", outputs.size()));
{outputs_type_check}
// Verify if attributes contain attribute name in attributes_name:
{attributes_check}
}}
"""
GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
(void)inputs;
(void)outputs;
(void)attributes;
}}
"""
INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{
if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
"""
OUTPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
"""
ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}
"""
OP_INFER_SHAPE_TEMPLATE = """ OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferShape( phi::InferMetaContext *infer_meta ) {{ void {op_name}::InferShape( phi::InferMetaContext *infer_meta ) {{
auto fn = PD_INFER_META(phi::{infer_meta_func}); auto fn = PD_INFER_META(phi::{infer_meta_func});
...@@ -1004,8 +906,8 @@ def GenBuildOutputs( ...@@ -1004,8 +906,8 @@ def GenBuildOutputs(
}} }}
""" """
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n""" CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n""" CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name}; CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name}); phi::MetaTensor meta_{name}(&dense_{name});
...@@ -1557,135 +1459,18 @@ def OpGenerator( ...@@ -1557,135 +1459,18 @@ def OpGenerator(
view=view_str, view=view_str,
) )
# =================================== # # generate op verify function str
# gen Verify func str # op_verify_str = gen_verify_func_str(
# =================================== # op_class_name,
# generate op verify function: inputs_type_check_str op_input_type_list,
if ( op_input_optional_list,
len(op_input_type_list) + len(op_mutable_attribute_name_list) op_mutable_attribute_name_list,
) == 0: op_mutable_attribute_type_list,
inputs_type_check_str = ( op_non_mutable_attribute_name_list,
"// Inputs num is 0, not need to check inputs type." op_non_mutable_attribute_type_list,
) op_output_type_list,
else: op_output_optional_list,
inputs_type_check_str = "" )
for idx in range(len(op_input_type_list)):
input_type = op_input_type_list[idx]
is_optional = op_input_optional_list[idx]
is_vector = False
if input_type.startswith("ir::VectorType<"):
is_vector = True
input_type = input_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = (
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
)
else:
check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
if is_vector:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
inputs_type_check_str += check_str
for idx in range(len(op_mutable_attribute_name_list)):
mutable_attribute_type = op_mutable_attribute_type_list[idx][0]
check_str = ""
if mutable_attribute_type == "paddle::dialect::ScalarAttribute":
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
else:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
inputs_type_check_str += check_str
# generate op verify function: outputs_type_check_str
if len(op_output_type_list) == 0:
outputs_type_check_str = (
"// Outputs num is 0, not need to check outputs type."
)
else:
outputs_type_check_str = ""
for idx in range(len(op_output_type_list)):
output_type = op_output_type_list[idx]
is_optional = op_output_optional_list[idx]
is_vector = False
if output_type.startswith("ir::VectorType<"):
is_vector = True
output_type = output_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = (
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
)
else:
check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
if is_vector:
check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
outputs_type_check_str += check_str
# generate op verify function: attributes_check_str
if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = (
"// Attributes num is 0, not need to check attributes type."
)
else:
attributes_check_str = ""
for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_name = op_non_mutable_attribute_name_list[idx]
attribute_type = op_non_mutable_attribute_type_list[idx]
if attribute_type.startswith("ir::ArrayAttribute<"):
attribute_type = attribute_type[19:-1]
attributes_check_str += (
ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format(
attribute_name=attribute_name,
standard=attribute_type,
)
)
else:
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type
)
# generate op verify function
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
)
else:
op_verify_str = OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
inputs_size=len(op_input_type_list)
+ len(op_mutable_attribute_type_list),
outputs_size=len(op_output_type_list),
inputs_type_check=inputs_type_check_str,
outputs_type_check=outputs_type_check_str,
attributes_check=attributes_check_str,
)
op_infer_shape_str = "" op_infer_shape_str = ""
if op_info.infer_shape_func: if op_info.infer_shape_func:
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{
VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}.";
VLOG(4) << "Verifying inputs:";
{{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(input_size, {inputs_size}u,
phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check}
}}
VLOG(4) << "Verifying attributes:";
{{{attributes_check}
}}
VLOG(4) << "Verifying outputs:";
{{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(output_size, {outputs_size}u,
phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check}
}}
VLOG(4) << "End Verifying for: {op_name}.";
}}
"""
GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{}}
"""
INPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """
if (auto vec_type = (*this)->operand({index}).type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}"""
ATTRIBUTE_CHECK_TEMPLATE = """
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}"""
OUTPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->result({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """
auto output_{index}_type = (*this)->result({index}).type();
if (auto vec_type = output_{index}_type.dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
else {{
PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto output_{index} = (*this)->result({index})) {{
PADDLE_ENFORCE(output_{index}.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto output_{index}_type = (*this)->result({index}).type()) {{
if (auto vec_type = output_{index}_type.dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
else {{
PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}"""
# generate inputs_type_check_str
def gen_inputs_type_check_str(
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
):
if (len(op_input_type_list) + len(op_mutable_attribute_name_list)) == 0:
inputs_type_check_str = """
// Inputs num is 0, not need to check inputs type."""
else:
inputs_type_check_str = ""
for idx in range(len(op_input_type_list)):
input_type = op_input_type_list[idx]
is_optional = op_input_optional_list[idx]
is_vector = False
if input_type.startswith("ir::VectorType<"):
is_vector = True
input_type = input_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
if is_vector:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
inputs_type_check_str += check_str
for idx in range(len(op_mutable_attribute_name_list)):
mutable_attribute_type = op_mutable_attribute_type_list[idx][0]
check_str = ""
if mutable_attribute_type == "paddle::dialect::ScalarAttribute":
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
else:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
inputs_type_check_str += check_str
return inputs_type_check_str
# generate attributes_check_str
def gen_attributes_type_check_str(
op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list
):
if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = """
// Attributes num is 0, not need to check attributes type."""
else:
attributes_check_str = """
auto& attributes = this->attributes();"""
for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_name = op_non_mutable_attribute_name_list[idx]
attribute_type = op_non_mutable_attribute_type_list[idx]
if attribute_type.startswith("ir::ArrayAttribute<"):
attribute_type = attribute_type[19:-1]
attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format(
attribute_name=attribute_name,
standard=attribute_type,
)
else:
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type
)
return attributes_check_str
# generate outputs_type_check_str
def gen_outputs_type_check_str(op_output_type_list, op_output_optional_list):
if len(op_output_type_list) == 0:
outputs_type_check_str = """
// Outputs num is 0, not need to check outputs type."""
else:
outputs_type_check_str = ""
for idx in range(len(op_output_type_list)):
output_type = op_output_type_list[idx]
is_optional = op_output_optional_list[idx]
is_vector = False
if output_type.startswith("ir::VectorType<"):
is_vector = True
output_type = output_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
if is_vector:
check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
outputs_type_check_str += check_str
return outputs_type_check_str
# generate op verify function
def gen_verify_func_str(
op_class_name,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_output_type_list,
op_output_optional_list,
):
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
return GRAD_OP_VERIFY_TEMPLATE.format(op_name=op_class_name)
inputs_type_check_str = gen_inputs_type_check_str(
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
)
attributes_type_check_str = gen_attributes_type_check_str(
op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list
)
outputs_type_check_str = gen_outputs_type_check_str(
op_output_type_list, op_output_optional_list
)
return OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
inputs_size=len(op_input_type_list)
+ len(op_mutable_attribute_type_list),
inputs_type_check=inputs_type_check_str,
attributes_check=attributes_type_check_str,
outputs_size=len(op_output_type_list),
outputs_type_check=outputs_type_check_str,
)
...@@ -23,7 +23,7 @@ namespace ir { ...@@ -23,7 +23,7 @@ namespace ir {
const char *ModuleOp::attributes_name[attributes_num] = {"program"}; const char *ModuleOp::attributes_name[attributes_num] = {"program"};
Program *ModuleOp::program() { Program *ModuleOp::program() {
const AttributeMap &attr = operation()->attributes(); const AttributeMap &attr = this->attributes();
auto iter = attr.find("program"); auto iter = attr.find("program");
if (iter == attr.end() || !iter->second) return nullptr; if (iter == attr.end() || !iter->second) return nullptr;
return static_cast<Program *>( return static_cast<Program *>(
...@@ -52,20 +52,19 @@ void ModuleOp::Destroy() { ...@@ -52,20 +52,19 @@ void ModuleOp::Destroy() {
} }
} }
void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs, void ModuleOp::Verify() {
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type: // Verify inputs:
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name: // Verify attributes:
auto &attributes = this->attributes();
auto iter = attributes.find("program"); auto iter = attributes.find("program");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(), IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
"Type of attribute: program is not right."); "Type of attribute: program is not right.");
// Verify outputs type: // Verify outputs:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); IR_ENFORCE(num_results() == 0u, "The size of inputs must be equal to 0.");
} }
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
...@@ -80,20 +79,19 @@ void GetParameterOp::Build(Builder &builder, ...@@ -80,20 +79,19 @@ void GetParameterOp::Build(Builder &builder,
argument.output_types.emplace_back(type); argument.output_types.emplace_back(type);
} }
void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, void GetParameterOp::Verify() {
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs type: // Verify inputs:
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name: // Verify if attributes contain attribute name in attributes_name:
auto &attributes = this->attributes();
auto iter = attributes.find("parameter_name"); auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(), IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right."); "Type of attribute: parameter_name is not right.");
// Verify outputs type: // Verify outputs type:
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1."); IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
} }
const char *SetParameterOp::attributes_name[attributes_num] = { const char *SetParameterOp::attributes_name[attributes_num] = {
...@@ -107,20 +105,19 @@ void SetParameterOp::Build(Builder &builder, // NOLINT ...@@ -107,20 +105,19 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0], argument.AddAttribute(attributes_name[0],
ir::StrAttribute::get(builder.ir_context(), name)); ir::StrAttribute::get(builder.ir_context(), name));
} }
void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs, void SetParameterOp::Verify() {
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp."; VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs type: // Verify inputs:
IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1."); IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
// Verify if attributes contain attribute name in attributes_name: // Verify attributes:
auto &attributes = this->attributes();
auto iter = attributes.find("parameter_name"); auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(), IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right."); "Type of attribute: parameter_name is not right.");
// Verify outputs type: // Verify outputs:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0."); IR_ENFORCE(num_results() == 0u, "The size of outputs must be equal to 0.");
} }
void CombineOp::Build(Builder &builder, void CombineOp::Build(Builder &builder,
...@@ -135,58 +132,56 @@ void CombineOp::Build(Builder &builder, ...@@ -135,58 +132,56 @@ void CombineOp::Build(Builder &builder,
ir::VectorType::get(builder.ir_context(), inputs_type)); ir::VectorType::get(builder.ir_context(), inputs_type));
} }
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs, void CombineOp::Verify() {
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// outputs.size() == 1 // outputs.size() == 1
IR_ENFORCE(outputs.size() == 1, IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
"The size %d of outputs must be equal to 1.",
outputs.size()); // output_type == Vector<Type>
auto output_type = (*this)->result(0).type().dyn_cast<VectorType>();
IR_ENFORCE(output_type,
"The type of outputs[0] must be equal to VectorType.");
// outputs[0].type == Vector<Type>
IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
"The type %s of outputs[0] must be equal to VectorType.",
outputs[0]);
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size() // inputs.size() == outputs[0].size()
IR_ENFORCE(output_type.size() == inputs.size(), auto input_num = num_operands();
"The size %d of outputs[0] must be equal to size %d of inputs.", IR_ENFORCE(output_type.size() == input_num,
"The size %d of output must be equal to size %d of inputs.",
output_type.size(), output_type.size(),
inputs.size()); input_num);
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type // forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < input_num; ++i) {
IR_ENFORCE(output_type[i] == inputs[i].type(), auto type = (*this)->operand(i).type();
IR_ENFORCE(output_type[i] == type,
"The type %s of outputs[0][%d] must be " "The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].", "equal to type %s of inputs[%d].",
output_type[i], output_type[i],
i, i,
inputs[i].type(), type,
i); i);
} }
} }
const char *SliceOp::attributes_name[attributes_num] = {"index"}; const char *SliceOp::attributes_name[attributes_num] = {"index"};
void SliceOp::Verify(const std::vector<ir::OpResult> &inputs, void SliceOp::Verify() {
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// inputs.size() == 1 // inputs.size() == 1
IR_ENFORCE(inputs.size() == 1, auto input_size = num_operands();
"The size %d of inputs must be equal to 1.", IR_ENFORCE(
inputs.size()); input_size == 1, "The size %d of inputs must be equal to 1.", input_size);
// inputs[0].type == Vector<Type> // inputs[0].type == Vector<Type>
IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(), auto input_type = (*this)->operand(0).type().dyn_cast<ir::VectorType>();
IR_ENFORCE(input_type,
"The type %s of inputs[0] must be equal to VectorType.", "The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type()); input_type);
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();
auto output_size = num_results();
// outputs.size() == 1 // outputs.size() == 1
IR_ENFORCE(outputs.size() == 1, IR_ENFORCE(output_size == 1,
"The size %d of outputs must be equal to 1.", "The size %d of outputs must be equal to 1.",
outputs.size()); output_size);
// attributes contains index: Int32 // attributes contains index: Int32
auto &attributes = this->attributes();
IR_ENFORCE(attributes.count("index") != 0, IR_ENFORCE(attributes.count("index") != 0,
"The attributes must contains index."); "The attributes must contains index.");
const ir::Attribute &attr = attributes.at("index"); const ir::Attribute &attr = attributes.at("index");
...@@ -203,12 +198,13 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs, ...@@ -203,12 +198,13 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
input_type.size()); input_type.size());
// inputs[index].type == outputs[0].type // inputs[index].type == outputs[0].type
auto output_type = (*this)->result(0).type();
IR_ENFORCE( IR_ENFORCE(
input_type[index] == outputs[0], input_type[index] == output_type,
"The type %s of inputs[%d] must be equal to type %s of outputs[0].", "The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index], input_type[index],
index, index,
outputs[0]); output_type);
} }
const char *ConstantOp::attributes_name[attributes_num] = {"value"}; const char *ConstantOp::attributes_name[attributes_num] = {"value"};
...@@ -221,16 +217,13 @@ void ConstantOp::Build(Builder &builder, ...@@ -221,16 +217,13 @@ void ConstantOp::Build(Builder &builder,
argument.output_types.push_back(output_type); argument.output_types.push_back(output_type);
} }
void ConstantOp::Verify(const std::vector<ir::OpResult> &inputs, void ConstantOp::Verify() {
const std::vector<ir::Type> &outputs, IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
const ir::AttributeMap &attributes) { IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0."); IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(attributes.count("value") > 0,
"Type of attribute: value is not right.");
} }
Attribute ConstantOp::value() { return operation()->attributes().at("value"); } Attribute ConstantOp::value() { return attributes().at("value"); }
} // namespace ir } // namespace ir
......
...@@ -30,10 +30,7 @@ class IR_API ModuleOp : public ir::Op<ModuleOp> { ...@@ -30,10 +30,7 @@ class IR_API ModuleOp : public ir::Op<ModuleOp> {
static const char *name() { return "builtin.module"; } static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
Program *program(); Program *program();
Block *block(); Block *block();
...@@ -58,9 +55,7 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> { ...@@ -58,9 +55,7 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
OperationArgument &argument, // NOLINT OperationArgument &argument, // NOLINT
const std::string &name, const std::string &name,
Type type); Type type);
static void Verify(const std::vector<OpResult> &inputs, void Verify();
const std::vector<Type> &outputs,
const ir::AttributeMap &attributes);
}; };
/// ///
...@@ -77,9 +72,7 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> { ...@@ -77,9 +72,7 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
OperationArgument &argument, // NOLINT OperationArgument &argument, // NOLINT
OpResult parameter, OpResult parameter,
const std::string &name); const std::string &name);
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
}; };
/// ///
...@@ -99,9 +92,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> { ...@@ -99,9 +92,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
OperationArgument &argument, // NOLINT OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &inputs); const std::vector<ir::OpResult> &inputs);
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
}; };
/// ///
...@@ -116,9 +107,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> { ...@@ -116,9 +107,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
}; };
class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> { class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
...@@ -143,9 +132,7 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> { ...@@ -143,9 +132,7 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
Attribute value, Attribute value,
Type output_type); Type output_type);
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const AttributeMap &attributes);
Attribute value(); Attribute value();
}; };
......
...@@ -100,7 +100,7 @@ class IR_API Dialect { ...@@ -100,7 +100,7 @@ class IR_API Dialect {
ConcreteOp::GetTraitSet(), ConcreteOp::GetTraitSet(),
ConcreteOp::attributes_num, ConcreteOp::attributes_num,
ConcreteOp::attributes_name, ConcreteOp::attributes_name,
ConcreteOp::Verify); ConcreteOp::VerifyInvariants);
} }
void RegisterOp(const std::string &name, OpInfoImpl *op_info); void RegisterOp(const std::string &name, OpInfoImpl *op_info);
......
...@@ -32,6 +32,7 @@ class InterfaceValue; ...@@ -32,6 +32,7 @@ class InterfaceValue;
class Type; class Type;
class OpResult; class OpResult;
class Attribute; class Attribute;
class Operation;
using OpInfoMap = std::unordered_map<std::string, OpInfo>; using OpInfoMap = std::unordered_map<std::string, OpInfo>;
...@@ -102,18 +103,14 @@ class IR_API IrContext { ...@@ -102,18 +103,14 @@ class IR_API IrContext {
/// ///
/// \brief Register an op infomation to IrContext /// \brief Register an op infomation to IrContext
/// ///
void RegisterOpInfo( void RegisterOpInfo(Dialect *dialect,
Dialect *dialect, TypeId op_id,
TypeId op_id, const char *name,
const char *name, std::vector<InterfaceValue> &&interface_map,
std::vector<InterfaceValue> &&interface_map, const std::vector<TypeId> &trait_set,
const std::vector<TypeId> &trait_set, size_t attributes_num,
size_t attributes_num, const char **attributes_name,
const char **attributes_name, void (*verify)(Operation *));
void (*verify)(
const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes));
/// ///
/// \brief Get registered operaiton infomation. /// \brief Get registered operaiton infomation.
......
...@@ -78,6 +78,12 @@ class IR_API OpBase { ...@@ -78,6 +78,12 @@ class IR_API OpBase {
IrContext *ir_context() const { return operation_->ir_context(); } IrContext *ir_context() const { return operation_->ir_context(); }
uint32_t num_results() const { return operation_->num_results(); }
uint32_t num_operands() const { return operation_->num_operands(); }
const AttributeMap &attributes() const { return operation_->attributes(); }
private: private:
Operation *operation_; // Not owned Operation *operation_; // Not owned
}; };
...@@ -205,6 +211,16 @@ class Op : public OpBase { ...@@ -205,6 +211,16 @@ class Op : public OpBase {
ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait); ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait);
return trait_set; return trait_set;
} }
static constexpr bool HasNoDataMembers() {
class EmptyOp : public Op<EmptyOp, TraitOrInterface...> {};
return sizeof(ConcreteOp) == sizeof(EmptyOp);
}
static void VerifyInvariants(Operation *op) {
static_assert(HasNoDataMembers(),
"Op class shouldn't define new data members");
op->dyn_cast<ConcreteOp>().Verify();
}
}; };
} // namespace ir } // namespace ir
...@@ -35,11 +35,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; } ...@@ -35,11 +35,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); } TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); }
void OpInfo::Verify(const std::vector<OpResult> &inputs, void OpInfo::Verify(Operation *operation) const { impl_->verify()(operation); }
const std::vector<Type> &outputs,
const AttributeMap &attributes) {
impl_->verify()(inputs, outputs, attributes);
}
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const { void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr; return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr;
......
...@@ -25,6 +25,9 @@ class OpResult; ...@@ -25,6 +25,9 @@ class OpResult;
class Type; class Type;
class Attribute; class Attribute;
class Dialect; class Dialect;
class Operation;
typedef void (*VerifyPtr)(Operation *op);
class IR_API OpInfo { class IR_API OpInfo {
public: public:
...@@ -49,9 +52,7 @@ class IR_API OpInfo { ...@@ -49,9 +52,7 @@ class IR_API OpInfo {
TypeId id() const; TypeId id() const;
void Verify(const std::vector<OpResult> &inputs, void Verify(Operation *) const;
const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes);
template <typename Trait> template <typename Trait>
bool HasTrait() const { bool HasTrait() const {
......
...@@ -25,9 +25,6 @@ ...@@ -25,9 +25,6 @@
namespace ir { namespace ir {
class Dialect; class Dialect;
typedef void (*VerifyPtr)(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const AttributeMap &attributes);
/// ///
/// \brief OpInfoImpl class. /// \brief OpInfoImpl class.
......
...@@ -46,10 +46,6 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs, ...@@ -46,10 +46,6 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
ir::OpInfo op_info, ir::OpInfo op_info,
size_t num_regions) { size_t num_regions) {
// 0. Verify
if (op_info) {
op_info.Verify(inputs, output_types, attributes);
}
// 1. Calculate the required memory size for OpResults + Operation + // 1. Calculate the required memory size for OpResults + Operation +
// OpOperands. // OpOperands.
uint32_t num_results = output_types.size(); uint32_t num_results = output_types.size();
...@@ -100,6 +96,11 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs, ...@@ -100,6 +96,11 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr += sizeof(Region); base_ptr += sizeof(Region);
} }
} }
// 0. Verify
if (op_info) {
op_info.Verify(op);
}
return op; return op;
} }
......
...@@ -45,9 +45,7 @@ class OperationTest ...@@ -45,9 +45,7 @@ class OperationTest
static const char *name() { return "test.operation2"; } static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, static void Verify() {}
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {}
static void InferShape(phi::InferMetaContext *infer_meta) { static void InferShape(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::CreateInferMeta); auto fn = PD_INFER_META(phi::CreateInferMeta);
fn(infer_meta); fn(infer_meta);
......
...@@ -90,9 +90,8 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -90,9 +90,8 @@ class Operation1 : public ir::Op<Operation1> {
static const char *name() { return "test.operation1"; } static const char *name() { return "test.operation1"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify() {
const std::vector<ir::Type> &outputs, auto &attributes = this->attributes();
const ir::AttributeMap &attributes) {
if (attributes.count("op1_attr1") == 0 || if (attributes.count("op1_attr1") == 0 ||
!attributes.at("op1_attr1").isa<ir::StrAttribute>()) { !attributes.at("op1_attr1").isa<ir::StrAttribute>()) {
throw("Type of attribute: parameter_name is not right."); throw("Type of attribute: parameter_name is not right.");
...@@ -133,9 +132,8 @@ class Operation2 ...@@ -133,9 +132,8 @@ class Operation2
static const char *name() { return "test.operation2"; } static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify() {
const std::vector<ir::Type> &outputs, auto &attributes = this->attributes();
const ir::AttributeMap &attributes) {
if (attributes.count("op2_attr1") == 0 || if (attributes.count("op2_attr1") == 0 ||
(!attributes.at("op2_attr1").isa<ir::StrAttribute>())) { (!attributes.at("op2_attr1").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right."); throw("Type of attribute: parameter_name is not right.");
......
...@@ -38,22 +38,21 @@ class AddOp : public ir::Op<AddOp> { ...@@ -38,22 +38,21 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; } static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
if (inputs.size() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
static void Build(ir::Builder &builder, // NOLINT static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT ir::OperationArgument &argument, // NOLINT
ir::OpResult l_operand, ir::OpResult l_operand,
ir::OpResult r_operand, ir::OpResult r_operand,
ir::Type sum_type); ir::Type sum_type);
}; };
void AddOp::Verify() {
if (num_operands() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (num_results() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
void AddOp::Build(ir::Builder &, void AddOp::Build(ir::Builder &,
ir::OperationArgument &argument, ir::OperationArgument &argument,
ir::OpResult l_operand, ir::OpResult l_operand,
...@@ -262,9 +261,9 @@ TEST(program_test, builder) { ...@@ -262,9 +261,9 @@ TEST(program_test, builder) {
ir::Type full_op_output = full_op->result(0).type(); ir::Type full_op_output = full_op->result(0).type();
EXPECT_EQ(program.block()->size(), 1u); EXPECT_EQ(program.block()->size(), 1u);
EXPECT_EQ(program.block()->back(), full_op.operation()); EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op->num_operands(), 0u); EXPECT_EQ(full_op.num_operands(), 0u);
EXPECT_EQ(full_op->num_results(), 1u); EXPECT_EQ(full_op.num_results(), 1u);
EXPECT_EQ(full_op->attributes().size(), 4u); EXPECT_EQ(full_op.attributes().size(), 4u);
EXPECT_EQ( EXPECT_EQ(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0, full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0,
true); true);
......
...@@ -65,22 +65,21 @@ class AddOp : public ir::Op<AddOp> { ...@@ -65,22 +65,21 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; } static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr; static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0; static constexpr uint32_t attributes_num = 0;
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
if (inputs.size() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
static void Build(ir::Builder &builder, // NOLINT static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT ir::OperationArgument &argument, // NOLINT
ir::OpResult l_operand, ir::OpResult l_operand,
ir::OpResult r_operand, ir::OpResult r_operand,
ir::Type sum_type); ir::Type sum_type);
}; };
void AddOp::Verify() {
if (num_operands() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (num_results() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
void AddOp::Build(ir::Builder &, void AddOp::Build(ir::Builder &,
ir::OperationArgument &argument, ir::OperationArgument &argument,
ir::OpResult l_operand, ir::OpResult l_operand,
......
...@@ -48,20 +48,20 @@ class Operation1 : public ir::Op<Operation1> { ...@@ -48,20 +48,20 @@ class Operation1 : public ir::Op<Operation1> {
static const char *name() { return "test.Operation1"; } static const char *name() { return "test.Operation1"; }
static constexpr uint32_t attributes_num = 2; static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs, void Verify();
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
if (attributes.count("op2_attr1") == 0 ||
(!attributes.at("op2_attr1").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
if (attributes.count("op2_attr2") == 0 ||
(!attributes.at("op2_attr2").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
}
static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; } static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
}; };
void Operation1::Verify() {
auto &attributes = this->attributes();
if (attributes.count("op2_attr1") == 0 ||
(!attributes.at("op2_attr1").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
if (attributes.count("op2_attr2") == 0 ||
(!attributes.at("op2_attr2").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
}
const char *Operation1::attributes_name[attributes_num] = {"op2_attr1", const char *Operation1::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"}; "op2_attr2"};
IR_DECLARE_EXPLICIT_TYPE_ID(Operation1) IR_DECLARE_EXPLICIT_TYPE_ID(Operation1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册