diff --git a/paddle/ir/CMakeLists.txt b/paddle/ir/CMakeLists.txt index 39e5ff3fda611a1159d04b3b0c0fb618494c5746..581bb3f8a7c584be90a37e305fa2a251b3c2ceb2 100644 --- a/paddle/ir/CMakeLists.txt +++ b/paddle/ir/CMakeLists.txt @@ -38,6 +38,7 @@ add_subdirectory(core) add_subdirectory(pass) add_subdirectory(pattern_rewrite) add_subdirectory(builtin_transforms) +add_subdirectory(dialect) if(WIN32) if(WITH_SHARED_IR) diff --git a/paddle/ir/dialect/CMakeLists.txt b/paddle/ir/dialect/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a87b0abfb2383d1ebf0f867d6f8868437e212665 --- /dev/null +++ b/paddle/ir/dialect/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(shape) diff --git a/paddle/ir/dialect/shape/CMakeLists.txt b/paddle/ir/dialect/shape/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ab8ecdd7eda28cc6b2f2337f391c67b18b74db03 --- /dev/null +++ b/paddle/ir/dialect/shape/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB SHAPE_SRCS "*.cc") +ir_library(ir_shape SRCS ${SHAPE_SRCS} DEPS ir_core) diff --git a/paddle/ir/dialect/shape/shape_dialect.cc b/paddle/ir/dialect/shape/shape_dialect.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5e3adc3ac0a5ab18052a14cbc0d39167a2aa4c5 --- /dev/null +++ b/paddle/ir/dialect/shape/shape_dialect.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/dialect/shape/shape_dialect.h" +#include "paddle/ir/dialect/shape/shape_op.h" + +namespace ir { +namespace dialect { +ShapeDialect::ShapeDialect(IrContext *context) + : Dialect(name(), context, TypeId::get()) { + initialize(); +} + +void ShapeDialect::initialize() { RegisterOps(); } + +} // namespace dialect +} // namespace ir + +IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::ShapeDialect) diff --git a/paddle/ir/dialect/shape/shape_dialect.h b/paddle/ir/dialect/shape/shape_dialect.h new file mode 100644 index 0000000000000000000000000000000000000000..eb47aa1345f2862b90ab97ceba10ccb609201ce4 --- /dev/null +++ b/paddle/ir/dialect/shape/shape_dialect.h @@ -0,0 +1,42 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/dialect.h" + +namespace ir { +namespace dialect { +/// +/// \brief Shape Dialect: +/// +class IR_API ShapeDialect : public ir::Dialect { + public: + explicit ShapeDialect(ir::IrContext *context); + /// + /// \brief Each Dialect needs to provide a name function to return the name of + /// the Dialect. + /// + /// \return The name of this Dialect. + /// + static const char *name() { return "shape"; } + + private: + void initialize(); +}; + +} // namespace dialect +} // namespace ir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::ShapeDialect) diff --git a/paddle/ir/dialect/shape/shape_op.cc b/paddle/ir/dialect/shape/shape_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7befe847790bf777b3d11d7579ec5bd8fc6e7894 --- /dev/null +++ b/paddle/ir/dialect/shape/shape_op.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/dialect/shape/shape_op.h" +#include "paddle/ir/core/builtin_attribute.h" + +namespace ir { +namespace dialect { + +const char *SymbolicDim::attributes_name[attributes_num] = {"knownNegativeOne", + "knownNonNegative", + "knownNonSizeOne", + "knownNonSizeZero", + "sym_name", + "value"}; // NOLINT + +void SymbolicDim::Build( + Builder &builder, + OperationArgument &argument, + const std::string &sym_name, + int64_t value, // TODO(zhangbo) value = ShapedType::kDynamic + bool knownNonNegative, + bool knownNegativeOne, + bool knownNonSizeOne, + bool knownNonSizeZero) { + ir::Attribute attr_sym_name = + ir::StrAttribute::get(ir::IrContext::Instance(), sym_name); + argument.AddAttribute("sym_name", attr_sym_name); + ir::Attribute attr_value = + ir::Int64Attribute::get(ir::IrContext::Instance(), value); + argument.AddAttribute("value", attr_value); + ir::Attribute attr_knownNonNegative = + ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonNegative); + argument.AddAttribute("knownNonNegative", attr_knownNonNegative); + ir::Attribute attr_knownNegativeOne = + ir::BoolAttribute::get(ir::IrContext::Instance(), knownNegativeOne); + argument.AddAttribute("knownNegativeOne", attr_knownNegativeOne); + ir::Attribute attr_knownNonSizeOne = + ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonSizeOne); + argument.AddAttribute("knownNonSizeOne", attr_knownNonSizeOne); + ir::Attribute attr_knownNonSizeZero = + ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonSizeZero); + argument.AddAttribute("knownNonSizeZero", attr_knownNonSizeZero); +} + +std::string SymbolicDim::getSymName() { + return attribute("sym_name").AsString(); +} +int64_t SymbolicDim::getValue() { + return attribute("value").data(); +} +bool SymbolicDim::getKnownNonNegative() { + return attribute("knownNonNegative").data(); +} +bool SymbolicDim::getKnownNegativeOne() { + return attribute("knownNegativeOne").data(); +} +bool SymbolicDim::getKnownNonSizeOne() { + return attribute("knownNonSizeOne").data(); +} +bool SymbolicDim::getKnownNonSizeZero() { + return attribute("knownNonSizeZero").data(); +} + +void SymbolicDim::updateSymName(std::string attrValue) { + operation()->set_attribute( + "sym_name", ir::StrAttribute::get(ir::IrContext::Instance(), attrValue)); +} +void SymbolicDim::updateValue(int64_t attrValue) { + operation()->set_attribute( + "value", ir::Int64Attribute::get(ir::IrContext::Instance(), attrValue)); +} + +void SymbolicDim::updateKnownNonNegative(bool attrValue) { + operation()->set_attribute( + "knownNonNegative", + ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); +} +void SymbolicDim::updateKnownNegativeOne(bool attrValue) { + operation()->set_attribute( + "knownNegativeOne", + ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); +} +void SymbolicDim::updateKnownNonSizeOne(bool attrValue) { + operation()->set_attribute( + "knownNonSizeOne", + ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); +} +void SymbolicDim::updateKnownNonSizeZero(bool attrValue) { + operation()->set_attribute( + "knownNonSizeZero", + ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue)); +} + +} // namespace dialect +} // namespace ir + +IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim) diff --git a/paddle/ir/dialect/shape/shape_op.h b/paddle/ir/dialect/shape/shape_op.h new file mode 100644 index 0000000000000000000000000000000000000000..48445d4e8cb75f95e7cfc8a5da3bbe8a87cbef18 --- /dev/null +++ b/paddle/ir/dialect/shape/shape_op.h @@ -0,0 +1,60 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/op_base.h" + +namespace ir { +namespace dialect { + +class IR_API SymbolicDim : public Op { + public: + using Op::Op; + static const char *name() { return "shape.SymbolicDim"; } + + static constexpr uint32_t attributes_num = 6; + static const char *attributes_name[attributes_num]; + + static void Build( + Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const std::string &sym_name, + int64_t value = -100000, // TODO(zhangbo): value = ShapedType::kDynamic + bool knownNonNegative = false, + bool knownNegativeOne = false, + bool knownNonSizeOne = false, + bool knownNonSizeZero = false); + std::string getSymName(); + int64_t getValue(); + bool getKnownNonNegative(); + bool getKnownNegativeOne(); + bool getKnownNonSizeOne(); + bool getKnownNonSizeZero(); + + void updateSymName(std::string attrValue); + void updateValue(int64_t attrValue); + + void updateKnownNonNegative(bool attrValue); + void updateKnownNegativeOne(bool attrValue); + void updateKnownNonSizeOne(bool attrValue); + void updateKnownNonSizeZero(bool attrValue); + void Verify() {} +}; + +} // namespace dialect +} // namespace ir + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim); diff --git a/test/cpp/ir/CMakeLists.txt b/test/cpp/ir/CMakeLists.txt index f33f84eab371150c682a0dce6ae8929c886eea70..d2117ad5c24e2e660464712695288ab872fa098c 100644 --- a/test/cpp/ir/CMakeLists.txt +++ b/test/cpp/ir/CMakeLists.txt @@ -3,3 +3,4 @@ add_subdirectory(pass) add_subdirectory(pattern_rewrite) add_subdirectory(kernel_dialect) add_subdirectory(cinn) +add_subdirectory(shape_dialect) diff --git a/test/cpp/ir/shape_dialect/CMakeLists.txt b/test/cpp/ir/shape_dialect/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7b959a283b20e5aa472e37b249df7e3e5f6b8a6b --- /dev/null +++ b/test/cpp/ir/shape_dialect/CMakeLists.txt @@ -0,0 +1 @@ +cc_test_old(assist_struct_test SRCS assist_struct_test.cc DEPS ir gtest) diff --git a/test/cpp/ir/shape_dialect/assist_struct_test.cc b/test/cpp/ir/shape_dialect/assist_struct_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..05ecf3734c42102bb718be170e50b986f9146c61 --- /dev/null +++ b/test/cpp/ir/shape_dialect/assist_struct_test.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/ir/core/block.h" +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/dialect/shape/shape_dialect.h" +#include "paddle/ir/dialect/shape/shape_op.h" + +TEST(assist_struct_test, symbolic_dim) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + ctx->GetOrRegisterDialect(); + ir::Builder builder = ir::Builder(ctx, program.block()); + ir::dialect::SymbolicDim sym_dim = builder.Build( + "S0", 10, false, false, false, false); + EXPECT_EQ(sym_dim.getValue(), 10); + EXPECT_EQ(sym_dim.getSymName(), "S0"); + EXPECT_FALSE(sym_dim.getKnownNegativeOne()); + EXPECT_FALSE(sym_dim.getKnownNonSizeOne()); + EXPECT_FALSE(sym_dim.getKnownNonSizeZero()); + EXPECT_FALSE(sym_dim.getKnownNonNegative()); + + sym_dim.updateValue(20); + sym_dim.updateSymName("S1"); + sym_dim.updateKnownNegativeOne(true); + sym_dim.updateKnownNonSizeOne(true); + sym_dim.updateKnownNonSizeZero(true); + sym_dim.updateKnownNonNegative(true); + + EXPECT_EQ(sym_dim.getValue(), 20); + EXPECT_EQ(sym_dim.getSymName(), "S1"); + EXPECT_TRUE(sym_dim.getKnownNegativeOne()); + EXPECT_TRUE(sym_dim.getKnownNonSizeOne()); + EXPECT_TRUE(sym_dim.getKnownNonSizeZero()); + EXPECT_TRUE(sym_dim.getKnownNonNegative()); +}