diff --git a/paddle/fluid/ir/dialect/CMakeLists.txt b/paddle/fluid/ir/dialect/CMakeLists.txt index 9cb024be507e10503c269c5691ffdb2a6a9ee57e..a54938ed6fa896df745e6329a39ec0c7bf21c9f7 100644 --- a/paddle/fluid/ir/dialect/CMakeLists.txt +++ b/paddle/fluid/ir/dialect/CMakeLists.txt @@ -2,7 +2,8 @@ set(PD_DIALECT_SOURCE_DIR "${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect") set(PD_DIALECT_BINARY_DIR "${PADDLE_BINARY_DIR}/paddle/fluid/ir/dialect") # Generate pd_dialect files defining op using op_gen_file -set(op_gen_file ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_gen.py) +set(op_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/op_gen.py) set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) set(op_forward_yaml_file1 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml @@ -17,10 +18,9 @@ set(op_backward_yaml_file2 ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml ) set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/pd_op.yaml) -set(op_yaml_file4 - ${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/pd_legacy_op.yaml) + set(op_yaml_files - ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3},${op_yaml_file4} + ${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3} ) set(op_namespace paddle,dialect) set(dialect_name pd) diff --git a/paddle/fluid/ir/dialect/kernel_attribute_storage.h b/paddle/fluid/ir/dialect/kernel_attribute_storage.h index 634e2f2dfff17767065979b506994e06b6e9f03c..67896e094eec5bca9e85a6af79b4481670497962 100644 --- a/paddle/fluid/ir/dialect/kernel_attribute_storage.h +++ b/paddle/fluid/ir/dialect/kernel_attribute_storage.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_factory.h" diff --git a/paddle/fluid/ir/dialect/kernel_op.cc b/paddle/fluid/ir/dialect/kernel_op.cc index 34bce0f176dd6ff5796f6233b49ee4f27c65f8ed..e46a874045d73cdd1034fd550936007e22678ccd 100644 --- a/paddle/fluid/ir/dialect/kernel_op.cc +++ b/paddle/fluid/ir/dialect/kernel_op.cc @@ -44,26 +44,14 @@ void PhiKernelOp::Verify() { "Type of attribute: kernel_key is not right.")); } -const std::string PhiKernelOp::op_name() { - return operation() - ->attributes() - .at("op_name") - .dyn_cast() - .data(); +std::string PhiKernelOp::op_name() { + return attributes().at("op_name").dyn_cast().data(); } -const std::string PhiKernelOp::kernel_name() { - return operation() - ->attributes() - .at("kernel_name") - .dyn_cast() - .data(); +std::string PhiKernelOp::kernel_name() { + return attributes().at("kernel_name").dyn_cast().data(); } phi::KernelKey PhiKernelOp::kernel_key() { - return operation() - ->attributes() - .at("kernel_key") - .dyn_cast() - .data(); + return attributes().at("kernel_key").dyn_cast().data(); } } // namespace dialect diff --git a/paddle/fluid/ir/dialect/kernel_op.h b/paddle/fluid/ir/dialect/kernel_op.h index c3a15e3be056d397e9dc06908a4f10f68b1b3f5d..5cbd4b0b434b38acd83e02e5bf90bcb27cfb6bba 100644 --- a/paddle/fluid/ir/dialect/kernel_op.h +++ b/paddle/fluid/ir/dialect/kernel_op.h @@ -27,8 +27,8 @@ class PhiKernelOp : public ir::Op { static const char *name() { return "phi.kernel"; } static constexpr uint32_t attributes_num = 3; static const char *attributes_name[attributes_num]; - const std::string op_name(); - const std::string kernel_name(); + std::string op_name(); + std::string kernel_name(); phi::KernelKey kernel_key(); void Verify(); }; diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_generator/op_gen.py similarity index 97% rename from paddle/fluid/ir/dialect/op_gen.py rename to paddle/fluid/ir/dialect/op_generator/op_gen.py index 8d1c446e686c4eb25ec4e9c58e58a8c0702b5639..29a96b9c386b0def8e17a536e08e51345f3bed1f 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_gen.py @@ -16,6 +16,8 @@ import argparse import os import yaml +from op_interface_gen import gen_exclusive_interface_str, gen_op_infer_meta_str +from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str # ===================================== @@ -29,7 +31,7 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST #undef GET_OP_LIST {op_declare} #else -// This file is generated by "paddle/fluid/ir/dialect/op_gen.py" +// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" #include @@ -78,17 +80,12 @@ op_n_attribute_declare_str = ( "static const char *attributes_name[{attribute_num}];" ) -OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }} -""" -OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }} -""" - # ===================================== # String Template for cc file code gen # ===================================== -CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_gen.py" +CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/ir/dialect/op_generator/op_gen.py" -#include "{h_file}" +#include "paddle/fluid/ir/dialect/pd_op.h" #include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/ir/core/builtin_attribute.h" @@ -142,12 +139,6 @@ void {op_name}::Build({build_args}) {{ {build_outputs} }} """ -OP_INFER_SHAPE_TEMPLATE = """ -void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ - auto fn = PD_INFER_META(phi::{infer_meta_func}); - fn(infer_meta); -}} -""" DEFINE_OP_TYPE_ID = """ IR_DEFINE_EXPLICIT_TYPE_ID({op_name}) @@ -1217,12 +1208,10 @@ def OpGenerator( op_interfaces = ["OpYamlInfoInterface"] op_traits = [] - exclusive_interface_str = "" if op_info.infer_meta_func: op_interfaces += ["InferMetaInterface"] - exclusive_interface_str += ( - " static void InferMeta( phi::InferMetaContext *infer_meta );" - ) + + exclusive_interface_str = gen_exclusive_interface_str(op_info) # If op has inplace info, we will generate inplace op and non-inplace op. for op_name in op_info.op_phi_name: @@ -1242,22 +1231,11 @@ def OpGenerator( # =================================== # # gen get input/output methods str # # =================================== # - op_get_inputs_outputs_str = "" - for idx in range(len(op_input_name_list)): - op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( - input_name=op_input_name_list[idx], - input_index=idx, - ) - for idx in range(len(op_mutable_attribute_name_list)): - op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( - input_name=op_mutable_attribute_name_list[idx], - input_index=idx + len(op_input_name_list), - ) - for idx in range(len(op_output_name_list)): - op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( - output_name=op_output_name_list[idx], - output_index=idx, - ) + op_get_inputs_outputs_str = gen_op_get_inputs_outputs_str( + op_input_name_list, + op_mutable_attribute_name_list, + op_output_name_list, + ) # =================================== # # gen Build methods str # @@ -1472,12 +1450,7 @@ def OpGenerator( op_output_optional_list, ) - op_infer_meta_str = "" - if op_info.infer_meta_func: - op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( - op_name=op_class_name, - infer_meta_func=op_info.infer_meta_func, - ) + op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name) ops_name_list.append(op_class_name) ops_declare_list.append(op_declare_str) diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..448253f2af6bfbf2631937f2c98ed73329431198 --- /dev/null +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -0,0 +1,41 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# generator interfaces + +OP_INFER_SHAPE_TEMPLATE = """ +void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ + auto fn = PD_INFER_META(phi::{infer_meta_func}); + fn(infer_meta); +}} +""" + + +def gen_op_infer_meta_str(op_info, op_class_name): + op_infer_meta_str = "" + if op_info.infer_meta_func: + op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( + op_name=op_class_name, + infer_meta_func=op_info.infer_meta_func, + ) + return op_infer_meta_str + + +def gen_exclusive_interface_str(op_info): + exclusive_interface_str = "" + if op_info.infer_meta_func: + exclusive_interface_str += ( + " static void InferMeta( phi::InferMetaContext *infer_meta );" + ) + return exclusive_interface_str diff --git a/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py b/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..2438a0d22490e354ca5b468fe47ce9d824cab9c6 --- /dev/null +++ b/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# generator op member function + +OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }} +""" +OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }} +""" +OP_GET_ATTRIBUTE_TEMPLATE = """ ir::Attribute attribute(const std::string &name) {{ + PADDLE_ENFORCE(attributes().count(name) > 0, + phi::errors::PreconditionNotMet("Attribute is not exist.")); + return attributes().at(name); + }} + template + T attribute(const std::string &name) {{ + PADDLE_ENFORCE(attributes().count(name) > 0 && attributes().at(name).isa(), + phi::errors::PreconditionNotMet("Attribute is not right.")); + return attributes().at(name).dyn_cast(); + }} +""" + + +def gen_op_get_inputs_outputs_str( + op_input_name_list, op_mutable_attribute_name_list, op_output_name_list +): + op_get_inputs_outputs_str = "" + for idx in range(len(op_input_name_list)): + op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( + input_name=op_input_name_list[idx], + input_index=idx, + ) + for idx in range(len(op_mutable_attribute_name_list)): + op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( + input_name=op_mutable_attribute_name_list[idx], + input_index=idx + len(op_input_name_list), + ) + for idx in range(len(op_output_name_list)): + op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format( + output_name=op_output_name_list[idx], + output_index=idx, + ) + op_get_inputs_outputs_str += OP_GET_ATTRIBUTE_TEMPLATE + return op_get_inputs_outputs_str diff --git a/paddle/fluid/ir/dialect/op_verify_gen.py b/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py similarity index 99% rename from paddle/fluid/ir/dialect/op_verify_gen.py rename to paddle/fluid/ir/dialect/op_generator/op_verify_gen.py index 7b65e8dce9181e46f1050e28eb4e96423f32f453..1e7441476fe3e8629b611870ec187a7ddd7dcd14 100644 --- a/paddle/fluid/ir/dialect/op_verify_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_verify_gen.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/paddle/fluid/ir/dialect/pd_attribute_storage.h b/paddle/fluid/ir/dialect/pd_attribute_storage.h index 6270791f725d4f0b5b7311f44a5de060be6fe850..5b8d9bf121b86cb59afa5f15a97e4508768cca09 100644 --- a/paddle/fluid/ir/dialect/pd_attribute_storage.h +++ b/paddle/fluid/ir/dialect/pd_attribute_storage.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/utils.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" diff --git a/paddle/fluid/ir/dialect/pd_legacy_op.yaml b/paddle/fluid/ir/dialect/pd_legacy_op.yaml deleted file mode 100644 index 9aa96732c87ebb0a996d5ff8968f8e1d27ededa1..0000000000000000000000000000000000000000 --- a/paddle/fluid/ir/dialect/pd_legacy_op.yaml +++ /dev/null @@ -1,32 +0,0 @@ -- name: elementwise_add - inputs: - - typename: Tensor - name: x - optional: false - no_need_buffer: false - data_transform: {} - - typename: Tensor - name: y - optional: false - no_need_buffer: false - data_transform: {} - attrs: - - {typename: int, name: axis} - outputs: - - {typename: Tensor, name: out, optional: false, intermediate: false} - no_need_buffer: null - data_transform: null - infer_meta: - func: ElementwiseInferMeta - param: [x, y] - kernel: - func: [add_raw] - param: [x, y] - backend: null - layout: null - data_type: null - dispatch: {add: null} - force_backend: null - inplace: {out: x} - view: null - backward: add_grad diff --git a/paddle/fluid/ir/dialect/pd_type_storage.h b/paddle/fluid/ir/dialect/pd_type_storage.h index dbdb3b374e4d223b89280e67bf27bf858bad2f81..cbba73f95a60f80d78129a9ba76c12ec51452f4f 100644 --- a/paddle/fluid/ir/dialect/pd_type_storage.h +++ b/paddle/fluid/ir/dialect/pd_type_storage.h @@ -17,6 +17,7 @@ #include #include "paddle/ir/core/type.h" +#include "paddle/ir/core/type_base.h" #include "paddle/ir/core/utils.h" #include "paddle/phi/core/tensor_meta.h" diff --git a/paddle/ir/core/attribute.cc b/paddle/ir/core/attribute.cc index 77e768720e36b0b556a91f692d5531b3973a3db1..0eff9964292dfd68841f624b331462a84f9ca80f 100644 --- a/paddle/ir/core/attribute.cc +++ b/paddle/ir/core/attribute.cc @@ -13,8 +13,20 @@ // limitations under the License. #include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/dialect.h" namespace ir { IrContext *Attribute::ir_context() const { return dialect().ir_context(); } + +TypeId Attribute::type_id() { return storage_->abstract_attribute().type_id(); } + +const AbstractAttribute &Attribute::abstract_attribute() { + return storage_->abstract_attribute(); +} + +const Dialect &Attribute::dialect() const { + return storage_->abstract_attribute().dialect(); +} + } // namespace ir diff --git a/paddle/ir/core/attribute.h b/paddle/ir/core/attribute.h index 0c0070a969814071ce8ec9bc41976cd926ba6757..d675fe7cd2fd56bad2f966356eb568f048bfc1a5 100644 --- a/paddle/ir/core/attribute.h +++ b/paddle/ir/core/attribute.h @@ -14,10 +14,15 @@ #pragma once -#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/cast_utils.h" +#include "paddle/ir/core/type_id.h" namespace ir { +class AttributeStorage; +class AbstractAttribute; +class IrContext; +class Dialect; + /// /// \brief Unified interface of the Attribute class. Derivation of all Attribute /// classes only derives interfaces, not members. @@ -46,17 +51,13 @@ class IR_API Attribute { /// /// \brief Some Attribute attribute acquisition interfaces. /// - TypeId type_id() { return storage_->abstract_attribute().type_id(); } + TypeId type_id(); - const AbstractAttribute &abstract_attribute() { - return storage_->abstract_attribute(); - } + const AbstractAttribute &abstract_attribute(); const Storage *storage() const { return storage_; } - const Dialect &dialect() const { - return storage_->abstract_attribute().dialect(); - } + const Dialect &dialect() const; IrContext *ir_context() const; diff --git a/paddle/ir/core/builder.h b/paddle/ir/core/builder.h index a290a2fb7df8bffd967f95e015ad93b7a9b6cc58..74856cdaf7c0ca4c5a43a9a3cd983c4843ec355d 100644 --- a/paddle/ir/core/builder.h +++ b/paddle/ir/core/builder.h @@ -17,6 +17,7 @@ #include #include "paddle/ir/core/block.h" +#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/operation.h" namespace ir { diff --git a/paddle/ir/core/builtin_attribute_storage.h b/paddle/ir/core/builtin_attribute_storage.h index bf68dc387cb31777a8cecf1d97fc638d3657b0a9..50a051c9cdf7708a1e6cf40852ad4c8a2d9e0a3e 100644 --- a/paddle/ir/core/builtin_attribute_storage.h +++ b/paddle/ir/core/builtin_attribute_storage.h @@ -19,6 +19,7 @@ #include #include "paddle/ir/core/attribute.h" +#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/utils.h" namespace ir { diff --git a/paddle/ir/core/builtin_type_storage.h b/paddle/ir/core/builtin_type_storage.h index 5c6f255461b1e7f57eab6398144b89704e4d7d58..7e6382b005fa6ddffc970a840b96dc263937ba1d 100644 --- a/paddle/ir/core/builtin_type_storage.h +++ b/paddle/ir/core/builtin_type_storage.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/ir/core/type.h" +#include "paddle/ir/core/type_base.h" #include "paddle/ir/core/utils.h" namespace ir { diff --git a/paddle/ir/core/operation_utils.cc b/paddle/ir/core/operation_utils.cc index d68c037000a7f78ec48cbbf53b955614ea5c4434..f975de0c828077fdd99f4c3382371dd3a4a2dbfd 100644 --- a/paddle/ir/core/operation_utils.cc +++ b/paddle/ir/core/operation_utils.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/ir/core/operation_utils.h" +#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/region.h" namespace ir { diff --git a/paddle/ir/core/operation_utils.h b/paddle/ir/core/operation_utils.h index 3e4610b0f1dd2d7ab9ccff72eef0a26dd1d3f154..3ab421f945daed1c6fd2d1cd35085c2e7fba02b0 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/ir/core/operation_utils.h @@ -14,6 +14,7 @@ #pragma once +#include #include "paddle/ir/core/attribute.h" #include "paddle/ir/core/op_info.h" #include "paddle/ir/core/region.h" diff --git a/paddle/ir/core/type.cc b/paddle/ir/core/type.cc index 8b1451fa76fb72dca2182da98f5614e80cfe3916..16713290d393d72ec8699be11885012b050d290e 100644 --- a/paddle/ir/core/type.cc +++ b/paddle/ir/core/type.cc @@ -14,7 +14,14 @@ #include "paddle/ir/core/type.h" #include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/type_base.h" namespace ir { -IrContext* Type::ir_context() const { return dialect().ir_context(); } +IrContext *Type::ir_context() const { return dialect().ir_context(); } + +TypeId Type::type_id() { return storage_->abstract_type().type_id(); } + +const AbstractType &Type::abstract_type() { return storage_->abstract_type(); } + +Dialect &Type::dialect() const { return storage_->abstract_type().dialect(); } } // namespace ir diff --git a/paddle/ir/core/type.h b/paddle/ir/core/type.h index 7647bf3641a8db4b5b8ca179a97beabcd0694b64..62dcefdf3ba65182272edf2987aa1ef2ba55a257 100644 --- a/paddle/ir/core/type.h +++ b/paddle/ir/core/type.h @@ -17,9 +17,13 @@ #include #include "paddle/ir/core/cast_utils.h" -#include "paddle/ir/core/type_base.h" +#include "paddle/ir/core/type_id.h" namespace ir { +class TypeStorage; +class AbstractType; +class IrContext; +class Dialect; /// /// \brief Unified interface of the Type class. Derivation of all Type classes /// only derives interfaces, not members. For example, DenseTensorType, @@ -53,13 +57,13 @@ class IR_API Type { /// /// \brief Some type attribute acquisition interfaces. /// - TypeId type_id() { return storage_->abstract_type().type_id(); } + TypeId type_id(); - const AbstractType &abstract_type() { return storage_->abstract_type(); } + const AbstractType &abstract_type(); const Storage *storage() const { return storage_; } - Dialect &dialect() const { return storage_->abstract_type().dialect(); } + Dialect &dialect() const; IrContext *ir_context() const; diff --git a/test/cpp/ir/core/ir_exe_test.cc b/test/cpp/ir/core/ir_exe_test.cc index 3c49fa0595edae07ad20912ae0da18f4867be1b1..0c23e49e805ec1e7d39fa5ca59853ba38cdd083d 100644 --- a/test/cpp/ir/core/ir_exe_test.cc +++ b/test/cpp/ir/core/ir_exe_test.cc @@ -44,6 +44,7 @@ #include "paddle/fluid/ir/pass/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_adaptor.h" +#include "paddle/ir/core/attribute.h" #include "paddle/phi/core/kernel_registry.h" PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); @@ -74,6 +75,11 @@ TEST(program_test, program) { true); EXPECT_EQ(block->size(), 4u); + ir::Attribute seed_attr = uniform1.attribute("seed"); + ir::Int32Attribute seed_attr1 = + uniform1.attribute("seed"); + EXPECT_EQ(seed_attr.dyn_cast().data(), seed_attr1.data()); + // Def: B = paddle::dialect::UniformOp(...) paddle::dialect::UniformOp uniform2 = builder.Build(std::vector{2, 2}, diff --git a/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc b/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc index 31f79b70fc0158f3ec85ef0b072f14fce928208c..5deb7ae9b8ef85eb9ea1ec5261e4f3d20234d711 100644 --- a/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc +++ b/test/cpp/ir/kernel_dialect/ir_kernel_dialect_pass_test.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/ir/dialect/kernel_dialect.h" +#include "paddle/fluid/ir/dialect/kernel_op.h" #include "paddle/fluid/ir/dialect/pd_attribute.h" #include "paddle/fluid/ir/dialect/pd_dialect.h" #include "paddle/fluid/ir/dialect/pd_type.h" @@ -34,6 +35,7 @@ #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/utils.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" @@ -84,6 +86,24 @@ TEST(program_test, program) { EXPECT_EQ(res1, true); EXPECT_EQ(res2, true); EXPECT_EQ(res3, true); + + EXPECT_EQ(kernel_program->block()->size(), 3u); + EXPECT_EQ(kernel_program->block() + ->front() + ->dyn_cast() + .op_name(), + "pd.full"); + EXPECT_EQ(kernel_program->block() + ->front() + ->dyn_cast() + .kernel_name(), + "full"); + EXPECT_EQ(kernel_program->block() + ->front() + ->dyn_cast() + .kernel_key() + .dtype(), + phi::DataType::FLOAT32); } TEST(dialect_attr, attr) {