Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
96652265
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
96652265
编写于
6月 27, 2023
作者:
W
winter-wang
提交者:
GitHub
6月 27, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IR] rectify the verify api (#54895)
上级
e49c17d2
变更
18
显示空白变更内容
内联
并排
Showing
18 changed file
with
427 addition
and
389 deletion
+427
-389
paddle/fluid/ir/dialect/kernel_op.cc
paddle/fluid/ir/dialect/kernel_op.cc
+1
-3
paddle/fluid/ir/dialect/kernel_op.h
paddle/fluid/ir/dialect/kernel_op.h
+1
-3
paddle/fluid/ir/dialect/op_gen.py
paddle/fluid/ir/dialect/op_gen.py
+16
-231
paddle/fluid/ir/dialect/op_verify_gen.py
paddle/fluid/ir/dialect/op_verify_gen.py
+275
-0
paddle/ir/core/builtin_op.cc
paddle/ir/core/builtin_op.cc
+54
-61
paddle/ir/core/builtin_op.h
paddle/ir/core/builtin_op.h
+6
-19
paddle/ir/core/dialect.h
paddle/ir/core/dialect.h
+1
-1
paddle/ir/core/ir_context.h
paddle/ir/core/ir_context.h
+9
-12
paddle/ir/core/op_base.h
paddle/ir/core/op_base.h
+16
-0
paddle/ir/core/op_info.cc
paddle/ir/core/op_info.cc
+1
-5
paddle/ir/core/op_info.h
paddle/ir/core/op_info.h
+4
-3
paddle/ir/core/op_info_impl.h
paddle/ir/core/op_info_impl.h
+0
-3
paddle/ir/core/operation.cc
paddle/ir/core/operation.cc
+5
-4
test/cpp/ir/core/ir_infershape_test.cc
test/cpp/ir/core/ir_infershape_test.cc
+1
-3
test/cpp/ir/core/ir_op_test.cc
test/cpp/ir/core/ir_op_test.cc
+4
-6
test/cpp/ir/core/ir_program_test.cc
test/cpp/ir/core/ir_program_test.cc
+12
-13
test/cpp/ir/pass/pass_manager_test.cc
test/cpp/ir/pass/pass_manager_test.cc
+9
-10
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
+12
-12
未找到文件。
paddle/fluid/ir/dialect/kernel_op.cc
浏览文件 @
96652265
...
...
@@ -20,9 +20,7 @@ namespace dialect {
const
char
*
PhiKernelOp
::
attributes_name
[
attributes_num
]
=
{
"base_op"
,
"infermeta_fn"
,
"kernel_fn"
};
void
PhiKernelOp
::
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
PhiKernelOp
::
Verify
()
{
VLOG
(
4
)
<<
"Verifying inputs, outputs and attributes for: PhiKernelOp."
;
// Verify inputs type:
...
...
paddle/fluid/ir/dialect/kernel_op.h
浏览文件 @
96652265
...
...
@@ -26,9 +26,7 @@ class PhiKernelOp : public ir::Op<PhiKernelOp> {
static
const
char
*
name
()
{
return
"phi.kernel"
;
}
static
constexpr
uint32_t
attributes_num
=
3
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
);
void
Verify
();
};
}
// namespace dialect
...
...
paddle/fluid/ir/dialect/op_gen.py
浏览文件 @
96652265
...
...
@@ -16,6 +16,7 @@ import argparse
import
os
import
yaml
from
op_verify_gen
import
gen_verify_func_str
# =====================================
# String Template for h file code gen
...
...
@@ -65,7 +66,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static OpInfoTuple GetOpInfo();
static void Build({build_args});
{build_mutable_attr_is_input}
static void Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes
);
void Verify(
);
{get_inputs_and_outputs}
{exclusive_interface}
}};
...
...
@@ -141,105 +142,6 @@ void {op_name}::Build({build_args}) {{
{build_outputs}
}}
"""
# verify
OP_VERIFY_TEMPLATE
=
"""
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}.";
// Verify inputs type:
PADDLE_ENFORCE_EQ(inputs.size(), {inputs_size},
phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", inputs.size()));
{inputs_type_check}
// Verify outputs type:
PADDLE_ENFORCE_EQ(outputs.size(), {outputs_size},
phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", outputs.size()));
{outputs_type_check}
// Verify if attributes contain attribute name in attributes_name:
{attributes_check}
}}
"""
GRAD_OP_VERIFY_TEMPLATE
=
"""
void {op_name}::Verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
(void)inputs;
(void)outputs;
(void)attributes;
}}
"""
INPUT_TYPE_CHECK_TEMPLATE
=
"""PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
"""
INPUT_VECTORTYPE_CHECK_TEMPLATE
=
"""if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
=
"""if (inputs[{index}]) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
=
"""if (inputs[{index}]) {{
if (inputs[{index}].type().isa<ir::VectorType>()) {{
for (size_t i = 0; i < inputs[{index}].type().dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}} else {{
PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
"""
OUTPUT_TYPE_CHECK_TEMPLATE
=
"""PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE
=
"""if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
=
"""if (outputs[{index}]) {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
=
"""if (outputs[{index}]) {{
if (outputs[{index}].isa<ir::VectorType>()) {{
for (size_t i = 0; i < outputs[{index}].dyn_cast<ir::VectorType>().size(); i++) {{
PADDLE_ENFORCE_EQ(outputs[{index}].dyn_cast<ir::VectorType>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}} else {{
PADDLE_ENFORCE_EQ(outputs[{index}].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
"""
ATTRIBUTE_CHECK_TEMPLATE
=
"""PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE
=
"""PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}
"""
OP_INFER_SHAPE_TEMPLATE
=
"""
void {op_name}::InferShape( phi::InferMetaContext *infer_meta ) {{
auto fn = PD_INFER_META(phi::{infer_meta_func});
...
...
@@ -1004,8 +906,8 @@ def GenBuildOutputs(
}}
"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE
=
""" std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().
operation()->
attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};
\n
"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE
=
""" {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().
operation()->
attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};
\n
"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE
=
""" std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullIntArrayOp>().attributes().at("value").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};
\n
"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE
=
""" {dtype} {name} = {name}_.owner()->dyn_cast<paddle::dialect::FullOp>().attributes().at("value").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};
\n
"""
CREATE_OUTPUT_METATENSOR_TEMPLATE
=
""" phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
...
...
@@ -1557,134 +1459,17 @@ def OpGenerator(
view
=
view_str
,
)
# =================================== #
# gen Verify func str #
# =================================== #
# generate op verify function: inputs_type_check_str
if
(
len
(
op_input_type_list
)
+
len
(
op_mutable_attribute_name_list
)
)
==
0
:
inputs_type_check_str
=
(
"// Inputs num is 0, not need to check inputs type."
)
else
:
inputs_type_check_str
=
""
for
idx
in
range
(
len
(
op_input_type_list
)):
input_type
=
op_input_type_list
[
idx
]
is_optional
=
op_input_optional_list
[
idx
]
is_vector
=
False
if
input_type
.
startswith
(
"ir::VectorType<"
):
is_vector
=
True
input_type
=
input_type
[
15
:
-
1
]
check_str
=
""
if
is_optional
==
"true"
:
if
is_vector
:
check_str
=
(
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
)
else
:
check_str
=
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
else
:
if
is_vector
:
check_str
=
INPUT_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
else
:
check_str
=
INPUT_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
inputs_type_check_str
+=
check_str
for
idx
in
range
(
len
(
op_mutable_attribute_name_list
)):
mutable_attribute_type
=
op_mutable_attribute_type_list
[
idx
][
0
]
check_str
=
""
if
mutable_attribute_type
==
"paddle::dialect::ScalarAttribute"
:
check_str
=
INPUT_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
+
len
(
op_input_type_list
),
standard
=
"paddle::dialect::DenseTensorType"
,
)
else
:
check_str
=
INPUT_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
+
len
(
op_input_type_list
),
standard
=
"paddle::dialect::DenseTensorType"
,
)
inputs_type_check_str
+=
check_str
# generate op verify function: outputs_type_check_str
if
len
(
op_output_type_list
)
==
0
:
outputs_type_check_str
=
(
"// Outputs num is 0, not need to check outputs type."
)
else
:
outputs_type_check_str
=
""
for
idx
in
range
(
len
(
op_output_type_list
)):
output_type
=
op_output_type_list
[
idx
]
is_optional
=
op_output_optional_list
[
idx
]
is_vector
=
False
if
output_type
.
startswith
(
"ir::VectorType<"
):
is_vector
=
True
output_type
=
output_type
[
15
:
-
1
]
check_str
=
""
if
is_optional
==
"true"
:
if
is_vector
:
check_str
=
(
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
)
else
:
check_str
=
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
else
:
if
is_vector
:
check_str
=
OUTPUT_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
else
:
check_str
=
OUTPUT_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
outputs_type_check_str
+=
check_str
# generate op verify function: attributes_check_str
if
len
(
op_non_mutable_attribute_name_list
)
==
0
:
attributes_check_str
=
(
"// Attributes num is 0, not need to check attributes type."
)
else
:
attributes_check_str
=
""
for
idx
in
range
(
len
(
op_non_mutable_attribute_name_list
)):
attribute_name
=
op_non_mutable_attribute_name_list
[
idx
]
attribute_type
=
op_non_mutable_attribute_type_list
[
idx
]
if
attribute_type
.
startswith
(
"ir::ArrayAttribute<"
):
attribute_type
=
attribute_type
[
19
:
-
1
]
attributes_check_str
+=
(
ATTRIBUTE_VECTOR_CHECK_TEMPLATE
.
format
(
attribute_name
=
attribute_name
,
standard
=
attribute_type
,
)
)
else
:
attributes_check_str
+=
ATTRIBUTE_CHECK_TEMPLATE
.
format
(
attribute_name
=
attribute_name
,
standard
=
attribute_type
)
# generate op verify function
if
"GradOp"
in
op_class_name
or
"Grad_Op"
in
op_class_name
:
op_verify_str
=
GRAD_OP_VERIFY_TEMPLATE
.
format
(
op_name
=
op_class_name
,
)
else
:
op_verify_str
=
OP_VERIFY_TEMPLATE
.
format
(
op_name
=
op_class_name
,
inputs_size
=
len
(
op_input_type_list
)
+
len
(
op_mutable_attribute_type_list
),
outputs_size
=
len
(
op_output_type_list
),
inputs_type_check
=
inputs_type_check_str
,
outputs_type_check
=
outputs_type_check_str
,
attributes_check
=
attributes_check_str
,
# generate op verify function str
op_verify_str
=
gen_verify_func_str
(
op_class_name
,
op_input_type_list
,
op_input_optional_list
,
op_mutable_attribute_name_list
,
op_mutable_attribute_type_list
,
op_non_mutable_attribute_name_list
,
op_non_mutable_attribute_type_list
,
op_output_type_list
,
op_output_optional_list
,
)
op_infer_shape_str
=
""
...
...
paddle/fluid/ir/dialect/op_verify_gen.py
0 → 100644
浏览文件 @
96652265
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# verify
OP_VERIFY_TEMPLATE
=
"""
void {op_name}::Verify() {{
VLOG(4) << "Start Verifying inputs, outputs and attributes for: {op_name}.";
VLOG(4) << "Verifying inputs:";
{{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(input_size, {inputs_size}u,
phi::errors::PreconditionNotMet("The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check}
}}
VLOG(4) << "Verifying attributes:";
{{{attributes_check}
}}
VLOG(4) << "Verifying outputs:";
{{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(output_size, {outputs_size}u,
phi::errors::PreconditionNotMet("The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check}
}}
VLOG(4) << "End Verifying for: {op_name}.";
}}
"""
GRAD_OP_VERIFY_TEMPLATE
=
"""
void {op_name}::Verify() {{}}
"""
INPUT_TYPE_CHECK_TEMPLATE
=
"""
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE
=
"""
if (auto vec_type = (*this)->operand({index}).type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
=
"""
if (auto val = (*this)->operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
=
"""
if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}"""
ATTRIBUTE_CHECK_TEMPLATE
=
"""
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE
=
"""
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}"""
OUTPUT_TYPE_CHECK_TEMPLATE
=
"""
PADDLE_ENFORCE((*this)->result({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE
=
"""
auto output_{index}_type = (*this)->result({index}).type();
if (auto vec_type = output_{index}_type.dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
else {{
PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
=
"""
if (auto output_{index} = (*this)->result({index})) {{
PADDLE_ENFORCE(output_{index}.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
=
"""
if (auto output_{index}_type = (*this)->result({index}).type()) {{
if (auto vec_type = output_{index}_type.dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}
else {{
PADDLE_ENFORCE(output_{index}_type.isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th output."));
}}
}}"""
# generate inputs_type_check_str
def
gen_inputs_type_check_str
(
op_input_type_list
,
op_input_optional_list
,
op_mutable_attribute_name_list
,
op_mutable_attribute_type_list
,
):
if
(
len
(
op_input_type_list
)
+
len
(
op_mutable_attribute_name_list
))
==
0
:
inputs_type_check_str
=
"""
// Inputs num is 0, not need to check inputs type."""
else
:
inputs_type_check_str
=
""
for
idx
in
range
(
len
(
op_input_type_list
)):
input_type
=
op_input_type_list
[
idx
]
is_optional
=
op_input_optional_list
[
idx
]
is_vector
=
False
if
input_type
.
startswith
(
"ir::VectorType<"
):
is_vector
=
True
input_type
=
input_type
[
15
:
-
1
]
check_str
=
""
if
is_optional
==
"true"
:
if
is_vector
:
check_str
=
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
else
:
check_str
=
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
else
:
if
is_vector
:
check_str
=
INPUT_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
else
:
check_str
=
INPUT_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
input_type
)
inputs_type_check_str
+=
check_str
for
idx
in
range
(
len
(
op_mutable_attribute_name_list
)):
mutable_attribute_type
=
op_mutable_attribute_type_list
[
idx
][
0
]
check_str
=
""
if
mutable_attribute_type
==
"paddle::dialect::ScalarAttribute"
:
check_str
=
INPUT_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
+
len
(
op_input_type_list
),
standard
=
"paddle::dialect::DenseTensorType"
,
)
else
:
check_str
=
INPUT_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
+
len
(
op_input_type_list
),
standard
=
"paddle::dialect::DenseTensorType"
,
)
inputs_type_check_str
+=
check_str
return
inputs_type_check_str
# generate attributes_check_str
def
gen_attributes_type_check_str
(
op_non_mutable_attribute_name_list
,
op_non_mutable_attribute_type_list
):
if
len
(
op_non_mutable_attribute_name_list
)
==
0
:
attributes_check_str
=
"""
// Attributes num is 0, not need to check attributes type."""
else
:
attributes_check_str
=
"""
auto& attributes = this->attributes();"""
for
idx
in
range
(
len
(
op_non_mutable_attribute_name_list
)):
attribute_name
=
op_non_mutable_attribute_name_list
[
idx
]
attribute_type
=
op_non_mutable_attribute_type_list
[
idx
]
if
attribute_type
.
startswith
(
"ir::ArrayAttribute<"
):
attribute_type
=
attribute_type
[
19
:
-
1
]
attributes_check_str
+=
ATTRIBUTE_VECTOR_CHECK_TEMPLATE
.
format
(
attribute_name
=
attribute_name
,
standard
=
attribute_type
,
)
else
:
attributes_check_str
+=
ATTRIBUTE_CHECK_TEMPLATE
.
format
(
attribute_name
=
attribute_name
,
standard
=
attribute_type
)
return
attributes_check_str
# generate outputs_type_check_str
def
gen_outputs_type_check_str
(
op_output_type_list
,
op_output_optional_list
):
if
len
(
op_output_type_list
)
==
0
:
outputs_type_check_str
=
"""
// Outputs num is 0, not need to check outputs type."""
else
:
outputs_type_check_str
=
""
for
idx
in
range
(
len
(
op_output_type_list
)):
output_type
=
op_output_type_list
[
idx
]
is_optional
=
op_output_optional_list
[
idx
]
is_vector
=
False
if
output_type
.
startswith
(
"ir::VectorType<"
):
is_vector
=
True
output_type
=
output_type
[
15
:
-
1
]
check_str
=
""
if
is_optional
==
"true"
:
if
is_vector
:
check_str
=
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
else
:
check_str
=
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
else
:
if
is_vector
:
check_str
=
OUTPUT_VECTORTYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
else
:
check_str
=
OUTPUT_TYPE_CHECK_TEMPLATE
.
format
(
index
=
idx
,
standard
=
output_type
)
outputs_type_check_str
+=
check_str
return
outputs_type_check_str
# generate op verify function
def
gen_verify_func_str
(
op_class_name
,
op_input_type_list
,
op_input_optional_list
,
op_mutable_attribute_name_list
,
op_mutable_attribute_type_list
,
op_non_mutable_attribute_name_list
,
op_non_mutable_attribute_type_list
,
op_output_type_list
,
op_output_optional_list
,
):
if
"GradOp"
in
op_class_name
or
"Grad_Op"
in
op_class_name
:
return
GRAD_OP_VERIFY_TEMPLATE
.
format
(
op_name
=
op_class_name
)
inputs_type_check_str
=
gen_inputs_type_check_str
(
op_input_type_list
,
op_input_optional_list
,
op_mutable_attribute_name_list
,
op_mutable_attribute_type_list
,
)
attributes_type_check_str
=
gen_attributes_type_check_str
(
op_non_mutable_attribute_name_list
,
op_non_mutable_attribute_type_list
)
outputs_type_check_str
=
gen_outputs_type_check_str
(
op_output_type_list
,
op_output_optional_list
)
return
OP_VERIFY_TEMPLATE
.
format
(
op_name
=
op_class_name
,
inputs_size
=
len
(
op_input_type_list
)
+
len
(
op_mutable_attribute_type_list
),
inputs_type_check
=
inputs_type_check_str
,
attributes_check
=
attributes_type_check_str
,
outputs_size
=
len
(
op_output_type_list
),
outputs_type_check
=
outputs_type_check_str
,
)
paddle/ir/core/builtin_op.cc
浏览文件 @
96652265
...
...
@@ -23,7 +23,7 @@ namespace ir {
const
char
*
ModuleOp
::
attributes_name
[
attributes_num
]
=
{
"program"
};
Program
*
ModuleOp
::
program
()
{
const
AttributeMap
&
attr
=
operation
()
->
attributes
();
const
AttributeMap
&
attr
=
this
->
attributes
();
auto
iter
=
attr
.
find
(
"program"
);
if
(
iter
==
attr
.
end
()
||
!
iter
->
second
)
return
nullptr
;
return
static_cast
<
Program
*>
(
...
...
@@ -52,20 +52,19 @@ void ModuleOp::Destroy() {
}
}
void
ModuleOp
::
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
ModuleOp
::
Verify
()
{
VLOG
(
4
)
<<
"Verifying inputs, outputs and attributes for: ModuleOp."
;
// Verify inputs
type
:
IR_ENFORCE
(
inputs
.
size
()
==
0
,
"The size of inputs must be equal to 0."
);
// Verify inputs:
IR_ENFORCE
(
num_operands
()
==
0u
,
"The size of inputs must be equal to 0."
);
// Verify if attributes contain attribute name in attributes_name:
// Verify attributes:
auto
&
attributes
=
this
->
attributes
();
auto
iter
=
attributes
.
find
(
"program"
);
IR_ENFORCE
(
iter
!=
attributes
.
end
()
&&
iter
->
second
.
isa
<
PointerAttribute
>
(),
"Type of attribute: program is not right."
);
// Verify outputs
type
:
IR_ENFORCE
(
outputs
.
size
()
==
0
,
"The size of out
puts must be equal to 0."
);
// Verify outputs:
IR_ENFORCE
(
num_results
()
==
0u
,
"The size of in
puts must be equal to 0."
);
}
const
char
*
GetParameterOp
::
attributes_name
[
attributes_num
]
=
{
...
...
@@ -80,20 +79,19 @@ void GetParameterOp::Build(Builder &builder,
argument
.
output_types
.
emplace_back
(
type
);
}
void
GetParameterOp
::
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
GetParameterOp
::
Verify
()
{
VLOG
(
4
)
<<
"Verifying inputs, outputs and attributes for: GetParameterOp."
;
// Verify inputs
type
:
IR_ENFORCE
(
inputs
.
size
()
==
0
,
"The size of inputs must be equal to 0."
);
// Verify inputs:
IR_ENFORCE
(
num_operands
()
==
0u
,
"The size of inputs must be equal to 0."
);
// Verify if attributes contain attribute name in attributes_name:
auto
&
attributes
=
this
->
attributes
();
auto
iter
=
attributes
.
find
(
"parameter_name"
);
IR_ENFORCE
(
iter
!=
attributes
.
end
()
&&
iter
->
second
.
isa
<
StrAttribute
>
(),
"Type of attribute: parameter_name is not right."
);
// Verify outputs type:
IR_ENFORCE
(
outputs
.
size
()
==
1
,
"The size of outputs must be equal to 1."
);
IR_ENFORCE
(
num_results
()
==
1u
,
"The size of outputs must be equal to 1."
);
}
const
char
*
SetParameterOp
::
attributes_name
[
attributes_num
]
=
{
...
...
@@ -107,20 +105,19 @@ void SetParameterOp::Build(Builder &builder, // NOLINT
argument
.
AddAttribute
(
attributes_name
[
0
],
ir
::
StrAttribute
::
get
(
builder
.
ir_context
(),
name
));
}
void
SetParameterOp
::
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
SetParameterOp
::
Verify
()
{
VLOG
(
4
)
<<
"Verifying inputs, outputs and attributes for: SetParameterOp."
;
// Verify inputs
type
:
IR_ENFORCE
(
inputs
.
size
()
==
1
,
"The size of outputs must be equal to 1."
);
// Verify inputs:
IR_ENFORCE
(
num_operands
()
==
1
,
"The size of outputs must be equal to 1."
);
// Verify if attributes contain attribute name in attributes_name:
// Verify attributes:
auto
&
attributes
=
this
->
attributes
();
auto
iter
=
attributes
.
find
(
"parameter_name"
);
IR_ENFORCE
(
iter
!=
attributes
.
end
()
&&
iter
->
second
.
isa
<
StrAttribute
>
(),
"Type of attribute: parameter_name is not right."
);
// Verify outputs
type
:
IR_ENFORCE
(
outputs
.
size
()
==
0
,
"The size of outputs must be equal to 0."
);
// Verify outputs:
IR_ENFORCE
(
num_results
()
==
0u
,
"The size of outputs must be equal to 0."
);
}
void
CombineOp
::
Build
(
Builder
&
builder
,
...
...
@@ -135,58 +132,56 @@ void CombineOp::Build(Builder &builder,
ir
::
VectorType
::
get
(
builder
.
ir_context
(),
inputs_type
));
}
void
CombineOp
::
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
CombineOp
::
Verify
()
{
// outputs.size() == 1
IR_ENFORCE
(
outputs
.
size
()
==
1
,
"The size %d of outputs must be equal to 1."
,
outputs
.
size
());
IR_ENFORCE
(
num_results
()
==
1u
,
"The size of outputs must be equal to 1."
);
// output_type == Vector<Type>
auto
output_type
=
(
*
this
)
->
result
(
0
).
type
().
dyn_cast
<
VectorType
>
();
IR_ENFORCE
(
output_type
,
"The type of outputs[0] must be equal to VectorType."
);
// outputs[0].type == Vector<Type>
IR_ENFORCE
(
outputs
[
0
].
isa
<
ir
::
VectorType
>
(),
"The type %s of outputs[0] must be equal to VectorType."
,
outputs
[
0
]);
ir
::
VectorType
output_type
=
outputs
[
0
].
dyn_cast
<
ir
::
VectorType
>
();
// inputs.size() == outputs[0].size()
IR_ENFORCE
(
output_type
.
size
()
==
inputs
.
size
(),
"The size %d of outputs[0] must be equal to size %d of inputs."
,
auto
input_num
=
num_operands
();
IR_ENFORCE
(
output_type
.
size
()
==
input_num
,
"The size %d of output must be equal to size %d of inputs."
,
output_type
.
size
(),
input
s
.
size
()
);
input
_num
);
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
IR_ENFORCE
(
output_type
[
i
]
==
inputs
[
i
].
type
(),
for
(
size_t
i
=
0
;
i
<
input_num
;
++
i
)
{
auto
type
=
(
*
this
)
->
operand
(
i
).
type
();
IR_ENFORCE
(
output_type
[
i
]
==
type
,
"The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d]."
,
output_type
[
i
],
i
,
inputs
[
i
].
type
()
,
type
,
i
);
}
}
const
char
*
SliceOp
::
attributes_name
[
attributes_num
]
=
{
"index"
};
void
SliceOp
::
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
SliceOp
::
Verify
()
{
// inputs.size() == 1
IR_ENFORCE
(
inputs
.
size
()
==
1
,
"The size %d of inputs must be equal to 1."
,
inputs
.
size
()
);
auto
input_size
=
num_operands
();
IR_ENFORCE
(
input_size
==
1
,
"The size %d of inputs must be equal to 1."
,
input_size
);
// inputs[0].type == Vector<Type>
IR_ENFORCE
(
inputs
[
0
].
type
().
isa
<
ir
::
VectorType
>
(),
auto
input_type
=
(
*
this
)
->
operand
(
0
).
type
().
dyn_cast
<
ir
::
VectorType
>
();
IR_ENFORCE
(
input_type
,
"The type %s of inputs[0] must be equal to VectorType."
,
inputs
[
0
].
type
());
ir
::
VectorType
input_type
=
inputs
[
0
].
type
().
dyn_cast
<
ir
::
VectorType
>
();
input_type
);
auto
output_size
=
num_results
();
// outputs.size() == 1
IR_ENFORCE
(
output
s
.
size
()
==
1
,
IR_ENFORCE
(
output
_size
==
1
,
"The size %d of outputs must be equal to 1."
,
output
s
.
size
()
);
output
_size
);
// attributes contains index: Int32
auto
&
attributes
=
this
->
attributes
();
IR_ENFORCE
(
attributes
.
count
(
"index"
)
!=
0
,
"The attributes must contains index."
);
const
ir
::
Attribute
&
attr
=
attributes
.
at
(
"index"
);
...
...
@@ -203,12 +198,13 @@ void SliceOp::Verify(const std::vector<ir::OpResult> &inputs,
input_type
.
size
());
// inputs[index].type == outputs[0].type
auto
output_type
=
(
*
this
)
->
result
(
0
).
type
();
IR_ENFORCE
(
input_type
[
index
]
==
output
s
[
0
]
,
input_type
[
index
]
==
output
_type
,
"The type %s of inputs[%d] must be equal to type %s of outputs[0]."
,
input_type
[
index
],
index
,
output
s
[
0
]
);
output
_type
);
}
const
char
*
ConstantOp
::
attributes_name
[
attributes_num
]
=
{
"value"
};
...
...
@@ -221,16 +217,13 @@ void ConstantOp::Build(Builder &builder,
argument
.
output_types
.
push_back
(
output_type
);
}
void
ConstantOp
::
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
IR_ENFORCE
(
inputs
.
size
()
==
0
,
"The size of inputs must be equal to 0."
);
IR_ENFORCE
(
outputs
.
size
()
==
1
,
"The size of outputs must be equal to 1."
);
IR_ENFORCE
(
attributes
.
count
(
"value"
)
>
0
,
"Type of attribute: value is not right."
);
void
ConstantOp
::
Verify
()
{
IR_ENFORCE
(
num_operands
()
==
0
,
"The size of inputs must be equal to 0."
);
IR_ENFORCE
(
num_results
()
==
1
,
"The size of outputs must be equal to 1."
);
IR_ENFORCE
(
attributes
().
count
(
"value"
)
>
0
,
"must has value attribute"
);
}
Attribute
ConstantOp
::
value
()
{
return
operation
()
->
attributes
().
at
(
"value"
);
}
Attribute
ConstantOp
::
value
()
{
return
attributes
().
at
(
"value"
);
}
}
// namespace ir
...
...
paddle/ir/core/builtin_op.h
浏览文件 @
96652265
...
...
@@ -30,10 +30,7 @@ class IR_API ModuleOp : public ir::Op<ModuleOp> {
static
const
char
*
name
()
{
return
"builtin.module"
;
}
static
constexpr
uint32_t
attributes_num
=
1
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
);
void
Verify
();
Program
*
program
();
Block
*
block
();
...
...
@@ -58,9 +55,7 @@ class IR_API GetParameterOp : public ir::Op<GetParameterOp> {
OperationArgument
&
argument
,
// NOLINT
const
std
::
string
&
name
,
Type
type
);
static
void
Verify
(
const
std
::
vector
<
OpResult
>
&
inputs
,
const
std
::
vector
<
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
);
void
Verify
();
};
///
...
...
@@ -77,9 +72,7 @@ class IR_API SetParameterOp : public ir::Op<SetParameterOp> {
OperationArgument
&
argument
,
// NOLINT
OpResult
parameter
,
const
std
::
string
&
name
);
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
);
void
Verify
();
};
///
...
...
@@ -99,9 +92,7 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
OperationArgument
&
argument
,
// NOLINT
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
);
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
);
void
Verify
();
};
///
...
...
@@ -116,9 +107,7 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static
constexpr
uint32_t
attributes_num
=
1
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
);
void
Verify
();
};
class
IR_API
ConstantLikeTrait
:
public
OpTraitBase
<
ConstantLikeTrait
>
{
...
...
@@ -143,9 +132,7 @@ class IR_API ConstantOp : public Op<ConstantOp, ConstantLikeTrait> {
Attribute
value
,
Type
output_type
);
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
AttributeMap
&
attributes
);
void
Verify
();
Attribute
value
();
};
...
...
paddle/ir/core/dialect.h
浏览文件 @
96652265
...
...
@@ -100,7 +100,7 @@ class IR_API Dialect {
ConcreteOp
::
GetTraitSet
(),
ConcreteOp
::
attributes_num
,
ConcreteOp
::
attributes_name
,
ConcreteOp
::
Verify
);
ConcreteOp
::
Verify
Invariants
);
}
void
RegisterOp
(
const
std
::
string
&
name
,
OpInfoImpl
*
op_info
);
...
...
paddle/ir/core/ir_context.h
浏览文件 @
96652265
...
...
@@ -32,6 +32,7 @@ class InterfaceValue;
class
Type
;
class
OpResult
;
class
Attribute
;
class
Operation
;
using
OpInfoMap
=
std
::
unordered_map
<
std
::
string
,
OpInfo
>
;
...
...
@@ -102,18 +103,14 @@ class IR_API IrContext {
///
/// \brief Register an op infomation to IrContext
///
void
RegisterOpInfo
(
Dialect
*
dialect
,
void
RegisterOpInfo
(
Dialect
*
dialect
,
TypeId
op_id
,
const
char
*
name
,
std
::
vector
<
InterfaceValue
>
&&
interface_map
,
const
std
::
vector
<
TypeId
>
&
trait_set
,
size_t
attributes_num
,
const
char
**
attributes_name
,
void
(
*
verify
)(
const
std
::
vector
<
OpResult
>
&
inputs
,
const
std
::
vector
<
Type
>
&
outputs
,
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
attributes
));
void
(
*
verify
)(
Operation
*
));
///
/// \brief Get registered operaiton infomation.
...
...
paddle/ir/core/op_base.h
浏览文件 @
96652265
...
...
@@ -78,6 +78,12 @@ class IR_API OpBase {
IrContext
*
ir_context
()
const
{
return
operation_
->
ir_context
();
}
uint32_t
num_results
()
const
{
return
operation_
->
num_results
();
}
uint32_t
num_operands
()
const
{
return
operation_
->
num_operands
();
}
const
AttributeMap
&
attributes
()
const
{
return
operation_
->
attributes
();
}
private:
Operation
*
operation_
;
// Not owned
};
...
...
@@ -205,6 +211,16 @@ class Op : public OpBase {
ConstructInterfacesOrTraits
<
ConcreteOp
,
TraitList
>::
trait
(
p_first_trait
);
return
trait_set
;
}
static
constexpr
bool
HasNoDataMembers
()
{
class
EmptyOp
:
public
Op
<
EmptyOp
,
TraitOrInterface
...
>
{};
return
sizeof
(
ConcreteOp
)
==
sizeof
(
EmptyOp
);
}
static
void
VerifyInvariants
(
Operation
*
op
)
{
static_assert
(
HasNoDataMembers
(),
"Op class shouldn't define new data members"
);
op
->
dyn_cast
<
ConcreteOp
>
().
Verify
();
}
};
}
// namespace ir
paddle/ir/core/op_info.cc
浏览文件 @
96652265
...
...
@@ -35,11 +35,7 @@ const char *OpInfo::name() const { return impl_ ? impl_->name() : nullptr; }
TypeId
OpInfo
::
id
()
const
{
return
impl_
?
impl_
->
id
()
:
TypeId
();
}
void
OpInfo
::
Verify
(
const
std
::
vector
<
OpResult
>
&
inputs
,
const
std
::
vector
<
Type
>
&
outputs
,
const
AttributeMap
&
attributes
)
{
impl_
->
verify
()(
inputs
,
outputs
,
attributes
);
}
void
OpInfo
::
Verify
(
Operation
*
operation
)
const
{
impl_
->
verify
()(
operation
);
}
void
*
OpInfo
::
GetInterfaceImpl
(
TypeId
interface_id
)
const
{
return
impl_
?
impl_
->
GetInterfaceImpl
(
interface_id
)
:
nullptr
;
...
...
paddle/ir/core/op_info.h
浏览文件 @
96652265
...
...
@@ -25,6 +25,9 @@ class OpResult;
class
Type
;
class
Attribute
;
class
Dialect
;
class
Operation
;
typedef
void
(
*
VerifyPtr
)(
Operation
*
op
);
class
IR_API
OpInfo
{
public:
...
...
@@ -49,9 +52,7 @@ class IR_API OpInfo {
TypeId
id
()
const
;
void
Verify
(
const
std
::
vector
<
OpResult
>
&
inputs
,
const
std
::
vector
<
Type
>
&
outputs
,
const
std
::
unordered_map
<
std
::
string
,
Attribute
>
&
attributes
);
void
Verify
(
Operation
*
)
const
;
template
<
typename
Trait
>
bool
HasTrait
()
const
{
...
...
paddle/ir/core/op_info_impl.h
浏览文件 @
96652265
...
...
@@ -25,9 +25,6 @@
namespace
ir
{
class
Dialect
;
typedef
void
(
*
VerifyPtr
)(
const
std
::
vector
<
OpResult
>
&
inputs
,
const
std
::
vector
<
Type
>
&
outputs
,
const
AttributeMap
&
attributes
);
///
/// \brief OpInfoImpl class.
...
...
paddle/ir/core/operation.cc
浏览文件 @
96652265
...
...
@@ -46,10 +46,6 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const
std
::
vector
<
ir
::
Type
>
&
output_types
,
ir
::
OpInfo
op_info
,
size_t
num_regions
)
{
// 0. Verify
if
(
op_info
)
{
op_info
.
Verify
(
inputs
,
output_types
,
attributes
);
}
// 1. Calculate the required memory size for OpResults + Operation +
// OpOperands.
uint32_t
num_results
=
output_types
.
size
();
...
...
@@ -100,6 +96,11 @@ Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
base_ptr
+=
sizeof
(
Region
);
}
}
// 0. Verify
if
(
op_info
)
{
op_info
.
Verify
(
op
);
}
return
op
;
}
...
...
test/cpp/ir/core/ir_infershape_test.cc
浏览文件 @
96652265
...
...
@@ -45,9 +45,7 @@ class OperationTest
static
const
char
*
name
()
{
return
"test.operation2"
;
}
static
constexpr
uint32_t
attributes_num
=
2
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{}
static
void
Verify
()
{}
static
void
InferShape
(
phi
::
InferMetaContext
*
infer_meta
)
{
auto
fn
=
PD_INFER_META
(
phi
::
CreateInferMeta
);
fn
(
infer_meta
);
...
...
test/cpp/ir/core/ir_op_test.cc
浏览文件 @
96652265
...
...
@@ -90,9 +90,8 @@ class Operation1 : public ir::Op<Operation1> {
static
const
char
*
name
()
{
return
"test.operation1"
;
}
static
constexpr
uint32_t
attributes_num
=
2
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
Verify
()
{
auto
&
attributes
=
this
->
attributes
();
if
(
attributes
.
count
(
"op1_attr1"
)
==
0
||
!
attributes
.
at
(
"op1_attr1"
).
isa
<
ir
::
StrAttribute
>
())
{
throw
(
"Type of attribute: parameter_name is not right."
);
...
...
@@ -133,9 +132,8 @@ class Operation2
static
const
char
*
name
()
{
return
"test.operation2"
;
}
static
constexpr
uint32_t
attributes_num
=
2
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
Verify
()
{
auto
&
attributes
=
this
->
attributes
();
if
(
attributes
.
count
(
"op2_attr1"
)
==
0
||
(
!
attributes
.
at
(
"op2_attr1"
).
isa
<
ir
::
StrAttribute
>
()))
{
throw
(
"Type of attribute: parameter_name is not right."
);
...
...
test/cpp/ir/core/ir_program_test.cc
浏览文件 @
96652265
...
...
@@ -38,22 +38,21 @@ class AddOp : public ir::Op<AddOp> {
static
const
char
*
name
()
{
return
"test.add"
;
}
static
constexpr
const
char
**
attributes_name
=
nullptr
;
static
constexpr
uint32_t
attributes_num
=
0
;
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
if
(
inputs
.
size
()
!=
2
)
{
throw
(
"The size of inputs must be equal to 2."
);
}
if
(
outputs
.
size
()
!=
1
)
{
throw
(
"The size of outputs must be equal to 1."
);
}
}
void
Verify
();
static
void
Build
(
ir
::
Builder
&
builder
,
// NOLINT
ir
::
OperationArgument
&
argument
,
// NOLINT
ir
::
OpResult
l_operand
,
ir
::
OpResult
r_operand
,
ir
::
Type
sum_type
);
};
void
AddOp
::
Verify
()
{
if
(
num_operands
()
!=
2
)
{
throw
(
"The size of inputs must be equal to 2."
);
}
if
(
num_results
()
!=
1
)
{
throw
(
"The size of outputs must be equal to 1."
);
}
}
void
AddOp
::
Build
(
ir
::
Builder
&
,
ir
::
OperationArgument
&
argument
,
ir
::
OpResult
l_operand
,
...
...
@@ -262,9 +261,9 @@ TEST(program_test, builder) {
ir
::
Type
full_op_output
=
full_op
->
result
(
0
).
type
();
EXPECT_EQ
(
program
.
block
()
->
size
(),
1u
);
EXPECT_EQ
(
program
.
block
()
->
back
(),
full_op
.
operation
());
EXPECT_EQ
(
full_op
->
num_operands
(),
0u
);
EXPECT_EQ
(
full_op
->
num_results
(),
1u
);
EXPECT_EQ
(
full_op
->
attributes
().
size
(),
4u
);
EXPECT_EQ
(
full_op
.
num_operands
(),
0u
);
EXPECT_EQ
(
full_op
.
num_results
(),
1u
);
EXPECT_EQ
(
full_op
.
attributes
().
size
(),
4u
);
EXPECT_EQ
(
full_op_output
.
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
().
offset
()
==
0
,
true
);
...
...
test/cpp/ir/pass/pass_manager_test.cc
浏览文件 @
96652265
...
...
@@ -65,22 +65,21 @@ class AddOp : public ir::Op<AddOp> {
static
const
char
*
name
()
{
return
"test.add"
;
}
static
constexpr
const
char
**
attributes_name
=
nullptr
;
static
constexpr
uint32_t
attributes_num
=
0
;
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
if
(
inputs
.
size
()
!=
2
)
{
throw
(
"The size of inputs must be equal to 2."
);
}
if
(
outputs
.
size
()
!=
1
)
{
throw
(
"The size of outputs must be equal to 1."
);
}
}
void
Verify
();
static
void
Build
(
ir
::
Builder
&
builder
,
// NOLINT
ir
::
OperationArgument
&
argument
,
// NOLINT
ir
::
OpResult
l_operand
,
ir
::
OpResult
r_operand
,
ir
::
Type
sum_type
);
};
void
AddOp
::
Verify
()
{
if
(
num_operands
()
!=
2
)
{
throw
(
"The size of inputs must be equal to 2."
);
}
if
(
num_results
()
!=
1
)
{
throw
(
"The size of outputs must be equal to 1."
);
}
}
void
AddOp
::
Build
(
ir
::
Builder
&
,
ir
::
OperationArgument
&
argument
,
ir
::
OpResult
l_operand
,
...
...
test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc
浏览文件 @
96652265
...
...
@@ -48,9 +48,11 @@ class Operation1 : public ir::Op<Operation1> {
static
const
char
*
name
()
{
return
"test.Operation1"
;
}
static
constexpr
uint32_t
attributes_num
=
2
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Verify
(
const
std
::
vector
<
ir
::
OpResult
>
&
inputs
,
const
std
::
vector
<
ir
::
Type
>
&
outputs
,
const
ir
::
AttributeMap
&
attributes
)
{
void
Verify
();
static
void
InferShape
()
{
VLOG
(
2
)
<<
"This is op2's InferShape interface."
;
}
};
void
Operation1
::
Verify
()
{
auto
&
attributes
=
this
->
attributes
();
if
(
attributes
.
count
(
"op2_attr1"
)
==
0
||
(
!
attributes
.
at
(
"op2_attr1"
).
isa
<
ir
::
StrAttribute
>
()))
{
throw
(
"Type of attribute: parameter_name is not right."
);
...
...
@@ -59,9 +61,7 @@ class Operation1 : public ir::Op<Operation1> {
(
!
attributes
.
at
(
"op2_attr2"
).
isa
<
ir
::
StrAttribute
>
()))
{
throw
(
"Type of attribute: parameter_name is not right."
);
}
}
static
void
InferShape
()
{
VLOG
(
2
)
<<
"This is op2's InferShape interface."
;
}
};
}
const
char
*
Operation1
::
attributes_name
[
attributes_num
]
=
{
"op2_attr1"
,
"op2_attr2"
};
IR_DECLARE_EXPLICIT_TYPE_ID
(
Operation1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录