提交 8570e418 编写于 作者: Z zlsh80826

implement stack op createPlugin

上级 7be555ee
......@@ -121,7 +121,22 @@ class StackPluginV2Creator : public nvinfer1::IPluginCreator {
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
int axis = -1;
int num_stack = -1;
for (int i = 0; i < fc->nbFields; ++i) {
const std::string name(fc->fields[i].name);
if (name == "axis") {
axis = static_cast<const int*>(fc->fields[i].data)[0];
} else if (name == "num_stack") {
num_stack = static_cast<const int*>(fc->fields[i].data)[0];
} else {
PADDLE_THROW(
platform::errors::Fatal("Meet an unknown plugin field '" + name +
"' when creating stack op plugin."));
}
}
return new StackPluginDynamic(axis, num_stack);
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册