提交 4aec5ec8 编写于 作者: Z zlsh80826

add stack op serialization

上级 862dde5e
...@@ -61,7 +61,6 @@ class StackPluginDynamic : public DynamicPluginTensorRT { ...@@ -61,7 +61,6 @@ class StackPluginDynamic : public DynamicPluginTensorRT {
serialize_size += SerializedSize(axis_); serialize_size += SerializedSize(axis_);
serialize_size += SerializedSize(num_stack_); serialize_size += SerializedSize(num_stack_);
// serialize_size += num_stack_ * sizeof(int64_t);
return serialize_size; return serialize_size;
} }
...@@ -69,7 +68,6 @@ class StackPluginDynamic : public DynamicPluginTensorRT { ...@@ -69,7 +68,6 @@ class StackPluginDynamic : public DynamicPluginTensorRT {
void serialize(void* buffer) const override { void serialize(void* buffer) const override {
SerializeValue(&buffer, axis_); SerializeValue(&buffer, axis_);
SerializeValue(&buffer, num_stack_); SerializeValue(&buffer, num_stack_);
// SerializeCudaPointer(&buffer, in_ptr_gpu_, num_stack_);
} }
nvinfer1::DimsExprs getOutputDimensions( nvinfer1::DimsExprs getOutputDimensions(
...@@ -109,6 +107,45 @@ class StackPluginDynamic : public DynamicPluginTensorRT { ...@@ -109,6 +107,45 @@ class StackPluginDynamic : public DynamicPluginTensorRT {
framework::Tensor in_ptr_tensor_; framework::Tensor in_ptr_tensor_;
int64_t* in_ptr_gpu_; int64_t* in_ptr_gpu_;
}; };
class StackPluginV2Creator : public nvinfer1::IPluginCreator {
public:
StackPluginV2Creator() {}
const char* getPluginName() const override { return "stack_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new StackPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(StackPluginV2Creator);
#endif #endif
} // namespace plugin } // namespace plugin
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册