// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" namespace paddle { namespace dialect { const std::unordered_set LegacyOpList = { "pd.load_combine", "pd.c_concat", "pd.c_broadcast_"}; enum class AttrType { UNDEFINED = 0, BOOL, INT32, INT64, FLOAT, DOUBLE, ARRAY, INT_ARRAY, SCALAR, DATA_TYPE, DATA_LAYOUT, PLACE, STRING, NUM_ATTR_TYPES, }; static inline AttrType GetAttributeType(const ir::Attribute& attr) { if (attr.isa()) { return AttrType::BOOL; } else if (attr.isa()) { return AttrType::FLOAT; } else if (attr.isa()) { return AttrType::DOUBLE; } else if (attr.isa()) { return AttrType::INT32; } else if (attr.isa()) { return AttrType::INT64; } else if (attr.isa()) { return AttrType::ARRAY; } else if (attr.isa()) { return AttrType::STRING; } else if (attr.isa()) { return AttrType::INT_ARRAY; } else if (attr.isa()) { return AttrType::DATA_TYPE; } else if (attr.isa()) { return AttrType::PLACE; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir Attribute type when casting it into " "AttrType.")); } } static std::unordered_map> kAttrCastMap = { {AttrType::BOOL, [](const ir::Attribute& attr) { return VariantType{attr.dyn_cast().data()}; }}, {AttrType::FLOAT, [](const ir::Attribute& attr) { return VariantType{attr.dyn_cast().data()}; }}, {AttrType::DOUBLE, [](const ir::Attribute& attr) { return VariantType{attr.dyn_cast().data()}; }}, {AttrType::INT32, [](const ir::Attribute& attr) { return VariantType{attr.dyn_cast().data()}; }}, {AttrType::INT64, [](const ir::Attribute& attr) { return VariantType{attr.dyn_cast().data()}; }}, {AttrType::INT_ARRAY, [](const ir::Attribute& attr) { return VariantType{ attr.dyn_cast() .data() .GetData()}; }}, {AttrType::STRING, [](const ir::Attribute& attr) { return VariantType{attr.dyn_cast().AsString()}; }}, {AttrType::DATA_TYPE, [](const ir::Attribute& attr) { return VariantType{ attr.dyn_cast().data()}; }}, {AttrType::PLACE, [](const ir::Attribute& attr) { return VariantType{ attr.dyn_cast().data()}; }}, {AttrType::ARRAY, [](const ir::Attribute& attr) { auto attr_vec = attr.dyn_cast().AsVector(); if (attr_vec.size() == 0) { return VariantType{std::vector()}; } AttrType element_type = GetAttributeType(attr_vec[0]); if (element_type == AttrType::BOOL) { std::vector vec_bools; for (auto vec_element : attr_vec) { vec_bools.push_back( vec_element.dyn_cast().data()); } return VariantType{vec_bools}; } else if (element_type == AttrType::INT32) { std::vector vec_int32; for (auto vec_element : attr_vec) { vec_int32.push_back( vec_element.dyn_cast().data()); } return VariantType{vec_int32}; } else if (element_type == AttrType::INT64) { std::vector vec_int64; for (auto vec_element : attr_vec) { vec_int64.push_back( vec_element.dyn_cast().data()); } return VariantType{vec_int64}; } else if (element_type == AttrType::FLOAT) { std::vector vec_float; for (auto vec_element : attr_vec) { vec_float.push_back( vec_element.dyn_cast().data()); } return VariantType{vec_float}; } else if (element_type == AttrType::DOUBLE) { std::vector vec_double; for (auto vec_element : attr_vec) { vec_double.push_back( vec_element.dyn_cast().data()); } return VariantType{vec_double}; } else { PADDLE_THROW(phi::errors::Unimplemented( "Unsupported ir Attribute type when casting it into " "vector.")); } }}, }; VariantType GetAttributeData(const ir::Attribute& attr) { AttrType attr_type = GetAttributeType(attr); return kAttrCastMap[attr_type](attr); } bool IsLegacyOp(const std::string& name) { return LegacyOpList.count(name); } } // namespace dialect } // namespace paddle