未验证 提交 195d6d0f 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP error] Add attrs type check (#53030)

* [CustomOP error] Add attrs type check

* fix global variable order bug

* include unordered_set

* fix ParseAttrStr compile error
上级 2395cbe5
...@@ -149,7 +149,7 @@ static void RunKernelFunc( ...@@ -149,7 +149,7 @@ static void RunKernelFunc(
} }
for (auto& attr_str : attrs) { for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str); auto attr_name_and_type = paddle::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0]; auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1]; auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") { if (attr_type_str == "bool") {
...@@ -464,7 +464,7 @@ static void RunInferShapeFunc( ...@@ -464,7 +464,7 @@ static void RunInferShapeFunc(
std::vector<paddle::any> custom_attrs; std::vector<paddle::any> custom_attrs;
for (auto& attr_str : attrs) { for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str); auto attr_name_and_type = paddle::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0]; auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1]; auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") { if (attr_type_str == "bool") {
...@@ -491,13 +491,13 @@ static void RunInferShapeFunc( ...@@ -491,13 +491,13 @@ static void RunInferShapeFunc(
custom_attrs.emplace_back( custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name)); ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. " "Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, " "Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, " "`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, " "`std::vector<float>`, `std::vector<int64_t>`, "
"Please check whether the attribute data type and " "`std::vector<std::string>`, Please check whether the attribute data "
"data type string are matched.", "type and data type string are matched.",
attr_type_str)); attr_type_str));
} }
} }
...@@ -872,7 +872,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker { ...@@ -872,7 +872,7 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
} }
} }
for (auto& attr : attrs_) { for (auto& attr : attrs_) {
auto attr_name_and_type = detail::ParseAttrStr(attr); auto attr_name_and_type = paddle::ParseAttrStr(attr);
auto attr_name = attr_name_and_type[0]; auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1]; auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") { if (attr_type_str == "bool") {
......
...@@ -81,25 +81,6 @@ inline static bool IsMemberOf(const std::vector<std::string>& vec, ...@@ -81,25 +81,6 @@ inline static bool IsMemberOf(const std::vector<std::string>& vec,
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend(); return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
} }
static std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos,
std::string::npos,
platform::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));
std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt;
}
} // namespace detail } // namespace detail
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -567,8 +567,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -567,8 +567,7 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
int attr_start_idx = 1 + inputs.size(); int attr_start_idx = 1 + inputs.size();
for (size_t i = 0; i < attrs.size(); ++i) { for (size_t i = 0; i < attrs.size(); ++i) {
const auto& attr = attrs.at(i); const auto& attr = attrs.at(i);
std::vector<std::string> attr_name_and_type = std::vector<std::string> attr_name_and_type = paddle::ParseAttrStr(attr);
paddle::framework::detail::ParseAttrStr(attr);
auto attr_type_str = attr_name_and_type[1]; auto attr_type_str = attr_name_and_type[1];
VLOG(7) << "Custom operator add attrs " << attr_name_and_type[0] VLOG(7) << "Custom operator add attrs " << attr_name_and_type[0]
<< " to CustomOpKernelContext. Attribute type = " << " to CustomOpKernelContext. Attribute type = "
......
...@@ -97,6 +97,8 @@ inline std::string Optional(const std::string& t_name) { ...@@ -97,6 +97,8 @@ inline std::string Optional(const std::string& t_name) {
return result; return result;
} }
std::vector<std::string> ParseAttrStr(const std::string& attr);
PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst); PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst);
////////////////////// Kernel Context //////////////////////// ////////////////////// Kernel Context ////////////////////////
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
...@@ -24,6 +25,38 @@ limitations under the License. */ ...@@ -24,6 +25,38 @@ limitations under the License. */
namespace paddle { namespace paddle {
// remove leading and tailing spaces
std::string trim_spaces(const std::string& str) {
const char* p = str.c_str();
while (*p != 0 && isspace(*p)) {
p++;
}
size_t len = strlen(p);
while (len > 0 && isspace(p[len - 1])) {
len--;
}
return std::string(p, len);
}
std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos,
std::string::npos,
phi::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));
std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(trim_spaces(attr.substr(split_pos + 1)));
VLOG(3) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt;
}
PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) {
if (!src.initialized() || !dst->defined()) { if (!src.initialized() || !dst->defined()) {
VLOG(3) << "Custom operator assigns non-initialized tensor, this only " VLOG(3) << "Custom operator assigns non-initialized tensor, this only "
...@@ -346,6 +379,30 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs( ...@@ -346,6 +379,30 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
} }
OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) { OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
const std::unordered_set<std::string> custom_attrs_type(
{"bool",
"int",
"float",
"int64_t",
"std::string",
"std::vector<int>",
"std::vector<float>",
"std::vector<int64_t>",
"std::vector<std::string>"});
for (const auto& attr : attrs) {
auto attr_type_str = ParseAttrStr(attr)[1];
if (custom_attrs_type.find(attr_type_str) == custom_attrs_type.end()) {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
info_ptr_->Attrs(std::forward<std::vector<std::string>>(attrs)); info_ptr_->Attrs(std::forward<std::vector<std::string>>(attrs));
return *this; return *this;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册