// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/translator/attribute_translator.h" #include #include #include "paddle/fluid/dialect/pd_attribute.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/layout.h" #include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" #include "paddle/utils/variant.h" namespace paddle { namespace translator { class AttributeVisitor { public: ir::IrContext* ctx; AttributeVisitor() { ctx = ir::IrContext::Instance(); } ~AttributeVisitor() {} public: virtual ir::Attribute operator()(int i) { VLOG(10) << "translating int"; return ir::Int32_tAttribute::get(ctx, i); } virtual ir::Attribute operator()(float f) { VLOG(10) << "translating float"; return ir::FloatAttribute::get(ctx, f); } virtual ir::Attribute operator()(bool b) { VLOG(10) << "translating bool"; return ir::BoolAttribute::get(ctx, b); } virtual ir::Attribute operator()(double d) { VLOG(10) << "translating double"; return ir::DoubleAttribute::get(ctx, d); } virtual ir::Attribute operator()(std::string str) { VLOG(10) << "translating string"; return ir::StrAttribute::get(ctx, str); } virtual ir::Attribute operator()(const paddle::experimental::Scalar& scalar) { VLOG(10) << "translating scalar"; return paddle::dialect::ScalarAttribute::get(ctx, scalar); } virtual ir::Attribute operator()(const std::vector& strs) { VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(strs.size()); for (const auto& v : strs) { attrs.push_back(ir::StrAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } virtual ir::Attribute operator()(const std::vector& fs) { VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(fs.size()); for (const auto& v : fs) { attrs.push_back(ir::FloatAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } virtual ir::Attribute operator()(const std::vector& is) { VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(is.size()); for (const auto& v : is) { attrs.push_back(ir::Int32_tAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } virtual ir::Attribute operator()(const std::vector& bs) { VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(bs.size()); for (const auto& v : bs) { attrs.push_back(ir::BoolAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } virtual ir::Attribute operator()(const std::vector& i64s) { VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(i64s.size()); for (const auto& v : i64s) { attrs.push_back(ir::Int64_tAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } virtual ir::Attribute operator()(const std::vector& ds) { VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(ds.size()); for (const auto& v : ds) { attrs.push_back(ir::DoubleAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } virtual ir::Attribute operator()( const std::vector& ss) { VLOG(10) << "translating vector"; std::vector attrs; attrs.reserve(ss.size()); for (const auto& v : ss) { attrs.push_back(paddle::dialect::ScalarAttribute::get(ctx, v)); } return ir::ArrayAttribute::get(ctx, attrs); } virtual ir::Attribute operator()(const paddle::blank& blank) { VLOG(10) << "translating paddle::blank"; return ir::Attribute(nullptr); } template ir::Attribute operator()(T attr) { VLOG(10) << "translating null type"; return ir::Attribute(nullptr); } }; class IntArrayAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; ir::Attribute operator()(const std::vector& is) override { VLOG(10) << "translating vector to IntArray"; phi::IntArray data(is); return paddle::dialect::IntArrayAttribute::get(ctx, data); } ir::Attribute operator()(const std::vector& is) override { VLOG(10) << "translating vector to IntArray"; phi::IntArray data(is); return paddle::dialect::IntArrayAttribute::get(ctx, data); } }; class ScalarAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; ir::Attribute operator()(int i) override { VLOG(10) << "translating int to Scalar"; phi::Scalar data(i); return paddle::dialect::ScalarAttribute::get(ctx, data); } ir::Attribute operator()(float f) override { VLOG(10) << "translating float to Scalar"; phi::Scalar data(f); return paddle::dialect::ScalarAttribute::get(ctx, data); } }; class DataTypeAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; ir::Attribute operator()(int i) override { VLOG(10) << "translating int to DataType: " << i; phi::DataType data = static_cast(i); return paddle::dialect::DataTypeAttribute::get(ctx, data); } }; class PlaceAttributeVisitor : public AttributeVisitor { public: using AttributeVisitor::AttributeVisitor; ir::Attribute operator()(const paddle::blank& blank) override { VLOG(10) << "translating paddle::blank"; phi::Place data(phi::AllocationType::CPU); return paddle::dialect::PlaceAttribute::get(ctx, data); } }; AttributeTranslator::AttributeTranslator() { general_visitor = new AttributeVisitor(); special_visitors["paddle::dialect::IntArrayAttribute"] = new IntArrayAttributeVisitor(); special_visitors["paddle::dialect::ScalarAttribute"] = new ScalarAttributeVisitor(); special_visitors["paddle::dialect::DataTypeAttribute"] = new DataTypeAttributeVisitor(); special_visitors["paddle::dialect::PlaceAttribute"] = new PlaceAttributeVisitor(); } ir::Attribute AttributeTranslator::operator()( const framework::Attribute& attr) { return paddle::visit(*general_visitor, attr); } ir::Attribute AttributeTranslator::operator()( const std::string& target_type, const framework::Attribute& attr) { if (special_visitors.find(target_type) == special_visitors.end()) { VLOG(10) << "[" << target_type << "] not found"; return paddle::visit(*general_visitor, attr); } return paddle::visit(*(special_visitors.at(target_type)), attr); } } // namespace translator } // namespace paddle