From 2594935aef879fa4fcd5ef78cde38205d3c7d815 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 17 Aug 2022 16:50:27 +0800 Subject: [PATCH] [OpAttr]Add SupportTensor for OpMaker with whitelist mechanism (#45084) * [OpAttr]Add SupportTensor for OpMaker * fix typo * fix code style * add SupportTensor for concat op * add unittest for register Tensor * add shape checker and split attribute --- paddle/fluid/framework/attribute.h | 246 ------------ paddle/fluid/framework/attribute_checker.h | 353 ++++++++++++++++++ paddle/fluid/framework/framework.proto | 1 + paddle/fluid/framework/op_info.h | 1 + paddle/fluid/framework/op_proto_maker.h | 1 + paddle/fluid/framework/operator.cc | 2 +- paddle/fluid/operators/concat_op.cc | 3 +- paddle/fluid/operators/dropout_op.cc | 3 +- paddle/fluid/operators/tile_op.cc | 3 +- .../tests/unittests/test_attribute_var.py | 33 ++ 10 files changed, 396 insertions(+), 250 deletions(-) create mode 100644 paddle/fluid/framework/attribute_checker.h diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 4d3ba2a1820..b4a939f822b 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -292,251 +292,5 @@ class AttrReader { const AttributeMap* default_attrs_; }; -// check whether a value(attribute) fit a certain limit -template -class GreaterThanChecker { - public: - explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} - void operator()(const T& value) const { - PADDLE_ENFORCE_GT( - value, - lower_bound_, - platform::errors::OutOfRange("Check for attribute value greater than " - "a certain value failed.")); - } - - private: - T lower_bound_; -}; - -template -class EqualGreaterThanChecker { - public: - explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} - void operator()(const T& value) const { - PADDLE_ENFORCE_GE( - value, - lower_bound_, - platform::errors::OutOfRange("Check for attribute valur equal or " - "greater than a certain value failed.")); - } - - private: - T lower_bound_; -}; - -// we can provide users more common Checker, like 'LessThanChecker', -// 'BetweenChecker'... - -template -class DefaultValueSetter { - public: - explicit DefaultValueSetter(T default_value) - : default_value_(default_value) {} - const T& operator()() const { return default_value_; } - - private: - T default_value_; -}; - -template -class EnumInContainer { - public: - explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} - void operator()(const T& val) const { - PADDLE_ENFORCE_NE( - container_.find(val), - container_.end(), - platform::errors::NotFound("Value %s is not in enum container %s.", - val, - ContainerDebugString())); - } - - private: - std::string ContainerDebugString() const { - std::ostringstream sout; - sout << "["; - size_t cnt = 0; - for (auto& v : container_) { - sout << v; - ++cnt; - if (cnt != container_.size()) { - sout << " ,"; - } - } - sout << "]"; - return sout.str(); - } - - std::unordered_set container_; -}; - -// check whether a certain attribute fit its limits -// an attribute can have more than one limits -template -class TypedAttrChecker { - typedef std::function DefaultValueChecker; - typedef std::function ValueChecker; - - public: - explicit TypedAttrChecker(const std::string& attr_name, - proto::OpProto_Attr* attr) - : attr_name_(attr_name), attr_(attr) {} - - TypedAttrChecker& AsExtra() { - attr_->set_extra(true); - return *this; - } - - TypedAttrChecker& AsQuant() { - attr_->set_quant(true); - return *this; - } - - TypedAttrChecker& InEnum(const std::unordered_set& range) { - value_checkers_.push_back(EnumInContainer(range)); - return *this; - } - - TypedAttrChecker& GreaterThan(const T& lower_bound) { - value_checkers_.push_back(GreaterThanChecker(lower_bound)); - return *this; - } - - TypedAttrChecker& EqualGreaterThan(const T& lower_bound) { - value_checkers_.push_back(EqualGreaterThanChecker(lower_bound)); - return *this; - } - - // we can add more common limits, like LessThan(), Between()... - - TypedAttrChecker& SetDefault(const T& default_value) { - PADDLE_ENFORCE_EQ( - default_value_setter_.empty(), - true, - platform::errors::AlreadyExists("Attribute (%s) has a default value " - "and cannot be set repeatedly.", - attr_name_)); - default_value_setter_.push_back(DefaultValueSetter(default_value)); - return *this; - } - - // allow users provide their own checker - TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) { - value_checkers_.push_back(checker); - return *this; - } - - void operator()(AttributeMap* attr_map, - bool get_default_value_only = false, - bool only_check_exist_value = false) const { - if (get_default_value_only) { - if (!default_value_setter_.empty()) { - attr_map->emplace(attr_name_, default_value_setter_[0]()); - } - return; - } - // If attribute is VarDesc(s), we should verify it's dtype and shape. - auto it = attr_map->find(attr_name_); - if (it != attr_map->end() && HasAttrVar(it->second)) { - VLOG(1) << "Found Attribute " << attr_name_ - << " with Variable, skip attr_checker."; - return; - } - - if (only_check_exist_value) { - if (it != attr_map->end()) { - ExtractAttribute extract_attr(attr_name_); - T* attr_value = extract_attr(it->second); - for (const auto& checker : value_checkers_) { - checker(*attr_value); - } - } - } else { - if (it == attr_map->end()) { - // user do not set this attr - PADDLE_ENFORCE_EQ( - default_value_setter_.empty(), - false, - platform::errors::InvalidArgument( - "Attribute (%s) is not set correctly.", attr_name_)); - // default_value_setter_ has no more than one element - auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]()); - it = tmp.first; - } - ExtractAttribute extract_attr(attr_name_); - T* attr_value = extract_attr(it->second); - for (const auto& checker : value_checkers_) { - checker(*attr_value); - } - } - } - - private: - std::string attr_name_; - proto::OpProto_Attr* attr_; - std::vector value_checkers_; - std::vector default_value_setter_; -}; - -// check whether op's all attributes fit their own limits -class OpAttrChecker { - typedef std::function AttrChecker; - - public: - template - TypedAttrChecker& AddAttrChecker(const std::string& attr_name, - proto::OpProto_Attr* attr) { - attr_checkers_.push_back(TypedAttrChecker(attr_name, attr)); - AttrChecker& checker = attr_checkers_.back(); - return *(checker.target>()); - } - - void Check(AttributeMap* attr_map, - bool explicit_only = false, - bool only_check_exist_value = false) const { - auto checker_num = attr_checkers_.size(); - if (explicit_only) checker_num = explicit_checker_num_; - for (size_t i = 0; i < checker_num; ++i) { - attr_checkers_[i](attr_map, false, only_check_exist_value); - } - } - - AttributeMap GetDefaultAttrsMap() const { - AttributeMap default_values_map; - for (const auto& checker : attr_checkers_) { - checker(&default_values_map, true, false); - } - return default_values_map; - } - - void RecordExplicitCheckerNum() { - explicit_checker_num_ = attr_checkers_.size(); - } - - void InitDefaultAttributeMap() { - for (const auto& checker : attr_checkers_) { - checker(&default_attrs_, true, false); - } - } - - const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; } - - private: - std::vector attr_checkers_; - - AttributeMap default_attrs_; - - // in order to improve the efficiency of dynamic graph mode, - // we divede the attribute into explicit type and implicit type. - // for explicit attribute, we mean the attribute added in the customized - // op makers, usually it's defined in the overloaded Make method. - // for implicit attribute, we mean the attribute added outside of the Make - // method like "op_role", "op_role_var", and they are useless in dynamic - // graph - // mode - size_t explicit_checker_num_; -}; - } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/attribute_checker.h b/paddle/fluid/framework/attribute_checker.h new file mode 100644 index 00000000000..f3650dc085d --- /dev/null +++ b/paddle/fluid/framework/attribute_checker.h @@ -0,0 +1,353 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/var_desc.h" + +namespace paddle { +namespace framework { +// check whether a value(attribute) fit a certain limit +template +class GreaterThanChecker { + public: + explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + void operator()(const T& value) const { + PADDLE_ENFORCE_GT( + value, + lower_bound_, + platform::errors::OutOfRange("Check for attribute value greater than " + "a certain value failed.")); + } + + private: + T lower_bound_; +}; + +template +class EqualGreaterThanChecker { + public: + explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + void operator()(const T& value) const { + PADDLE_ENFORCE_GE( + value, + lower_bound_, + platform::errors::OutOfRange("Check for attribute valur equal or " + "greater than a certain value failed.")); + } + + private: + T lower_bound_; +}; + +template +class TypedAttrVarInfoChecker { + public: + TypedAttrVarInfoChecker() = default; + + void operator()(const Attribute& attr) const { + if (IsAttrVar(attr)) { + auto* var_desc = PADDLE_GET_CONST(VarDesc*, attr); + check(var_desc); + } else if (IsAttrVars(attr)) { + auto var_descs = PADDLE_GET_CONST(std::vector, attr); + check(var_descs); + } + } + + void check(const VarDesc* var_desc) const { + PADDLE_ENFORCE_NOT_NULL( + var_desc, + platform::errors::InvalidArgument( + "Required Attribute with Variable type shall not be nullptr.")); + auto shape = var_desc->GetShape(); + PADDLE_ENFORCE_EQ(shape.size(), + 1U, + platform::errors::InvalidArgument( + "Required shape rank of Attribute(%s) == 1, " + "but received rank == %s", + var_desc->Name(), + shape.size())); + + auto& expected_type = typeid(T); + auto dtype = var_desc->GetDataType(); + // attribute is a IntArray + if (expected_type == typeid(std::vector) || + expected_type == typeid(std::vector)) { + bool is_int = (dtype == proto::VarType::Type::VarType_Type_INT32 || + dtype == proto::VarType::Type::VarType_Type_INT64); + PADDLE_ENFORCE_EQ(is_int, + true, + platform::errors::InvalidArgument( + "Required dtype of Attribute(%s) shall be " + "int32|int64, but recevied %s.", + var_desc->Name(), + dtype)); + } + } + + void check(const std::vector& var_descs) const { + for (auto& var_desc : var_descs) { + PADDLE_ENFORCE_NOT_NULL( + var_desc, + platform::errors::InvalidArgument( + "Required Attribute with Variable type shall not be nullptr.")); + auto shape = var_desc->GetShape(); + PADDLE_ENFORCE_EQ(shape.size(), + 1U, + platform::errors::InvalidArgument( + "Required shape rank of Attribute(%s) == 1, " + "but received rank == %s", + var_desc->Name(), + shape.size())); + PADDLE_ENFORCE_EQ(shape[0] == 1U || shape[0] == -1, + true, + platform::errors::InvalidArgument( + "Required shape[0] of Attribute(%s) == 1 or -1, " + "but received shape[0] == %s", + var_desc->Name(), + shape[0])); + } + } +}; + +// we can provide users more common Checker, like 'LessThanChecker', +// 'BetweenChecker'... + +template +class DefaultValueSetter { + public: + explicit DefaultValueSetter(T default_value) + : default_value_(default_value) {} + const T& operator()() const { return default_value_; } + + private: + T default_value_; +}; + +template +class EnumInContainer { + public: + explicit EnumInContainer(const std::unordered_set& c) : container_(c) {} + void operator()(const T& val) const { + PADDLE_ENFORCE_NE( + container_.find(val), + container_.end(), + platform::errors::NotFound("Value %s is not in enum container %s.", + val, + ContainerDebugString())); + } + + private: + std::string ContainerDebugString() const { + std::ostringstream sout; + sout << "["; + size_t cnt = 0; + for (auto& v : container_) { + sout << v; + ++cnt; + if (cnt != container_.size()) { + sout << " ,"; + } + } + sout << "]"; + return sout.str(); + } + + std::unordered_set container_; +}; + +// check whether a certain attribute fit its limits +// an attribute can have more than one limits +template +class TypedAttrChecker { + typedef std::function DefaultValueChecker; + typedef std::function ValueChecker; + + public: + explicit TypedAttrChecker(const std::string& attr_name, + proto::OpProto_Attr* attr) + : attr_name_(attr_name), attr_(attr) {} + + TypedAttrChecker& AsExtra() { + attr_->set_extra(true); + return *this; + } + + TypedAttrChecker& AsQuant() { + attr_->set_quant(true); + return *this; + } + + TypedAttrChecker& SupportTensor() { + attr_->set_support_tensor(true); + return *this; + } + + TypedAttrChecker& InEnum(const std::unordered_set& range) { + value_checkers_.push_back(EnumInContainer(range)); + return *this; + } + + TypedAttrChecker& GreaterThan(const T& lower_bound) { + value_checkers_.push_back(GreaterThanChecker(lower_bound)); + return *this; + } + + TypedAttrChecker& EqualGreaterThan(const T& lower_bound) { + value_checkers_.push_back(EqualGreaterThanChecker(lower_bound)); + return *this; + } + + // we can add more common limits, like LessThan(), Between()... + + TypedAttrChecker& SetDefault(const T& default_value) { + PADDLE_ENFORCE_EQ( + default_value_setter_.empty(), + true, + platform::errors::AlreadyExists("Attribute (%s) has a default value " + "and cannot be set repeatedly.", + attr_name_)); + default_value_setter_.push_back(DefaultValueSetter(default_value)); + return *this; + } + + // allow users provide their own checker + TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) { + value_checkers_.push_back(checker); + return *this; + } + + void operator()(AttributeMap* attr_map, + bool get_default_value_only = false, + bool only_check_exist_value = false) const { + if (get_default_value_only) { + if (!default_value_setter_.empty()) { + attr_map->emplace(attr_name_, default_value_setter_[0]()); + } + return; + } + // If attribute is VarDesc(s), we should verify it's supported in OpMaker + auto it = attr_map->find(attr_name_); + if (it != attr_map->end() && HasAttrVar(it->second)) { + PADDLE_ENFORCE_EQ(attr_->support_tensor(), + true, + platform::errors::InvalidArgument( + "Found Attribute('%s') with type(Variable), but it " + "doesn't support Tensor type.", + attr_name_)); + + VLOG(1) << "Found Attribute " << attr_name_ << " with type(Variable)."; + var_info_checker_(it->second); + return; + } + + if (only_check_exist_value) { + if (it != attr_map->end()) { + ExtractAttribute extract_attr(attr_name_); + T* attr_value = extract_attr(it->second); + for (const auto& checker : value_checkers_) { + checker(*attr_value); + } + } + } else { + if (it == attr_map->end()) { + // user do not set this attr + PADDLE_ENFORCE_EQ( + default_value_setter_.empty(), + false, + platform::errors::InvalidArgument( + "Attribute (%s) is not set correctly.", attr_name_)); + // default_value_setter_ has no more than one element + auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]()); + it = tmp.first; + } + ExtractAttribute extract_attr(attr_name_); + T* attr_value = extract_attr(it->second); + for (const auto& checker : value_checkers_) { + checker(*attr_value); + } + } + } + + private: + std::string attr_name_; + proto::OpProto_Attr* attr_; + TypedAttrVarInfoChecker var_info_checker_; + std::vector value_checkers_; + std::vector default_value_setter_; +}; + +// check whether op's all attributes fit their own limits +class OpAttrChecker { + typedef std::function AttrChecker; + + public: + template + TypedAttrChecker& AddAttrChecker(const std::string& attr_name, + proto::OpProto_Attr* attr) { + attr_checkers_.push_back(TypedAttrChecker(attr_name, attr)); + AttrChecker& checker = attr_checkers_.back(); + return *(checker.target>()); + } + + void Check(AttributeMap* attr_map, + bool explicit_only = false, + bool only_check_exist_value = false) const { + auto checker_num = attr_checkers_.size(); + if (explicit_only) checker_num = explicit_checker_num_; + for (size_t i = 0; i < checker_num; ++i) { + attr_checkers_[i](attr_map, false, only_check_exist_value); + } + } + + AttributeMap GetDefaultAttrsMap() const { + AttributeMap default_values_map; + for (const auto& checker : attr_checkers_) { + checker(&default_values_map, true, false); + } + return default_values_map; + } + + void RecordExplicitCheckerNum() { + explicit_checker_num_ = attr_checkers_.size(); + } + + void InitDefaultAttributeMap() { + for (const auto& checker : attr_checkers_) { + checker(&default_attrs_, true, false); + } + } + + const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; } + + private: + std::vector attr_checkers_; + + AttributeMap default_attrs_; + + // in order to improve the efficiency of dynamic graph mode, + // we divede the attribute into explicit type and implicit type. + // for explicit attribute, we mean the attribute added in the customized + // op makers, usually it's defined in the overloaded Make method. + // for implicit attribute, we mean the attribute added outside of the Make + // method like "op_role", "op_role_var", and they are useless in dynamic + // graph + // mode + size_t explicit_checker_num_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index b58b643cdff..61a495a59a9 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -102,6 +102,7 @@ message OpProto { optional bool generated = 4 [ default = false ]; optional bool extra = 5 [ default = false ]; optional bool quant = 6 [ default = false ]; + optional bool support_tensor = 7 [ default = false]; } required string type = 1; diff --git a/paddle/fluid/framework/op_info.h b/paddle/fluid/framework/op_info.h index c2f64cfaea2..5a40c4acc7e 100644 --- a/paddle/fluid/framework/op_info.h +++ b/paddle/fluid/framework/op_info.h @@ -20,6 +20,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/attribute_checker.h" #include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/framework/op_proto_maker.h b/paddle/fluid/framework/op_proto_maker.h index 9cea78c92c6..3440f049ef7 100644 --- a/paddle/fluid/framework/op_proto_maker.h +++ b/paddle/fluid/framework/op_proto_maker.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/framework/attribute.h" +#include "paddle/fluid/framework/attribute_checker.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 4fb7b0e018d..35865252629 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -447,7 +447,7 @@ OperatorBase::OperatorBase(const std::string& type, GenerateTemporaryNames(); CheckAllInputOutputSet(); } - // In OperatorBase level, all attribute with VarDesc type will be considered + // In OperatorBase level, all attributes with VarDesc type will be considered // as Input. for (auto& attr : FilterAttrVar(attrs)) { VLOG(3) << "found Attribute with Variable type: " << attr.first; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index d4cbed66638..75dbb9b0379 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -91,7 +91,8 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { "The axis could also be negative numbers. Negative axis is " "interpreted as counting from the end of the rank." "i.e., axis + rank(X) th dimension.") - .SetDefault(0); + .SetDefault(0) + .SupportTensor(); AddInput("AxisTensor", "(Tensor) The axis along which the input tensors will be " "concatenated. " diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 879abd1f631..84784c3c603 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -71,7 +71,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { true, platform::errors::InvalidArgument( "'dropout_prob' must be between 0.0 and 1.0.")); - }); + }) + .SupportTensor(); AddAttr("is_test", "(bool, default false) Set to true for inference only, false " "for training. Some layers may run faster when this is true.") diff --git a/paddle/fluid/operators/tile_op.cc b/paddle/fluid/operators/tile_op.cc index 1d5b57a8a3d..8cf13291540 100644 --- a/paddle/fluid/operators/tile_op.cc +++ b/paddle/fluid/operators/tile_op.cc @@ -74,7 +74,8 @@ class TileOpMaker : public framework::OpProtoAndCheckerMaker { "the corresponding value given by Attr(repeat_times)."); AddAttr>("repeat_times", "The number of repeat times for each dimension.") - .SetDefault({}); + .SetDefault({}) + .SupportTensor(); AddComment(R"DOC( Tile operator repeats the input by given times number. You should set times number for each dimension by providing attribute 'repeat_times'. The rank of X diff --git a/python/paddle/fluid/tests/unittests/test_attribute_var.py b/python/paddle/fluid/tests/unittests/test_attribute_var.py index 950d33a9bbc..5d0316edfa4 100644 --- a/python/paddle/fluid/tests/unittests/test_attribute_var.py +++ b/python/paddle/fluid/tests/unittests/test_attribute_var.py @@ -18,6 +18,7 @@ import tempfile import paddle import paddle.inference as paddle_infer from paddle.fluid.framework import program_guard, Program +from paddle.fluid.framework import OpProtoHolder import numpy as np paddle.enable_static() @@ -154,5 +155,37 @@ class TestTileTensor(UnittestBase): self.assertEqual(infer_out.shape, (6, 6, 10)) +class TestRegiterSupportTensorInOpMaker(unittest.TestCase): + + def setUp(self): + self.all_protos = OpProtoHolder.instance() + self.support_tensor_attrs = { + 'dropout': ['dropout_prob'], + 'tile': ['repeat_times'], + 'concat': ['axis'] + } + # Just add a op example to test not support tensor + self.not_support_tensor_attrs = {'svd': ['full_matrices']} + + def test_support_tensor(self): + # All Attribute tagged with .SupportTensor() in OpMaker will return True + for op_type, attr_names in self.support_tensor_attrs.items(): + for attr_name in attr_names: + self.assertTrue(self.is_support_tensor_attr(op_type, attr_name)) + + # All Attribute not tagged with .SupportTensor() in OpMaker will return False + for op_type, attr_names in self.not_support_tensor_attrs.items(): + for attr_name in attr_names: + self.assertFalse(self.is_support_tensor_attr( + op_type, attr_name)) + + def is_support_tensor_attr(self, op_type, attr_name): + proto = self.all_protos.get_op_proto(op_type) + for attr in proto.attrs: + if attr.name == attr_name: + return attr.support_tensor + raise RuntimeError("Not found attribute : ", attr_name) + + if __name__ == '__main__': unittest.main() -- GitLab