未验证 提交 96652265 编写于 作者: W winter-wang 提交者: GitHub

[IR] rectify the verify api (#54895)

上级 e49c17d2
......@@ -20,9 +20,7 @@ namespace dialect {
const char *PhiKernelOp::attributes_name[attributes_num] = {
"base_op", "infermeta_fn", "kernel_fn"};
void PhiKernelOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void PhiKernelOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: PhiKernelOp.";
// Verify inputs type:
......
......@@ -26,9 +26,7 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> {
static const char *name() { return "phi.kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
void Verify();
};
} // namespace dialect
......
......@@ -16,6 +16,7 @@ import argparse
import os
import yaml
from op_verify_gen import gen_verify_func_str
# =====================================
# String Template for h file code gen
......@@ -65,7 +66,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static OpInfoTuple GetOpInfo();
static void Build({build_args});
{build_mutable_attr_is_input}
static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
void Verify();
{get_inputs_and_outputs}
{exclusive_interface}
}};
......@@ -141,105 +142,6 @@ void {op_name}::Build({build_args}) {{
{build_outputs}
}}
"""
# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}.";
// Verify inputs type:
PADDLE_ENFORCE_EQ(inputs.size(), {inputs_size},
phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", inputs.size()));
{inputs_type_check}
// Verify outputs type:
PADDLE_ENFORCE_EQ(outputs.size(), {outputs_size},
phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", outputs.size()));
{outputs_type_check}
// Verify if attributes contain attribute name in attributes_name:
{attributes_check}
}}
"""
GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
(void)inputs;
(void)outputs;
(void)attributes;
}}
"""
INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (inputs[{index}]) {{
if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
"""
OUTPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
"""
ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}
"""
OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferShape( phi::InferMetaContext *infer_meta ) {{
auto fn = PD_INFER_META(phi::{infer_meta_func});
......@@ -1004,8 +906,8 @@ def GenBuildOutputs(
}}
"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().operation()->attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
......@@ -1557,134 +1459,17 @@ def OpGenerator(
view=view_str,
)
# =================================== #
# gen Verify func str #
# =================================== #
# generate op verify function: inputs_type_check_str
if (
len(op_input_type_list) + len(op_mutable_attribute_name_list)
) == 0:
inputs_type_check_str = (
"// Inputs num is 0, not need to check inputs type."
)
else:
inputs_type_check_str = ""
for idx in range(len(op_input_type_list)):
input_type = op_input_type_list[idx]
is_optional = op_input_optional_list[idx]
is_vector = False
if input_type.startswith("ir::VectorType<"):
is_vector = True
input_type = input_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = (
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
)
else:
check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
if is_vector:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
inputs_type_check_str += check_str
for idx in range(len(op_mutable_attribute_name_list)):
mutable_attribute_type = op_mutable_attribute_type_list[idx][0]
check_str = ""
if mutable_attribute_type == "paddle::dialect::ScalarAttribute":
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
else:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
inputs_type_check_str += check_str
# generate op verify function: outputs_type_check_str
if len(op_output_type_list) == 0:
outputs_type_check_str = (
"// Outputs num is 0, not need to check outputs type."
)
else:
outputs_type_check_str = ""
for idx in range(len(op_output_type_list)):
output_type = op_output_type_list[idx]
is_optional = op_output_optional_list[idx]
is_vector = False
if output_type.startswith("ir::VectorType<"):
is_vector = True
output_type = output_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = (
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
)
else:
check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
if is_vector:
check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
outputs_type_check_str += check_str
# generate op verify function: attributes_check_str
if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = (
"// Attributes num is 0, not need to check attributes type."
)
else:
attributes_check_str = ""
for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_name = op_non_mutable_attribute_name_list[idx]
attribute_type = op_non_mutable_attribute_type_list[idx]
if attribute_type.startswith("ir::ArrayAttribute<"):
attribute_type = attribute_type[19:-1]
attributes_check_str += (
ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format(
attribute_name=attribute_name,
standard=attribute_type,
)
)
else:
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type
)
# generate op verify function
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
)
else:
op_verify_str = OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
inputs_size=len(op_input_type_list)
+ len(op_mutable_attribute_type_list),
outputs_size=len(op_output_type_list),
inputs_type_check=inputs_type_check_str,
outputs_type_check=outputs_type_check_str,
attributes_check=attributes_check_str,
# generate op verify function str
op_verify_str = gen_verify_func_str(
op_class_name,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_output_type_list,
op_output_optional_list,
)
op_infer_shape_str = ""
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{
VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}.";
VLOG(4) << "Verifying inputs:";
{{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(input_size, {inputs_size}u,
phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check}
}}
VLOG(4) << "Verifying attributes:";
{{{attributes_check}
}}
VLOG(4) << "Verifying outputs:";
{{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(output_size, {outputs_size}u,
phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check}
}}
VLOG(4) << "End Verifying for: {op_name}.";
}}
"""
GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::Verify() {{}}
"""
INPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """
if (auto vec_type = (*this)->operand({index}).type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}"""
ATTRIBUTE_CHECK_TEMPLATE = """
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}"""
OUTPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->result({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """
auto output_{index}_type = (*this)->result({index}).type();
if (auto vec_type = output_{index}_type.dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
else {{
PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto output_{index} = (*this)->result({index})) {{
PADDLE_ENFORCE(output_{index}.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto output_{index}_type = (*this)->result({index}).type()) {{
if (auto vec_type = output_{index}_type.dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
else {{
PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}"""
# generate inputs_type_check_str
def gen_inputs_type_check_str(
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
):
if (len(op_input_type_list) + len(op_mutable_attribute_name_list)) == 0:
inputs_type_check_str = """
// Inputs num is 0, not need to check inputs type."""
else:
inputs_type_check_str = ""
for idx in range(len(op_input_type_list)):
input_type = op_input_type_list[idx]
is_optional = op_input_optional_list[idx]
is_vector = False
if input_type.startswith("ir::VectorType<"):
is_vector = True
input_type = input_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
if is_vector:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
else:
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
inputs_type_check_str += check_str
for idx in range(len(op_mutable_attribute_name_list)):
mutable_attribute_type = op_mutable_attribute_type_list[idx][0]
check_str = ""
if mutable_attribute_type == "paddle::dialect::ScalarAttribute":
check_str = INPUT_TYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
else:
check_str = INPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx + len(op_input_type_list),
standard="paddle::dialect::DenseTensorType",
)
inputs_type_check_str += check_str
return inputs_type_check_str
# generate attributes_check_str
def gen_attributes_type_check_str(
op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list
):
if len(op_non_mutable_attribute_name_list) == 0:
attributes_check_str = """
// Attributes num is 0, not need to check attributes type."""
else:
attributes_check_str = """
auto& attributes = this->attributes();"""
for idx in range(len(op_non_mutable_attribute_name_list)):
attribute_name = op_non_mutable_attribute_name_list[idx]
attribute_type = op_non_mutable_attribute_type_list[idx]
if attribute_type.startswith("ir::ArrayAttribute<"):
attribute_type = attribute_type[19:-1]
attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format(
attribute_name=attribute_name,
standard=attribute_type,
)
else:
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type
)
return attributes_check_str
# generate outputs_type_check_str
def gen_outputs_type_check_str(op_output_type_list, op_output_optional_list):
if len(op_output_type_list) == 0:
outputs_type_check_str = """
// Outputs num is 0, not need to check outputs type."""
else:
outputs_type_check_str = ""
for idx in range(len(op_output_type_list)):
output_type = op_output_type_list[idx]
is_optional = op_output_optional_list[idx]
is_vector = False
if output_type.startswith("ir::VectorType<"):
is_vector = True
output_type = output_type[15:-1]
check_str = ""
if is_optional == "true":
if is_vector:
check_str = OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
check_str = OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
if is_vector:
check_str = OUTPUT_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
else:
check_str = OUTPUT_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=output_type
)
outputs_type_check_str += check_str
return outputs_type_check_str
# generate op verify function
def gen_verify_func_str(
op_class_name,
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_output_type_list,
op_output_optional_list,
):
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
return GRAD_OP_VERIFY_TEMPLATE.format(op_name=op_class_name)
inputs_type_check_str = gen_inputs_type_check_str(
op_input_type_list,
op_input_optional_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
)
attributes_type_check_str = gen_attributes_type_check_str(
op_non_mutable_attribute_name_list, op_non_mutable_attribute_type_list
)
outputs_type_check_str = gen_outputs_type_check_str(
op_output_type_list, op_output_optional_list
)
return OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
inputs_size=len(op_input_type_list)
+ len(op_mutable_attribute_type_list),
inputs_type_check=inputs_type_check_str,
attributes_check=attributes_type_check_str,
outputs_size=len(op_output_type_list),
outputs_type_check=outputs_type_check_str,
)
......@@ -23,7 +23,7 @@ namespace ir {
const char *ModuleOp::attributes_name[attributes_num] = {"program"};
Program *ModuleOp::program() {
const AttributeMap &attr = operation()->attributes();
const AttributeMap &attr = this->attributes();
auto iter = attr.find("program");
if (iter == attr.end() || !iter->second) return nullptr;
return static_cast<Program *>(
......@@ -52,20 +52,19 @@ void ModuleOp::Destroy() {
}
}
void ModuleOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void ModuleOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: ModuleOp.";
// Verify inputs type:
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name:
// Verify attributes:
auto &attributes = this->attributes();
auto iter = attributes.find("program");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<PointerAttribute>(),
"Type of attribute: program is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
// Verify outputs:
IR_ENFORCE(num_results() == 0u, "The size of inputs must be equal to 0.");
}
const char *GetParameterOp::attributes_name[attributes_num] = {
......@@ -80,20 +79,19 @@ void GetParameterOp::Build(Builder &builder,
argument.output_types.emplace_back(type);
}
void GetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void GetParameterOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: GetParameterOp.";
// Verify inputs type:
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
// Verify inputs:
IR_ENFORCE(num_operands() == 0u, "The size of inputs must be equal to 0.");
// Verify if attributes contain attribute name in attributes_name:
auto &attributes = this->attributes();
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
}
const char *SetParameterOp::attributes_name[attributes_num] = {
......@@ -107,20 +105,19 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
argument.AddAttribute(attributes_name[0],
ir::StrAttribute::get(builder.ir_context(), name));
}
void SetParameterOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void SetParameterOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: SetParameterOp.";
// Verify inputs type:
IR_ENFORCE(inputs.size() == 1, "The size of outputs must be equal to 1.");
// Verify inputs:
IR_ENFORCE(num_operands() == 1, "The size of outputs must be equal to 1.");
// Verify if attributes contain attribute name in attributes_name:
// Verify attributes:
auto &attributes = this->attributes();
auto iter = attributes.find("parameter_name");
IR_ENFORCE(iter != attributes.end() && iter->second.isa<StrAttribute>(),
"Type of attribute: parameter_name is not right.");
// Verify outputs type:
IR_ENFORCE(outputs.size() == 0, "The size of outputs must be equal to 0.");
// Verify outputs:
IR_ENFORCE(num_results() == 0u, "The size of outputs must be equal to 0.");
}
void CombineOp::Build(Builder &builder,
......@@ -135,58 +132,56 @@ void CombineOp::Build(Builder &builder,
ir::VectorType::get(builder.ir_context(), inputs_type));
}
void CombineOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void CombineOp::Verify() {
// outputs.size() == 1
IR_ENFORCE(outputs.size() == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());
IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1.");
// output_type == Vector<Type>
auto output_type = (*this)->result(0).type().dyn_cast<VectorType>();
IR_ENFORCE(output_type,
"The type of outputs[0] must be equal to VectorType.");
// outputs[0].type == Vector<Type>
IR_ENFORCE(outputs[0].isa<ir::VectorType>(),
"The type %s of outputs[0] must be equal to VectorType.",
outputs[0]);
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size()
IR_ENFORCE(output_type.size() == inputs.size(),
"The size %d of outputs[0] must be equal to size %d of inputs.",
auto input_num = num_operands();
IR_ENFORCE(output_type.size() == input_num,
"The size %d of output must be equal to size %d of inputs.",
output_type.size(),
inputs.size());
input_num);
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) {
IR_ENFORCE(output_type[i] == inputs[i].type(),
for (size_t i = 0; i < input_num; ++i) {
auto type = (*this)->operand(i).type();
IR_ENFORCE(output_type[i] == type,
"The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].",
output_type[i],
i,
inputs[i].type(),
type,
i);
}
}
const char *SliceOp::attributes_name[attributes_num] = {"index"};
void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void SliceOp::Verify() {
// inputs.size() == 1
IR_ENFORCE(inputs.size() == 1,
"The size %d of inputs must be equal to 1.",
inputs.size());
auto input_size = num_operands();
IR_ENFORCE(
input_size == 1, "The size %d of inputs must be equal to 1.", input_size);
// inputs[0].type == Vector<Type>
IR_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
auto input_type = (*this)->operand(0).type().dyn_cast<ir::VectorType>();
IR_ENFORCE(input_type,
"The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type());
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();
input_type);
auto output_size = num_results();
// outputs.size() == 1
IR_ENFORCE(outputs.size() == 1,
IR_ENFORCE(output_size == 1,
"The size %d of outputs must be equal to 1.",
outputs.size());
output_size);
// attributes contains index: Int32
auto &attributes = this->attributes();
IR_ENFORCE(attributes.count("index") != 0,
"The attributes must contains index.");
const ir::Attribute &attr = attributes.at("index");
......@@ -203,12 +198,13 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
input_type.size());
// inputs[index].type == outputs[0].type
auto output_type = (*this)->result(0).type();
IR_ENFORCE(
input_type[index] == outputs[0],
input_type[index] == output_type,
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index],
index,
outputs[0]);
output_type);
}
const char *ConstantOp::attributes_name[attributes_num] = {"value"};
......@@ -221,16 +217,13 @@ void ConstantOp::Build(Builder &builder,
argument.output_types.push_back(output_type);
}
void ConstantOp::Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
IR_ENFORCE(inputs.size() == 0, "The size of inputs must be equal to 0.");
IR_ENFORCE(outputs.size() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(attributes.count("value") > 0,
"Type of attribute: value is not right.");
void ConstantOp::Verify() {
IR_ENFORCE(num_operands() == 0, "The size of inputs must be equal to 0.");
IR_ENFORCE(num_results() == 1, "The size of outputs must be equal to 1.");
IR_ENFORCE(attributes().count("value") > 0, "must has value attribute");
}
Attribute ConstantOp::value() { return operation()->attributes().at("value"); }
Attribute ConstantOp::value() { return attributes().at("value"); }
} // namespace ir
......
......@@ -30,10 +30,7 @@ class IR_API ModuleOp : public ir::Op<ModuleOp> {
static const char *name() { return "builtin.module"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
void Verify();
Program *program();
Block *block();
......@@ -58,9 +55,7 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
OperationArgument &argument, // NOLINT
const std::string &name,
Type type);
static void Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const ir::AttributeMap &attributes);
void Verify();
};
///
......@@ -77,9 +72,7 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
OperationArgument &argument, // NOLINT
OpResult parameter,
const std::string &name);
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
void Verify();
};
///
......@@ -99,9 +92,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
OperationArgument &argument, // NOLINT
const std::vector<ir::OpResult> &inputs);
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
void Verify();
};
///
......@@ -116,9 +107,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
void Verify();
};
class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
......@@ -143,9 +132,7 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
Attribute value,
Type output_type);
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const AttributeMap &attributes);
void Verify();
Attribute value();
};
......
......@@ -100,7 +100,7 @@ class IR_API Dialect {
ConcreteOp::GetTraitSet(),
ConcreteOp::attributes_num,
ConcreteOp::attributes_name,
ConcreteOp::Verify);
ConcreteOp::VerifyInvariants);
}
void RegisterOp(const std::string &name, OpInfoImpl *op_info);
......
......@@ -32,6 +32,7 @@ class InterfaceValue;
class Type;
class OpResult;
class Attribute;
class Operation;
using OpInfoMap = std::unordered_map<std::string, OpInfo>;
......@@ -102,18 +103,14 @@ class IR_API IrContext {
///
/// \brief Register an op infomation to IrContext
///
void RegisterOpInfo(
Dialect *dialect,
void RegisterOpInfo(Dialect *dialect,
TypeId op_id,
const char *name,
std::vector<InterfaceValue> &&interface_map,
const std::vector<TypeId> &trait_set,
size_t attributes_num,
const char **attributes_name,
void (*verify)(
const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes));
void (*verify)(Operation *));
///
/// \brief Get registered operaiton infomation.
......
......@@ -78,6 +78,12 @@ class IR_API OpBase {
IrContext *ir_context() const { return operation_->ir_context(); }
uint32_t num_results() const { return operation_->num_results(); }
uint32_t num_operands() const { return operation_->num_operands(); }
const AttributeMap &attributes() const { return operation_->attributes(); }
private:
Operation *operation_; // Not owned
};
......@@ -205,6 +211,16 @@ class Op : public OpBase {
ConstructInterfacesOrTraits<ConcreteOp, TraitList>::trait(p_first_trait);
return trait_set;
}
static constexpr bool HasNoDataMembers() {
class EmptyOp : public Op<EmptyOp, TraitOrInterface...> {};
return sizeof(ConcreteOp) == sizeof(EmptyOp);
}
static void VerifyInvariants(Operation *op) {
static_assert(HasNoDataMembers(),
"Op class shouldn't define new data members");
op->dyn_cast<ConcreteOp>().Verify();
}
};
} // namespace ir
......@@ -35,11 +35,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
TypeId OpInfo::id() const { return impl_ ? impl_->id() : TypeId(); }
void OpInfo::Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const AttributeMap &attributes) {
impl_->verify()(inputs, outputs, attributes);
}
void OpInfo::Verify(Operation *operation) const { impl_->verify()(operation); }
void *OpInfo::GetInterfaceImpl(TypeId interface_id) const {
return impl_ ? impl_->GetInterfaceImpl(interface_id) : nullptr;
......
......@@ -25,6 +25,9 @@ class OpResult;
class Type;
class Attribute;
class Dialect;
class Operation;
typedef void (*VerifyPtr)(Operation *op);
class IR_API OpInfo {
public:
......@@ -49,9 +52,7 @@ class IR_API OpInfo {
TypeId id() const;
void Verify(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const std::unordered_map<std::string, Attribute> &attributes);
void Verify(Operation *) const;
template <typename Trait>
bool HasTrait() const {
......
......@@ -25,9 +25,6 @@
namespace ir {
class Dialect;
typedef void (*VerifyPtr)(const std::vector<OpResult> &inputs,
const std::vector<Type> &outputs,
const AttributeMap &attributes);
///
/// \brief OpInfoImpl class.
......
......@@ -46,10 +46,6 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &output_types,
ir::OpInfo op_info,
size_t num_regions) {
// 0. Verify
if (op_info) {
op_info.Verify(inputs, output_types, attributes);
}
// 1. Calculate the required memory size for OpResults + Operation +
// OpOperands.
uint32_t num_results = output_types.size();
......@@ -100,6 +96,11 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr += sizeof(Region);
}
}
// 0. Verify
if (op_info) {
op_info.Verify(op);
}
return op;
}
......
......@@ -45,9 +45,7 @@ class OperationTest
static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {}
static void Verify() {}
static void InferShape(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::CreateInferMeta);
fn(infer_meta);
......
......@@ -90,9 +90,8 @@ class Operation1 : public ir::Op<Operation1> {
static const char *name() { return "test.operation1"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void Verify() {
auto &attributes = this->attributes();
if (attributes.count("op1_attr1") == 0 ||
!attributes.at("op1_attr1").isa<ir::StrAttribute>()) {
throw("Type of attribute: parameter_name is not right.");
......@@ -133,9 +132,8 @@ class Operation2
static const char *name() { return "test.operation2"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void Verify() {
auto &attributes = this->attributes();
if (attributes.count("op2_attr1") == 0 ||
(!attributes.at("op2_attr1").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
......
......@@ -38,22 +38,21 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
if (inputs.size() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
void Verify();
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult l_operand,
ir::OpResult r_operand,
ir::Type sum_type);
};
void AddOp::Verify() {
if (num_operands() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (num_results() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
void AddOp::Build(ir::Builder &,
ir::OperationArgument &argument,
ir::OpResult l_operand,
......@@ -262,9 +261,9 @@ TEST(program_test, builder) {
ir::Type full_op_output = full_op->result(0).type();
EXPECT_EQ(program.block()->size(), 1u);
EXPECT_EQ(program.block()->back(), full_op.operation());
EXPECT_EQ(full_op->num_operands(), 0u);
EXPECT_EQ(full_op->num_results(), 1u);
EXPECT_EQ(full_op->attributes().size(), 4u);
EXPECT_EQ(full_op.num_operands(), 0u);
EXPECT_EQ(full_op.num_results(), 1u);
EXPECT_EQ(full_op.attributes().size(), 4u);
EXPECT_EQ(
full_op_output.dyn_cast<paddle::dialect::DenseTensorType>().offset() == 0,
true);
......
......@@ -65,22 +65,21 @@ class AddOp : public ir::Op<AddOp> {
static const char *name() { return "test.add"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
if (inputs.size() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (outputs.size() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
void Verify();
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult l_operand,
ir::OpResult r_operand,
ir::Type sum_type);
};
void AddOp::Verify() {
if (num_operands() != 2) {
throw("The size of inputs must be equal to 2.");
}
if (num_results() != 1) {
throw("The size of outputs must be equal to 1.");
}
}
void AddOp::Build(ir::Builder &,
ir::OperationArgument &argument,
ir::OpResult l_operand,
......
......@@ -48,9 +48,11 @@ class Operation1 : public ir::Op<Operation1> {
static const char *name() { return "test.Operation1"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
void Verify();
static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
};
void Operation1::Verify() {
auto &attributes = this->attributes();
if (attributes.count("op2_attr1") == 0 ||
(!attributes.at("op2_attr1").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
......@@ -59,9 +61,7 @@ class Operation1 : public ir::Op<Operation1> {
(!attributes.at("op2_attr2").isa<ir::StrAttribute>())) {
throw("Type of attribute: parameter_name is not right.");
}
}
static void InferShape() { VLOG(2) << "This is op2's InferShape interface."; }
};
}
const char *Operation1::attributes_name[attributes_num] = {"op2_attr1",
"op2_attr2"};
IR_DECLARE_EXPLICIT_TYPE_ID(Operation1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册