未验证 提交 2f0b4ad0 编写于 作者: K kangguangli 提交者: GitHub

[IR] Refactor scalar attribute (#54340)

* refactor scalar attribute

* fix op build

* fix merge conflicts

* fix coverage ci
上级 f20d3fc2
......@@ -249,6 +249,15 @@ def to_phi_and_fluid_op_name(op_item):
return phi_name, fluid_name
scalar_type_maps = {
'int': 'ir::Int32_tAttribute',
'int64_t': 'ir::Int64_tAttribute',
'float': 'ir::FloatAttribute',
'dobule': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute',
}
# =====================================
# Parse Op Compat From Yaml
# =====================================
......@@ -298,10 +307,10 @@ class OpInfoParser:
self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['paddle::dialect::ScalarAttribute', 'int'],
'Scalar(int64_t)': ['paddle::dialect::ScalarAttribute', 'int64_t'],
'Scalar(float)': ['paddle::dialect::ScalarAttribute', 'float'],
'Scalar(dobule)': ['paddle::dialect::ScalarAttribute', 'dobule'],
'Scalar(int)': ['ir::Int32_tAttribute', 'int'],
'Scalar(int64_t)': ['ir::Int64_tAttribute', 'int64_t'],
'Scalar(float)': ['ir::FloatAttribute', 'float'],
'Scalar(dobule)': ['ir::DoubleAttribute', 'dobule'],
'Scalar[]': [
'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'std::vector<Scalar>',
......@@ -707,7 +716,7 @@ def GenBuildAttributes(
):
INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr}));
"""
SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::Scalar({attr}));
SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = TransToIrAttribute({attr}, ir::IrContext::Instance());
"""
STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr});
"""
......@@ -742,7 +751,6 @@ def GenBuildAttributes(
+ ".size()",
create_attribute=SCALAR_STR_TEMPLATE.format(
attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=inner_attribute_type,
attr=op_non_mutable_attribute_name_list[idx] + "[i]",
),
)
......@@ -773,7 +781,6 @@ def GenBuildAttributes(
):
attr_str += SCALAR_STR_TEMPLATE.format(
attr_name=op_non_mutable_attribute_name_list[idx],
op_attribute_type=op_non_mutable_attribute_type_list[idx],
attr=op_non_mutable_attribute_name_list[idx],
)
else:
......@@ -828,7 +835,7 @@ def GenBuildOutputs(
}}
"""
CREATE_INTARRAY_MUTABLE_ATTRIBUE_TEMPLATE = """ std::vector<int64_t> {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{dtype}>(); (void){name};\n"""
CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE = """ {dtype} {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<{ir_type}>().data(); (void){name};\n"""
CREATE_STRING_MUTABLE_ATTRIBUE_TEMPLATE = """ std::string {name} = {name}_.owner()->dyn_cast<ir::ConstantOp>().value().dyn_cast<ir::StrAttribute>().data(); (void){name};\n"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
......@@ -869,7 +876,9 @@ def GenBuildOutputs(
# scalar
elif attr_dtype[0] == "paddle::dialect::ScalarAttribute":
build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx], dtype=attr_dtype[1]
name=op_mutable_attribute_name_list[idx],
dtype=attr_dtype[1],
ir_type=scalar_type_maps[attr_dtype[1]],
)
# string
elif attr_dtype[0] == "ir::StrAttribute":
......
......@@ -18,10 +18,6 @@ namespace paddle {
namespace dialect {
phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); }
paddle::experimental::Scalar ScalarAttribute::data() const {
return storage()->GetAsKey();
}
phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); }
phi::Place PlaceAttribute::data() const { return storage()->GetAsKey(); }
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/ir/dialect/pd_attribute_storage.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
namespace paddle {
namespace dialect {
......@@ -37,13 +38,13 @@ class ScalarAttribute : public ir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ScalarAttribute, ScalarAttributeStorage);
bool operator<(const ScalarAttribute &right) const {
return storage() < right.storage();
static bool classof(ir::Attribute val) {
return (val.type_id() == ir::BoolAttribute::type_id()) ||
(val.type_id() == ir::FloatAttribute::type_id()) ||
(val.type_id() == ir::DoubleAttribute::type_id()) ||
(val.type_id() == ir::Int32_tAttribute::type_id()) ||
(val.type_id() == ir::Int64_tAttribute::type_id());
}
paddle::experimental::Scalar data() const;
};
class DataTypeAttribute : public ir::Attribute {
......
......@@ -20,7 +20,6 @@
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
namespace paddle {
namespace dialect {
......@@ -54,28 +53,6 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage {
phi::IntArray data_;
};
struct ScalarAttributeStorage : public ir::AttributeStorage {
using ParamKey = paddle::experimental::Scalar;
explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; }
static ScalarAttributeStorage *Construct(ParamKey key) {
return new ScalarAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return ir::hash_combine(std::hash<std::string>()(key.ToString()),
std::hash<bool>()(key.FromTensor()));
}
bool operator==(const ParamKey &key) const { return data_ == key; }
ParamKey GetAsKey() const { return ParamKey(data_); }
private:
paddle::experimental::Scalar data_;
};
struct DataTypeAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::DataType;
......
......@@ -92,7 +92,6 @@ void PaddleDialect::initialize() {
RegisterTypes<paddle::dialect::DenseTensorType>();
RegisterAttributes<paddle::dialect::IntArrayAttribute,
paddle::dialect::ScalarAttribute,
paddle::dialect::DataTypeAttribute,
paddle::dialect::PlaceAttribute,
paddle::dialect::DataLayoutAttribute>();
......
......@@ -17,14 +17,16 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/ir/dialect/pd_type_storage.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace dialect {
// TODO(zhangbo): The builtin type needs to cover all data types of
// phi::DataType.
inline phi::DataType TransToPhiDataType(ir::Type dtype) {
static inline phi::DataType TransToPhiDataType(ir::Type dtype) {
if (dtype.isa<ir::Float16Type>()) {
return phi::DataType::FLOAT16;
} else if (dtype.isa<ir::Float32Type>()) {
......@@ -44,8 +46,8 @@ inline phi::DataType TransToPhiDataType(ir::Type dtype) {
}
}
inline ir::Type TransToIrDataType(phi::DataType dtype,
ir::IrContext *ctx = nullptr) {
static inline ir::Type TransToIrDataType(phi::DataType dtype,
ir::IrContext *ctx = nullptr) {
if (ctx == nullptr) {
ctx = ir::IrContext::Instance();
}
......@@ -70,6 +72,30 @@ inline ir::Type TransToIrDataType(phi::DataType dtype,
}
}
static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
ir::IrContext *ctx = nullptr) {
if (ctx == nullptr) {
ctx = ir::IrContext::Instance();
}
switch (scalar.dtype()) {
case phi::DataType::FLOAT32:
return ir::FloatAttribute::get(ctx, scalar.to<float>());
case phi::DataType::FLOAT64:
return ir::DoubleAttribute::get(ctx, scalar.to<double>());
case phi::DataType::INT32:
return ir::Int32_tAttribute::get(ctx, scalar.to<int32_t>());
case phi::DataType::INT64:
return ir::Int64_tAttribute::get(ctx, scalar.to<int64_t>());
case phi::DataType::BOOL:
return ir::BoolAttribute::get(ctx, scalar.to<bool>());
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported phi data type `%s` when casting it into "
"ir attribute.",
scalar.dtype()));
}
}
struct OpInputInfo {
std::string name;
std::string type_name;
......
......@@ -18,6 +18,7 @@
#include <vector>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/layout.h"
......@@ -62,7 +63,7 @@ class AttributeVisitor {
virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) {
VLOG(10) << "translating scalar";
return paddle::dialect::ScalarAttribute::get(ctx, scalar);
IR_THROW("not support translating paddle::experimental::Scalar");
}
virtual ir::Attribute operator()(const std::vector<std::string>& strs) {
......@@ -128,12 +129,8 @@ class AttributeVisitor {
virtual ir::Attribute operator()(
const std::vector<paddle::experimental::Scalar>& ss) {
VLOG(10) << "translating vector<scalar>";
std::vector<ir::Attribute> attrs;
attrs.reserve(ss.size());
for (const auto& v : ss) {
attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v));
}
return ir::ArrayAttribute::get(ctx, attrs);
IR_THROW(
"not support translating std::vector<paddle::experimental::Scalar>");
}
virtual ir::Attribute operator()(const paddle::blank& blank) {
......@@ -164,22 +161,6 @@ class IntArrayAttributeVisitor : public AttributeVisitor {
}
};
class ScalarAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
ir::Attribute operator()(int i) override {
VLOG(10) << "translating int to Scalar";
phi::Scalar data(i);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}
ir::Attribute operator()(float f) override {
VLOG(10) << "translating float to Scalar";
phi::Scalar data(f);
return paddle::dialect::ScalarAttribute::get(ctx, data);
}
};
class DataTypeAttributeVisitor : public AttributeVisitor {
public:
using AttributeVisitor::AttributeVisitor;
......@@ -205,8 +186,6 @@ AttributeTranslator::AttributeTranslator() {
general_visitor = new AttributeVisitor();
special_visitors["paddle::dialect::IntArrayAttribute"] =
new IntArrayAttributeVisitor();
special_visitors["paddle::dialect::ScalarAttribute"] =
new ScalarAttributeVisitor();
special_visitors["paddle::dialect::DataTypeAttribute"] =
new DataTypeAttributeVisitor();
special_visitors["paddle::dialect::PlaceAttribute"] =
......
......@@ -32,6 +32,15 @@ cc_test_old(
phi
gtest)
cc_test_old(
scalar_attribute_test
SRCS
scalar_attribute_test.cc
DEPS
pd_dialect
new_ir
gtest)
file(
DOWNLOAD
https://paddle-ci.gz.bcebos.com/ir_translator_test/restnet50_main.prog
......
......@@ -74,10 +74,8 @@ TEST(program_test, program) {
ctx, std::vector<int64_t>({2, 2}));
ir::Attribute data_type =
paddle::dialect::DataTypeAttribute::get(ctx, phi::DataType::FLOAT32);
ir::Attribute min =
paddle::dialect::ScalarAttribute::get(ctx, phi::Scalar(0.0));
ir::Attribute max =
paddle::dialect::ScalarAttribute::get(ctx, phi::Scalar(1.0));
ir::Attribute min = ir::FloatAttribute::get(ctx, 0.0);
ir::Attribute max = ir::FloatAttribute::get(ctx, 1.0);
ir::Attribute seed = ir::Int32_tAttribute::get(ctx, 2);
ir::Attribute uni_place = paddle::dialect::PlaceAttribute::get(
ctx, phi::Place(phi::AllocationType::CPU));
......
......@@ -143,9 +143,6 @@ void build_context(ir::Operation* op,
} else if (type_name == "paddle::dialect::DataTypeAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::DataTypeAttribute>().data());
} else if (type_name == "paddle::dialect::ScalarAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<paddle::dialect::ScalarAttribute>().data());
} else if (type_name == "ir::Int32_tAttribute") {
ctx->EmplaceBackAttr(
attr_map[t].dyn_cast<ir::Int32_tAttribute>().data());
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
using ScalarAttribute = paddle::dialect::ScalarAttribute;
TEST(ScalarTest, base) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Attribute bool_scalar = ir::BoolAttribute::get(ctx, false);
EXPECT_TRUE(bool_scalar.isa<ScalarAttribute>());
EXPECT_TRUE(bool_scalar.isa<ir::BoolAttribute>());
ir::BoolAttribute pure_bool = bool_scalar.dyn_cast<ir::BoolAttribute>();
EXPECT_TRUE(pure_bool.isa<ScalarAttribute>());
ScalarAttribute scalar_from_bool = bool_scalar.dyn_cast<ScalarAttribute>();
EXPECT_TRUE(scalar_from_bool.isa<ir::BoolAttribute>());
EXPECT_NO_THROW(scalar_from_bool.dyn_cast<ir::BoolAttribute>());
}
TEST(ScalarTest, test_classof) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Attribute bool_scalar = ir::BoolAttribute::get(ctx, false);
EXPECT_TRUE(bool_scalar.isa<ScalarAttribute>());
ir::Attribute float_scalar = ir::FloatAttribute::get(ctx, 1.0f);
EXPECT_TRUE(float_scalar.isa<ScalarAttribute>());
ir::Attribute double_scalar = ir::DoubleAttribute::get(ctx, 1.0);
EXPECT_TRUE(double_scalar.isa<ScalarAttribute>());
ir::Attribute int32_scalar = ir::Int32_tAttribute::get(ctx, 1);
EXPECT_TRUE(int32_scalar.isa<ScalarAttribute>());
ir::Attribute int64_scalar = ir::Int64_tAttribute::get(ctx, 1l);
EXPECT_TRUE(int64_scalar.isa<ScalarAttribute>());
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册