未验证 提交 2594935a 编写于 作者: A Aurelius84 提交者: GitHub

[OpAttr]Add SupportTensor for OpMaker with whitelist mechanism (#45084)

* [OpAttr]Add SupportTensor for OpMaker

* fix typo

* fix code style

* add SupportTensor for concat op

* add unittest for register Tensor

* add shape checker and split attribute
上级 2105d146
...@@ -292,251 +292,5 @@ class AttrReader { ...@@ -292,251 +292,5 @@ class AttrReader {
const AttributeMap* default_attrs_; const AttributeMap* default_attrs_;
}; };
// check whether a value(attribute) fit a certain limit
template <typename T>
class GreaterThanChecker {
public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const {
PADDLE_ENFORCE_GT(
value,
lower_bound_,
platform::errors::OutOfRange("Check for attribute value greater than "
"a certain value failed."));
}
private:
T lower_bound_;
};
template <typename T>
class EqualGreaterThanChecker {
public:
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const {
PADDLE_ENFORCE_GE(
value,
lower_bound_,
platform::errors::OutOfRange("Check for attribute valur equal or "
"greater than a certain value failed."));
}
private:
T lower_bound_;
};
// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...
template <typename T>
class DefaultValueSetter {
public:
explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {}
const T& operator()() const { return default_value_; }
private:
T default_value_;
};
template <typename T>
class EnumInContainer {
public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(const T& val) const {
PADDLE_ENFORCE_NE(
container_.find(val),
container_.end(),
platform::errors::NotFound("Value %s is not in enum container %s.",
val,
ContainerDebugString()));
}
private:
std::string ContainerDebugString() const {
std::ostringstream sout;
sout << "[";
size_t cnt = 0;
for (auto& v : container_) {
sout << v;
++cnt;
if (cnt != container_.size()) {
sout << " ,";
}
}
sout << "]";
return sout.str();
}
std::unordered_set<T> container_;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template <typename T>
class TypedAttrChecker {
typedef std::function<const T&()> DefaultValueChecker;
typedef std::function<void(const T&)> ValueChecker;
public:
explicit TypedAttrChecker(const std::string& attr_name,
proto::OpProto_Attr* attr)
: attr_name_(attr_name), attr_(attr) {}
TypedAttrChecker& AsExtra() {
attr_->set_extra(true);
return *this;
}
TypedAttrChecker& AsQuant() {
attr_->set_quant(true);
return *this;
}
TypedAttrChecker& InEnum(const std::unordered_set<T>& range) {
value_checkers_.push_back(EnumInContainer<T>(range));
return *this;
}
TypedAttrChecker& GreaterThan(const T& lower_bound) {
value_checkers_.push_back(GreaterThanChecker<T>(lower_bound));
return *this;
}
TypedAttrChecker& EqualGreaterThan(const T& lower_bound) {
value_checkers_.push_back(EqualGreaterThanChecker<T>(lower_bound));
return *this;
}
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(),
true,
platform::errors::AlreadyExists("Attribute (%s) has a default value "
"and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}
// allow users provide their own checker
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
value_checkers_.push_back(checker);
return *this;
}
void operator()(AttributeMap* attr_map,
bool get_default_value_only = false,
bool only_check_exist_value = false) const {
if (get_default_value_only) {
if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
return;
}
// If attribute is VarDesc(s), we should verify it's dtype and shape.
auto it = attr_map->find(attr_name_);
if (it != attr_map->end() && HasAttrVar(it->second)) {
VLOG(1) << "Found Attribute " << attr_name_
<< " with Variable, skip attr_checker.";
return;
}
if (only_check_exist_value) {
if (it != attr_map->end()) {
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
} else {
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(),
false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]());
it = tmp.first;
}
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
}
private:
std::string attr_name_;
proto::OpProto_Attr* attr_;
std::vector<ValueChecker> value_checkers_;
std::vector<DefaultValueChecker> default_value_setter_;
};
// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;
public:
template <typename T>
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name,
proto::OpProto_Attr* attr) {
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name, attr));
AttrChecker& checker = attr_checkers_.back();
return *(checker.target<TypedAttrChecker<T>>());
}
void Check(AttributeMap* attr_map,
bool explicit_only = false,
bool only_check_exist_value = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false, only_check_exist_value);
}
}
AttributeMap GetDefaultAttrsMap() const {
AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true, false);
}
return default_values_map;
}
void RecordExplicitCheckerNum() {
explicit_checker_num_ = attr_checkers_.size();
}
void InitDefaultAttributeMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false);
}
}
const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }
private:
std::vector<AttrChecker> attr_checkers_;
AttributeMap default_attrs_;
// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic
// graph
// mode
size_t explicit_checker_num_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/var_desc.h"
namespace paddle {
namespace framework {
// check whether a value(attribute) fit a certain limit
template <typename T>
class GreaterThanChecker {
public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const {
PADDLE_ENFORCE_GT(
value,
lower_bound_,
platform::errors::OutOfRange("Check for attribute value greater than "
"a certain value failed."));
}
private:
T lower_bound_;
};
template <typename T>
class EqualGreaterThanChecker {
public:
explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(const T& value) const {
PADDLE_ENFORCE_GE(
value,
lower_bound_,
platform::errors::OutOfRange("Check for attribute valur equal or "
"greater than a certain value failed."));
}
private:
T lower_bound_;
};
template <typename T>
class TypedAttrVarInfoChecker {
public:
TypedAttrVarInfoChecker() = default;
void operator()(const Attribute& attr) const {
if (IsAttrVar(attr)) {
auto* var_desc = PADDLE_GET_CONST(VarDesc*, attr);
check(var_desc);
} else if (IsAttrVars(attr)) {
auto var_descs = PADDLE_GET_CONST(std::vector<VarDesc*>, attr);
check(var_descs);
}
}
void check(const VarDesc* var_desc) const {
PADDLE_ENFORCE_NOT_NULL(
var_desc,
platform::errors::InvalidArgument(
"Required Attribute with Variable type shall not be nullptr."));
auto shape = var_desc->GetShape();
PADDLE_ENFORCE_EQ(shape.size(),
1U,
platform::errors::InvalidArgument(
"Required shape rank of Attribute(%s) == 1, "
"but received rank == %s",
var_desc->Name(),
shape.size()));
auto& expected_type = typeid(T);
auto dtype = var_desc->GetDataType();
// attribute is a IntArray
if (expected_type == typeid(std::vector<int64_t>) ||
expected_type == typeid(std::vector<int>)) {
bool is_int = (dtype == proto::VarType::Type::VarType_Type_INT32 ||
dtype == proto::VarType::Type::VarType_Type_INT64);
PADDLE_ENFORCE_EQ(is_int,
true,
platform::errors::InvalidArgument(
"Required dtype of Attribute(%s) shall be "
"int32|int64, but recevied %s.",
var_desc->Name(),
dtype));
}
}
void check(const std::vector<VarDesc*>& var_descs) const {
for (auto& var_desc : var_descs) {
PADDLE_ENFORCE_NOT_NULL(
var_desc,
platform::errors::InvalidArgument(
"Required Attribute with Variable type shall not be nullptr."));
auto shape = var_desc->GetShape();
PADDLE_ENFORCE_EQ(shape.size(),
1U,
platform::errors::InvalidArgument(
"Required shape rank of Attribute(%s) == 1, "
"but received rank == %s",
var_desc->Name(),
shape.size()));
PADDLE_ENFORCE_EQ(shape[0] == 1U || shape[0] == -1,
true,
platform::errors::InvalidArgument(
"Required shape[0] of Attribute(%s) == 1 or -1, "
"but received shape[0] == %s",
var_desc->Name(),
shape[0]));
}
}
};
// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...
template <typename T>
class DefaultValueSetter {
public:
explicit DefaultValueSetter(T default_value)
: default_value_(default_value) {}
const T& operator()() const { return default_value_; }
private:
T default_value_;
};
template <typename T>
class EnumInContainer {
public:
explicit EnumInContainer(const std::unordered_set<T>& c) : container_(c) {}
void operator()(const T& val) const {
PADDLE_ENFORCE_NE(
container_.find(val),
container_.end(),
platform::errors::NotFound("Value %s is not in enum container %s.",
val,
ContainerDebugString()));
}
private:
std::string ContainerDebugString() const {
std::ostringstream sout;
sout << "[";
size_t cnt = 0;
for (auto& v : container_) {
sout << v;
++cnt;
if (cnt != container_.size()) {
sout << " ,";
}
}
sout << "]";
return sout.str();
}
std::unordered_set<T> container_;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template <typename T>
class TypedAttrChecker {
typedef std::function<const T&()> DefaultValueChecker;
typedef std::function<void(const T&)> ValueChecker;
public:
explicit TypedAttrChecker(const std::string& attr_name,
proto::OpProto_Attr* attr)
: attr_name_(attr_name), attr_(attr) {}
TypedAttrChecker& AsExtra() {
attr_->set_extra(true);
return *this;
}
TypedAttrChecker& AsQuant() {
attr_->set_quant(true);
return *this;
}
TypedAttrChecker& SupportTensor() {
attr_->set_support_tensor(true);
return *this;
}
TypedAttrChecker& InEnum(const std::unordered_set<T>& range) {
value_checkers_.push_back(EnumInContainer<T>(range));
return *this;
}
TypedAttrChecker& GreaterThan(const T& lower_bound) {
value_checkers_.push_back(GreaterThanChecker<T>(lower_bound));
return *this;
}
TypedAttrChecker& EqualGreaterThan(const T& lower_bound) {
value_checkers_.push_back(EqualGreaterThanChecker<T>(lower_bound));
return *this;
}
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(),
true,
platform::errors::AlreadyExists("Attribute (%s) has a default value "
"and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}
// allow users provide their own checker
TypedAttrChecker& AddCustomChecker(const ValueChecker& checker) {
value_checkers_.push_back(checker);
return *this;
}
void operator()(AttributeMap* attr_map,
bool get_default_value_only = false,
bool only_check_exist_value = false) const {
if (get_default_value_only) {
if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
return;
}
// If attribute is VarDesc(s), we should verify it's supported in OpMaker
auto it = attr_map->find(attr_name_);
if (it != attr_map->end() && HasAttrVar(it->second)) {
PADDLE_ENFORCE_EQ(attr_->support_tensor(),
true,
platform::errors::InvalidArgument(
"Found Attribute('%s') with type(Variable), but it "
"doesn't support Tensor type.",
attr_name_));
VLOG(1) << "Found Attribute " << attr_name_ << " with type(Variable).";
var_info_checker_(it->second);
return;
}
if (only_check_exist_value) {
if (it != attr_map->end()) {
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
} else {
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(),
false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]());
it = tmp.first;
}
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
}
private:
std::string attr_name_;
proto::OpProto_Attr* attr_;
TypedAttrVarInfoChecker<T> var_info_checker_;
std::vector<ValueChecker> value_checkers_;
std::vector<DefaultValueChecker> default_value_setter_;
};
// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;
public:
template <typename T>
TypedAttrChecker<T>& AddAttrChecker(const std::string& attr_name,
proto::OpProto_Attr* attr) {
attr_checkers_.push_back(TypedAttrChecker<T>(attr_name, attr));
AttrChecker& checker = attr_checkers_.back();
return *(checker.target<TypedAttrChecker<T>>());
}
void Check(AttributeMap* attr_map,
bool explicit_only = false,
bool only_check_exist_value = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false, only_check_exist_value);
}
}
AttributeMap GetDefaultAttrsMap() const {
AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true, false);
}
return default_values_map;
}
void RecordExplicitCheckerNum() {
explicit_checker_num_ = attr_checkers_.size();
}
void InitDefaultAttributeMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false);
}
}
const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }
private:
std::vector<AttrChecker> attr_checkers_;
AttributeMap default_attrs_;
// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic
// graph
// mode
size_t explicit_checker_num_;
};
} // namespace framework
} // namespace paddle
...@@ -102,6 +102,7 @@ message OpProto { ...@@ -102,6 +102,7 @@ message OpProto {
optional bool generated = 4 [ default = false ]; optional bool generated = 4 [ default = false ];
optional bool extra = 5 [ default = false ]; optional bool extra = 5 [ default = false ];
optional bool quant = 6 [ default = false ]; optional bool quant = 6 [ default = false ];
optional bool support_tensor = 7 [ default = false];
} }
required string type = 1; required string type = 1;
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/attribute_checker.h"
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h" #include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/attribute.h" #include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/attribute_checker.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -447,7 +447,7 @@ OperatorBase::OperatorBase(const std::string& type, ...@@ -447,7 +447,7 @@ OperatorBase::OperatorBase(const std::string& type,
GenerateTemporaryNames(); GenerateTemporaryNames();
CheckAllInputOutputSet(); CheckAllInputOutputSet();
} }
// In OperatorBase level, all attribute with VarDesc type will be considered // In OperatorBase level, all attributes with VarDesc type will be considered
// as Input. // as Input.
for (auto& attr : FilterAttrVar(attrs)) { for (auto& attr : FilterAttrVar(attrs)) {
VLOG(3) << "found Attribute with Variable type: " << attr.first; VLOG(3) << "found Attribute with Variable type: " << attr.first;
......
...@@ -91,7 +91,8 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -91,7 +91,8 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
"The axis could also be negative numbers. Negative axis is " "The axis could also be negative numbers. Negative axis is "
"interpreted as counting from the end of the rank." "interpreted as counting from the end of the rank."
"i.e., axis + rank(X) th dimension.") "i.e., axis + rank(X) th dimension.")
.SetDefault(0); .SetDefault(0)
.SupportTensor();
AddInput("AxisTensor", AddInput("AxisTensor",
"(Tensor) The axis along which the input tensors will be " "(Tensor) The axis along which the input tensors will be "
"concatenated. " "concatenated. "
......
...@@ -71,7 +71,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -71,7 +71,8 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"'dropout_prob' must be between 0.0 and 1.0.")); "'dropout_prob' must be between 0.0 and 1.0."));
}); })
.SupportTensor();
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.") "for training. Some layers may run faster when this is true.")
......
...@@ -74,7 +74,8 @@ class TileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -74,7 +74,8 @@ class TileOpMaker : public framework::OpProtoAndCheckerMaker {
"the corresponding value given by Attr(repeat_times)."); "the corresponding value given by Attr(repeat_times).");
AddAttr<std::vector<int>>("repeat_times", AddAttr<std::vector<int>>("repeat_times",
"The number of repeat times for each dimension.") "The number of repeat times for each dimension.")
.SetDefault({}); .SetDefault({})
.SupportTensor();
AddComment(R"DOC( AddComment(R"DOC(
Tile operator repeats the input by given times number. You should set times Tile operator repeats the input by given times number. You should set times
number for each dimension by providing attribute 'repeat_times'. The rank of X number for each dimension by providing attribute 'repeat_times'. The rank of X
......
...@@ -18,6 +18,7 @@ import tempfile ...@@ -18,6 +18,7 @@ import tempfile
import paddle import paddle
import paddle.inference as paddle_infer import paddle.inference as paddle_infer
from paddle.fluid.framework import program_guard, Program from paddle.fluid.framework import program_guard, Program
from paddle.fluid.framework import OpProtoHolder
import numpy as np import numpy as np
paddle.enable_static() paddle.enable_static()
...@@ -154,5 +155,37 @@ class TestTileTensor(UnittestBase): ...@@ -154,5 +155,37 @@ class TestTileTensor(UnittestBase):
self.assertEqual(infer_out.shape, (6, 6, 10)) self.assertEqual(infer_out.shape, (6, 6, 10))
class TestRegiterSupportTensorInOpMaker(unittest.TestCase):
def setUp(self):
self.all_protos = OpProtoHolder.instance()
self.support_tensor_attrs = {
'dropout': ['dropout_prob'],
'tile': ['repeat_times'],
'concat': ['axis']
}
# Just add a op example to test not support tensor
self.not_support_tensor_attrs = {'svd': ['full_matrices']}
def test_support_tensor(self):
# All Attribute tagged with .SupportTensor() in OpMaker will return True
for op_type, attr_names in self.support_tensor_attrs.items():
for attr_name in attr_names:
self.assertTrue(self.is_support_tensor_attr(op_type, attr_name))
# All Attribute not tagged with .SupportTensor() in OpMaker will return False
for op_type, attr_names in self.not_support_tensor_attrs.items():
for attr_name in attr_names:
self.assertFalse(self.is_support_tensor_attr(
op_type, attr_name))
def is_support_tensor_attr(self, op_type, attr_name):
proto = self.all_protos.get_op_proto(op_type)
for attr in proto.attrs:
if attr.name == attr_name:
return attr.support_tensor
raise RuntimeError("Not found attribute : ", attr_name)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册