/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include #include #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" #include "paddle/utils/any.h" #include "paddle/utils/variant.h" namespace paddle { namespace framework { paddle::any GetAttrValue(const Attribute& attr); Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); template struct ExtractAttribute { explicit ExtractAttribute(const std::string& attr_name) : attr_name_(attr_name) {} T* operator()(Attribute& attr) const { T* attr_value = nullptr; try { attr_value = &paddle::get(attr); } catch (paddle::bad_variant_access const& bad_get) { PADDLE_THROW(platform::errors::InvalidArgument( "Cannot get attribute (%s) by type %s, its type is %s.", attr_name_, paddle::platform::demangle(typeid(T).name()), paddle::platform::demangle(attr.type().name()))); } return attr_value; } const std::string& attr_name_; }; // special handle bool // FIXME(yuyang18): Currently we cast bool into int in python binding. It is // hard to change the logic there. In another way, we should correct handle // if the user set `some_flag=1`. // // FIX ME anytime if there is a better solution. template <> struct ExtractAttribute { explicit ExtractAttribute(const std::string& attr_name) : attr_name_(attr_name) {} bool* operator()(Attribute& attr) const { if (attr.type() == typeid(int)) { // NOLINT int val = PADDLE_GET_CONST(int, attr); attr = static_cast(val); } else if (attr.type() == typeid(float)) { // NOLINT float val = PADDLE_GET_CONST(float, attr); attr = static_cast(val); } bool* attr_value = nullptr; try { attr_value = &paddle::get(attr); } catch (paddle::bad_variant_access const& bad_get) { PADDLE_THROW(platform::errors::InvalidArgument( "Cannot get attribute (%s) by type bool, its type is %s.", attr_name_, paddle::platform::demangle(attr.type().name()))); } return attr_value; } const std::string& attr_name_; }; template <> struct ExtractAttribute { explicit ExtractAttribute(const std::string& attr_name) : attr_name_(attr_name) {} int64_t* operator()(Attribute& attr) const { if (attr.type() == typeid(int)) { // NOLINT int val = PADDLE_GET_CONST(int, attr); attr = static_cast(val); } else if (attr.type() == typeid(float)) { // NOLINT int val = PADDLE_GET_CONST(float, attr); attr = static_cast(val); } int64_t* attr_value = nullptr; try { attr_value = &paddle::get(attr); } catch (paddle::bad_variant_access const& bad_get) { PADDLE_THROW(platform::errors::InvalidArgument( "Cannot get attribute (%s) by type int64_t, its type is %s.", attr_name_, paddle::platform::demangle(attr.type().name()))); } return attr_value; } const std::string& attr_name_; }; template <> struct ExtractAttribute> { explicit ExtractAttribute(const std::string& attr_name) : attr_name_(attr_name) {} std::vector* operator()(Attribute& attr) const { if (attr.type() == typeid(std::vector)) { // NOLINT std::vector val = PADDLE_GET_CONST(std::vector, attr); std::vector vec(val.begin(), val.end()); attr = vec; } else if (attr.type() == typeid(std::vector)) { // NOLINT std::vector val = PADDLE_GET_CONST(std::vector, attr); std::vector vec(val.begin(), val.end()); attr = vec; } std::vector* attr_value = nullptr; try { attr_value = &paddle::get>(attr); } catch (paddle::bad_variant_access const& bad_get) { PADDLE_THROW(platform::errors::InvalidArgument( "Cannot get attribute (%s) by type std::vector, its type is " "%s.", attr_name_, paddle::platform::demangle(attr.type().name()))); } return attr_value; } const std::string& attr_name_; }; template <> struct ExtractAttribute { explicit ExtractAttribute(const std::string& attr_name) : attr_name_(attr_name) {} float* operator()(Attribute& attr) const { if (attr.type() == typeid(int)) { // NOLINT int val = PADDLE_GET_CONST(int, attr); attr = static_cast(val); } else if (attr.type() == typeid(int64_t)) { // NOLINT int64_t val = PADDLE_GET_CONST(int64_t, attr); attr = static_cast(val); } float* attr_value = nullptr; try { attr_value = &paddle::get(attr); } catch (paddle::bad_variant_access const& bad_get) { PADDLE_THROW(platform::errors::InvalidArgument( "Cannot get attribute (%s) by type float, its type is %s.", attr_name_, paddle::platform::demangle(attr.type().name()))); } return attr_value; } const std::string& attr_name_; }; template <> struct ExtractAttribute { explicit ExtractAttribute(const std::string& attr_name) : attr_name_(attr_name) {} double* operator()(Attribute& attr) const { if (attr.type() == typeid(int)) { // NOLINT int val = PADDLE_GET_CONST(int, attr); attr = static_cast(val); } else if (attr.type() == typeid(int64_t)) { // NOLINT int64_t val = PADDLE_GET_CONST(int64_t, attr); attr = static_cast(val); } else if (attr.type() == typeid(float)) { // NOLINT int64_t val = PADDLE_GET_CONST(float, attr); attr = static_cast(val); } double* attr_value = nullptr; try { attr_value = &paddle::get(attr); } catch (paddle::bad_variant_access const& bad_get) { PADDLE_THROW(platform::errors::InvalidArgument( "Cannot get attribute (%s) by type double, its type is %s.", attr_name_, paddle::platform::demangle(attr.type().name()))); } return attr_value; } const std::string& attr_name_; }; template <> struct ExtractAttribute> { explicit ExtractAttribute(const std::string& attr_name) : attr_name_(attr_name) {} std::vector* operator()(Attribute& attr) const { if (attr.type() == typeid(std::vector)) { // NOLINT std::vector val = PADDLE_GET_CONST(std::vector, attr); std::vector vec(val.begin(), val.end()); attr = vec; } else if (attr.type() == typeid(std::vector)) { // NOLINT std::vector val = PADDLE_GET_CONST(std::vector, attr); std::vector vec(val.begin(), val.end()); attr = vec; } std::vector* attr_value = nullptr; try { attr_value = &paddle::get>(attr); } catch (paddle::bad_variant_access const& bad_get) { PADDLE_THROW(platform::errors::InvalidArgument( "Cannot get attribute (%s) by type std::vector, its type is " "%s.", attr_name_, paddle::platform::demangle(attr.type().name()))); } return attr_value; } const std::string& attr_name_; }; template inline proto::AttrType AttrTypeID() { Attribute tmp = T(); return static_cast(tmp.index() - 1); } inline proto::AttrType AttrTypeID(const Attribute& attr) { return static_cast(attr.index() - 1); } inline bool IsAttrVar(const Attribute& attr) { return AttrTypeID(attr) == proto::AttrType::VAR; } inline bool IsAttrVars(const Attribute& attr) { return AttrTypeID(attr) == proto::AttrType::VARS; } inline bool HasAttrVar(const Attribute& attr) { return IsAttrVar(attr) || IsAttrVars(attr); } inline AttributeMap FilterAttrVar(const AttributeMap& attrs) { AttributeMap attrs_var; for (auto& attr : attrs) { if (HasAttrVar(attr.second)) { attrs_var.emplace(attr); } } return attrs_var; } class AttrReader { public: explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs), default_attrs_(nullptr) {} AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs) : attrs_(attrs), default_attrs_(&default_attrs) {} template inline const T& Get(const std::string& name) const { auto it = attrs_.find(name); bool found = it != attrs_.end(); if (!found) { if (default_attrs_ != nullptr) { it = default_attrs_->find(name); found = it != default_attrs_->end(); } } PADDLE_ENFORCE_EQ(found, true, platform::errors::NotFound( "Attribute (%s) should be in AttributeMap.", name)); Attribute& attr = const_cast(it->second); ExtractAttribute extract_attr(name); T* attr_value = extract_attr(attr); return *attr_value; } const Attribute* GetAttr(const std::string& name) const { auto it = attrs_.find(name); bool found = it != attrs_.end(); if (!found) { if (default_attrs_ != nullptr) { it = default_attrs_->find(name); found = it != default_attrs_->end(); } } if (found) { return &it->second; } return nullptr; } private: const AttributeMap& attrs_; const AttributeMap* default_attrs_; }; } // namespace framework } // namespace paddle