From 195d6d0f142207d5b85188b9add57b3a2bfe17f8 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 20 Apr 2023 13:56:44 +0800 Subject: [PATCH] [CustomOP error] Add attrs type check (#53030) * [CustomOP error] Add attrs type check * fix global variable order bug * include unordered_set * fix ParseAttrStr compile error --- paddle/fluid/framework/custom_operator.cc | 14 ++--- .../fluid/framework/custom_operator_utils.h | 19 ------- paddle/fluid/pybind/eager_functions.cc | 3 +- paddle/phi/api/ext/op_meta_info.h | 2 + paddle/phi/api/lib/op_meta_info.cc | 57 +++++++++++++++++++ 5 files changed, 67 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 8435e825531..18ae5627633 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -149,7 +149,7 @@ static void RunKernelFunc( } for (auto& attr_str : attrs) { - auto attr_name_and_type = detail::ParseAttrStr(attr_str); + auto attr_name_and_type = paddle::ParseAttrStr(attr_str); auto attr_name = attr_name_and_type[0]; auto attr_type_str = attr_name_and_type[1]; if (attr_type_str == "bool") { @@ -464,7 +464,7 @@ static void RunInferShapeFunc( std::vector custom_attrs; for (auto& attr_str : attrs) { - auto attr_name_and_type = detail::ParseAttrStr(attr_str); + auto attr_name_and_type = paddle::ParseAttrStr(attr_str); auto attr_name = attr_name_and_type[0]; auto attr_type_str = attr_name_and_type[1]; if (attr_type_str == "bool") { @@ -491,13 +491,13 @@ static void RunInferShapeFunc( custom_attrs.emplace_back( ctx->Attrs().Get>(attr_name)); } else { - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "Unsupported `%s` type value as custom attribute now. " "Supported data types include `bool`, `int`, `float`, " "`int64_t`, `std::string`, `std::vector`, " - "`std::vector`, `std::vector`, " - "Please check whether the attribute data type and " - "data type string are matched.", + "`std::vector`, `std::vector`, " + "`std::vector`, Please check whether the attribute data " + "type and data type string are matched.", attr_type_str)); } } @@ -872,7 +872,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker { } } for (auto& attr : attrs_) { - auto attr_name_and_type = detail::ParseAttrStr(attr); + auto attr_name_and_type = paddle::ParseAttrStr(attr); auto attr_name = attr_name_and_type[0]; auto attr_type_str = attr_name_and_type[1]; if (attr_type_str == "bool") { diff --git a/paddle/fluid/framework/custom_operator_utils.h b/paddle/fluid/framework/custom_operator_utils.h index 678e0f5db31..ec00e8b9d0d 100644 --- a/paddle/fluid/framework/custom_operator_utils.h +++ b/paddle/fluid/framework/custom_operator_utils.h @@ -81,25 +81,6 @@ inline static bool IsMemberOf(const std::vector& vec, return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); } -static std::vector ParseAttrStr(const std::string& attr) { - auto split_pos = attr.find_first_of(":"); - PADDLE_ENFORCE_NE(split_pos, - std::string::npos, - platform::errors::InvalidArgument( - "Invalid attribute string format. Attribute string " - "format is `:`.")); - - std::vector rlt; - // 1. name - rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos))); - // 2. type - rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1))); - - VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1]; - - return rlt; -} - } // namespace detail } // namespace framework } // namespace paddle diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 848fa1fe742..d146cfdeb7f 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -567,8 +567,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, int attr_start_idx = 1 + inputs.size(); for (size_t i = 0; i < attrs.size(); ++i) { const auto& attr = attrs.at(i); - std::vector attr_name_and_type = - paddle::framework::detail::ParseAttrStr(attr); + std::vector attr_name_and_type = paddle::ParseAttrStr(attr); auto attr_type_str = attr_name_and_type[1]; VLOG(7) << "Custom operator add attrs " << attr_name_and_type[0] << " to CustomOpKernelContext. Attribute type = " diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index 07a47ed1df6..15aca9c4677 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -97,6 +97,8 @@ inline std::string Optional(const std::string& t_name) { return result; } +std::vector ParseAttrStr(const std::string& attr); + PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst); ////////////////////// Kernel Context //////////////////////// diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 0af2c96521c..27c18fd604f 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include "glog/logging.h" @@ -24,6 +25,38 @@ limitations under the License. */ namespace paddle { +// remove leading and tailing spaces +std::string trim_spaces(const std::string& str) { + const char* p = str.c_str(); + while (*p != 0 && isspace(*p)) { + p++; + } + size_t len = strlen(p); + while (len > 0 && isspace(p[len - 1])) { + len--; + } + return std::string(p, len); +} + +std::vector ParseAttrStr(const std::string& attr) { + auto split_pos = attr.find_first_of(":"); + PADDLE_ENFORCE_NE(split_pos, + std::string::npos, + phi::errors::InvalidArgument( + "Invalid attribute string format. Attribute string " + "format is `:`.")); + + std::vector rlt; + // 1. name + rlt.emplace_back(trim_spaces(attr.substr(0, split_pos))); + // 2. type + rlt.emplace_back(trim_spaces(attr.substr(split_pos + 1))); + + VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1]; + + return rlt; +} + PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { if (!src.initialized() || !dst->defined()) { VLOG(3) << "Custom operator assigns non-initialized tensor, this only " @@ -346,6 +379,30 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs( } OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector&& attrs) { + const std::unordered_set custom_attrs_type( + {"bool", + "int", + "float", + "int64_t", + "std::string", + "std::vector", + "std::vector", + "std::vector", + "std::vector"}); + for (const auto& attr : attrs) { + auto attr_type_str = ParseAttrStr(attr)[1]; + if (custom_attrs_type.find(attr_type_str) == custom_attrs_type.end()) { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector`, " + "`std::vector`, " + "Please check whether the attribute data type and " + "data type string are matched.", + attr_type_str)); + } + } info_ptr_->Attrs(std::forward>(attrs)); return *this; } -- GitLab