未验证 提交 f239d7d1 编写于 作者: W winter-wang 提交者: GitHub

[IR] add type&attribtue api for builder. (#54965)

上级 5a6cd05f
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/region.h"
#include "paddle/ir/core/value.h"
......@@ -38,5 +40,44 @@ Operation *Builder::Insert(Operation *op) {
}
return op;
}
UInt8Type Builder::uint8_type() { return UInt8Type::get(context_); }
Int8Type Builder::int8_type() { return Int8Type::get(context_); }
VectorType Builder::vec_type(const std::vector<Type> &value) {
return VectorType::get(context_, value);
}
BFloat16Type Builder::bfloat16_type() { return BFloat16Type::get(context_); }
Float32Type Builder::float32_type() { return Float32Type::get(context_); }
Float64Type Builder::float64_type() { return Float64Type::get(context_); }
Int16Type Builder::int16_type() { return Int16Type::get(context_); }
BoolType Builder::bool_type() { return BoolType::get(context_); }
Complex64Type Builder::complex64_type() { return Complex64Type::get(context_); }
Complex128Type Builder::complex128_type() {
return Complex128Type::get(context_);
}
StrAttribute Builder::str_attr(const std::string &value) {
return StrAttribute::get(context_, value);
}
BoolAttribute Builder::bool_attr(bool value) {
return BoolAttribute::get(context_, value);
}
FloatAttribute Builder::float_attr(float value) {
return FloatAttribute::get(context_, value);
}
DoubleAttribute Builder::double_attr(double value) {
return DoubleAttribute::get(context_, value);
}
Int32Attribute Builder::int32_attr(int32_t value) {
return Int32Attribute::get(context_, value);
}
Int64Attribute Builder::int64_attr(int64_t value) {
return Int64Attribute::get(context_, value);
}
ArrayAttribute Builder::array_attr(const std::vector<Attribute> &value) {
return ArrayAttribute::get(context_, value);
}
PointerAttribute Builder::pointer_attr(void *value) {
return PointerAttribute::get(context_, value);
}
} // namespace ir
......@@ -20,6 +20,25 @@
#include "paddle/ir/core/operation.h"
namespace ir {
class Type;
class UInt8Type;
class Int8Type;
class VectorType;
class BFloat16Type;
class Float32Type;
class Float64Type;
class Int16Type;
class BoolType;
class Complex64Type;
class Complex128Type;
class StrAttribute;
class BoolAttribute;
class FloatAttribute;
class DoubleAttribute;
class Int32Attribute;
class Int64Attribute;
class ArrayAttribute;
class PointerAttribute;
///
/// \brief Unified interface of the Attribute class. Derivation of all Attribute
......@@ -90,6 +109,26 @@ class Builder {
return op->dyn_cast<OpTy>();
}
IR_API UInt8Type uint8_type();
IR_API Int8Type int8_type();
IR_API VectorType vec_type(const std::vector<Type> &);
IR_API BFloat16Type bfloat16_type();
IR_API Float32Type float32_type();
IR_API Float64Type float64_type();
IR_API Int16Type int16_type();
IR_API BoolType bool_type();
IR_API Complex64Type complex64_type();
IR_API Complex128Type complex128_type();
IR_API StrAttribute str_attr(const std::string &value);
IR_API BoolAttribute bool_attr(bool value);
IR_API FloatAttribute float_attr(float value);
IR_API DoubleAttribute double_attr(double value);
IR_API Int32Attribute int32_attr(int32_t value);
IR_API Int64Attribute int64_attr(int64_t value);
IR_API ArrayAttribute array_attr(const std::vector<Attribute> &value);
IR_API PointerAttribute pointer_attr(void *value);
private:
Operation *Insert(Operation *op);
......
......@@ -47,6 +47,9 @@ class IR_API IrContext {
///
static IrContext *Instance();
IrContext();
~IrContext();
///
/// \brief Get an instance of IrContextImpl, a private member of IrContext.
/// For the specific definition of IrContextImpl, see ir_context.cc.
......@@ -184,8 +187,6 @@ class IR_API IrContext {
void operator=(const IrContext &) = delete;
private:
IrContext();
~IrContext();
IrContextImpl *impl_;
};
......
......@@ -3,6 +3,7 @@ cc_test_old(ir_attribute_test SRCS ir_attribute_test.cc DEPS ir gtest)
cc_test_old(ir_value_test SRCS ir_value_test.cc DEPS ir gtest)
cc_test_old(ir_op_test SRCS ir_op_test.cc DEPS ir gtest)
cc_test_old(ir_region_test SRCS ir_region_test.cc DEPS ir gtest)
cc_test_old(ir_builder_test SRCS ir_builder_test.cc DEPS ir gtest)
cc_test_old(
ir_program_test
SRCS
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <map>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
TEST(builder_test, type_api) {
ir::IrContext ctx;
ir::Builder builder(&ctx);
EXPECT_EQ(ir::UInt8Type::get(&ctx), builder.uint8_type());
EXPECT_EQ(ir::Int8Type::get(&ctx), builder.int8_type());
EXPECT_EQ(ir::VectorType::get(&ctx, std::vector<ir::Type>()),
builder.vec_type({}));
EXPECT_EQ(ir::BFloat16Type::get(&ctx), builder.bfloat16_type());
EXPECT_EQ(ir::Float32Type::get(&ctx), builder.float32_type());
EXPECT_EQ(ir::Float64Type::get(&ctx), builder.float64_type());
EXPECT_EQ(ir::Int16Type::get(&ctx), builder.int16_type());
EXPECT_EQ(ir::BoolType::get(&ctx), builder.bool_type());
EXPECT_EQ(ir::Complex64Type::get(&ctx), builder.complex64_type());
EXPECT_EQ(ir::Complex128Type::get(&ctx), builder.complex128_type());
}
TEST(builder_test, attribute_api) {
ir::IrContext ctx;
ir::Builder builder(&ctx);
EXPECT_EQ(ir::StrAttribute::get(&ctx, "test"), builder.str_attr("test"));
EXPECT_EQ(ir::BoolAttribute::get(&ctx, true), builder.bool_attr(true));
EXPECT_EQ(ir::FloatAttribute::get(&ctx, 0.2f), builder.float_attr(0.2f));
EXPECT_EQ(ir::DoubleAttribute::get(&ctx, 2.0), builder.double_attr(2.0));
EXPECT_EQ(ir::Int32Attribute::get(&ctx, 2), builder.int32_attr(2));
EXPECT_EQ(ir::Int64Attribute::get(&ctx, 2), builder.int64_attr(2));
EXPECT_EQ(ir::ArrayAttribute::get(&ctx, std::vector<ir::Attribute>()),
builder.array_attr({}));
EXPECT_EQ(ir::PointerAttribute::get(&ctx, nullptr),
builder.pointer_attr(nullptr));
}
......@@ -32,8 +32,8 @@ TEST(region, erase_op_test) {
ir::Builder builder = ir::Builder(ctx, program.block());
// (3) Def a = ConstantOp("2.0"); b = ConstantOp("2.0");
ir::FloatAttribute fp_attr = ir::FloatAttribute::get(ctx, 2.0f);
ir::Float32Type fp32_type = ir::Float32Type::get(ctx);
ir::FloatAttribute fp_attr = builder.float_attr(2.0f);
ir::Float32Type fp32_type = builder.float32_type();
ir::OpResult a = builder.Build<ir::ConstantOp>(fp_attr, fp32_type)->result(0);
ir::OpResult b = builder.Build<ir::ConstantOp>(fp_attr, fp32_type)->result(0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册