未验证 提交 1ad502df 编写于 作者: L liuruyan 提交者: GitHub

Add ShapeDialect dict & SymbolicDimOp with UT. (#56156)

* Add ShapeDialect dict & SymbolicDimOp without UT.

* add unittest and fix Update_xxx_Func.

* change std::string to const std::string & and remove phi dependency.
上级 a97b507e
...@@ -38,6 +38,7 @@ add_subdirectory(core) ...@@ -38,6 +38,7 @@ add_subdirectory(core)
add_subdirectory(pass) add_subdirectory(pass)
add_subdirectory(pattern_rewrite) add_subdirectory(pattern_rewrite)
add_subdirectory(builtin_transforms) add_subdirectory(builtin_transforms)
add_subdirectory(dialect)
if(WIN32) if(WIN32)
if(WITH_SHARED_IR) if(WITH_SHARED_IR)
......
file(GLOB SHAPE_SRCS "*.cc")
ir_library(ir_shape SRCS ${SHAPE_SRCS} DEPS ir_core)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect/shape/shape_dialect.h"
#include "paddle/ir/dialect/shape/shape_op.h"
namespace ir {
namespace dialect {
ShapeDialect::ShapeDialect(IrContext *context)
: Dialect(name(), context, TypeId::get<ShapeDialect>()) {
initialize();
}
void ShapeDialect::initialize() { RegisterOps<SymbolicDim>(); }
} // namespace dialect
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::ShapeDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/dialect.h"
namespace ir {
namespace dialect {
///
/// \brief Shape Dialect:
///
class IR_API ShapeDialect : public ir::Dialect {
public:
explicit ShapeDialect(ir::IrContext *context);
///
/// \brief Each Dialect needs to provide a name function to return the name of
/// the Dialect.
///
/// \return The name of this Dialect.
///
static const char *name() { return "shape"; }
private:
void initialize();
};
} // namespace dialect
} // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::ShapeDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect/shape/shape_op.h"
#include "paddle/ir/core/builtin_attribute.h"
namespace ir {
namespace dialect {
const char *SymbolicDim::attributes_name[attributes_num] = {"knownNegativeOne",
"knownNonNegative",
"knownNonSizeOne",
"knownNonSizeZero",
"sym_name",
"value"}; // NOLINT
void SymbolicDim::Build(
Builder &builder,
OperationArgument &argument,
const std::string &sym_name,
int64_t value, // TODO(zhangbo) value = ShapedType::kDynamic
bool knownNonNegative,
bool knownNegativeOne,
bool knownNonSizeOne,
bool knownNonSizeZero) {
ir::Attribute attr_sym_name =
ir::StrAttribute::get(ir::IrContext::Instance(), sym_name);
argument.AddAttribute("sym_name", attr_sym_name);
ir::Attribute attr_value =
ir::Int64Attribute::get(ir::IrContext::Instance(), value);
argument.AddAttribute("value", attr_value);
ir::Attribute attr_knownNonNegative =
ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonNegative);
argument.AddAttribute("knownNonNegative", attr_knownNonNegative);
ir::Attribute attr_knownNegativeOne =
ir::BoolAttribute::get(ir::IrContext::Instance(), knownNegativeOne);
argument.AddAttribute("knownNegativeOne", attr_knownNegativeOne);
ir::Attribute attr_knownNonSizeOne =
ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonSizeOne);
argument.AddAttribute("knownNonSizeOne", attr_knownNonSizeOne);
ir::Attribute attr_knownNonSizeZero =
ir::BoolAttribute::get(ir::IrContext::Instance(), knownNonSizeZero);
argument.AddAttribute("knownNonSizeZero", attr_knownNonSizeZero);
}
std::string SymbolicDim::getSymName() {
return attribute<ir::StrAttribute>("sym_name").AsString();
}
int64_t SymbolicDim::getValue() {
return attribute<ir::Int64Attribute>("value").data();
}
bool SymbolicDim::getKnownNonNegative() {
return attribute<ir::BoolAttribute>("knownNonNegative").data();
}
bool SymbolicDim::getKnownNegativeOne() {
return attribute<ir::BoolAttribute>("knownNegativeOne").data();
}
bool SymbolicDim::getKnownNonSizeOne() {
return attribute<ir::BoolAttribute>("knownNonSizeOne").data();
}
bool SymbolicDim::getKnownNonSizeZero() {
return attribute<ir::BoolAttribute>("knownNonSizeZero").data();
}
void SymbolicDim::updateSymName(std::string attrValue) {
operation()->set_attribute(
"sym_name", ir::StrAttribute::get(ir::IrContext::Instance(), attrValue));
}
void SymbolicDim::updateValue(int64_t attrValue) {
operation()->set_attribute(
"value", ir::Int64Attribute::get(ir::IrContext::Instance(), attrValue));
}
void SymbolicDim::updateKnownNonNegative(bool attrValue) {
operation()->set_attribute(
"knownNonNegative",
ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue));
}
void SymbolicDim::updateKnownNegativeOne(bool attrValue) {
operation()->set_attribute(
"knownNegativeOne",
ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue));
}
void SymbolicDim::updateKnownNonSizeOne(bool attrValue) {
operation()->set_attribute(
"knownNonSizeOne",
ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue));
}
void SymbolicDim::updateKnownNonSizeZero(bool attrValue) {
operation()->set_attribute(
"knownNonSizeZero",
ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue));
}
} // namespace dialect
} // namespace ir
IR_DEFINE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/op_base.h"
namespace ir {
namespace dialect {
class IR_API SymbolicDim : public Op<SymbolicDim> {
public:
using Op::Op;
static const char *name() { return "shape.SymbolicDim"; }
static constexpr uint32_t attributes_num = 6;
static const char *attributes_name[attributes_num];
static void Build(
Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::string &sym_name,
int64_t value = -100000, // TODO(zhangbo): value = ShapedType::kDynamic
bool knownNonNegative = false,
bool knownNegativeOne = false,
bool knownNonSizeOne = false,
bool knownNonSizeZero = false);
std::string getSymName();
int64_t getValue();
bool getKnownNonNegative();
bool getKnownNegativeOne();
bool getKnownNonSizeOne();
bool getKnownNonSizeZero();
void updateSymName(std::string attrValue);
void updateValue(int64_t attrValue);
void updateKnownNonNegative(bool attrValue);
void updateKnownNegativeOne(bool attrValue);
void updateKnownNonSizeOne(bool attrValue);
void updateKnownNonSizeZero(bool attrValue);
void Verify() {}
};
} // namespace dialect
} // namespace ir
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::dialect::SymbolicDim);
...@@ -3,3 +3,4 @@ add_subdirectory(pass) ...@@ -3,3 +3,4 @@ add_subdirectory(pass)
add_subdirectory(pattern_rewrite) add_subdirectory(pattern_rewrite)
add_subdirectory(kernel_dialect) add_subdirectory(kernel_dialect)
add_subdirectory(cinn) add_subdirectory(cinn)
add_subdirectory(shape_dialect)
cc_test_old(assist_struct_test SRCS assist_struct_test.cc DEPS ir gtest)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <map>
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/dialect/shape/shape_dialect.h"
#include "paddle/ir/dialect/shape/shape_op.h"
TEST(assist_struct_test, symbolic_dim) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
ir::dialect::SymbolicDim sym_dim = builder.Build<ir::dialect::SymbolicDim>(
"S0", 10, false, false, false, false);
EXPECT_EQ(sym_dim.getValue(), 10);
EXPECT_EQ(sym_dim.getSymName(), "S0");
EXPECT_FALSE(sym_dim.getKnownNegativeOne());
EXPECT_FALSE(sym_dim.getKnownNonSizeOne());
EXPECT_FALSE(sym_dim.getKnownNonSizeZero());
EXPECT_FALSE(sym_dim.getKnownNonNegative());
sym_dim.updateValue(20);
sym_dim.updateSymName("S1");
sym_dim.updateKnownNegativeOne(true);
sym_dim.updateKnownNonSizeOne(true);
sym_dim.updateKnownNonSizeZero(true);
sym_dim.updateKnownNonNegative(true);
EXPECT_EQ(sym_dim.getValue(), 20);
EXPECT_EQ(sym_dim.getSymName(), "S1");
EXPECT_TRUE(sym_dim.getKnownNegativeOne());
EXPECT_TRUE(sym_dim.getKnownNonSizeOne());
EXPECT_TRUE(sym_dim.getKnownNonSizeZero());
EXPECT_TRUE(sym_dim.getKnownNonNegative());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册