未验证 提交 70180df5 编写于 作者: W Wilber 提交者: GitHub
上级 116fcada
......@@ -70,9 +70,17 @@ class CustomPluginCreater : public OpConverter {
std::list<std::vector<int>> ints_attrs;
std::list<std::vector<float>> floats_attrs;
for (auto &attr_name : op_attrs_names) {
for (auto &attr_name_and_type : op_attrs_names) {
auto attr_name =
attr_name_and_type.substr(0, attr_name_and_type.find_first_of(":"));
nvinfer1::PluginField plugindata;
plugindata.name = attr_name.c_str();
// NOTE: to avoid string rewrite by iterator, deep copy here
std::vector<char> plugin_attr_name(attr_name.length() + 1, 0);
snprintf(
plugin_attr_name.data(), attr_name.length() + 1, attr_name.c_str());
plugindata.name = plugin_attr_name.data();
if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) {
int_attrs.push_back(PADDLE_GET_CONST(int, attrs.at(attr_name)));
plugindata.data = &int_attrs.back();
......
......@@ -23,13 +23,13 @@ PD_BUILD_OP(custom_op)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({
"float_attr",
"int_attr",
"bool_attr",
"string_attr",
"ints_attr",
"floats_attr",
"bools_attr",
"float_attr: float",
"int_attr: int",
"bool_attr: bool",
"string_attr: std::string",
"ints_attr: std::vector<int>",
"floats_attr: std::vector<float>",
"bools_attr: std::vector<bool>",
});
namespace paddle {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册