提交 6e3255a1 编写于 作者: Z zlsh80826

add stack plugin serialize

上级 bb17cc75
......@@ -23,12 +23,7 @@ namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
size_t StackPluginDynamic::getSerializationSize() const { return 0; }
void StackPluginDynamic::serialize(void* buffer) const {}
nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
......@@ -119,9 +114,9 @@ int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice,
stream);
int num_stacks = out_dims.d[axis_];
const int num_stacks = out_dims.d[axis_];
dim3 num_blocks(num_stacks, lead_unit);
int num_threads = 256;
const int num_threads = 256;
auto infer_type = input_desc[0].type;
if (infer_type == nvinfer1::DataType::kFLOAT) {
......
......@@ -30,25 +30,47 @@ class StackPluginDynamic : public DynamicPluginTensorRT {
public:
StackPluginDynamic(int axis, int num_stack)
: axis_(axis), num_stack_(num_stack) {
int device_id;
cudaGetDevice(&device_id);
in_ptr_tensor_.Resize({num_stack});
in_ptr_gpu_ =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id));
init();
}
StackPluginDynamic(void const* serialData, size_t serialLength) {
DeserializeValue(&serialData, &serialLength, &axis_);
DeserializeValue(&serialData, &serialLength, &num_stack_);
init();
}
StackPluginDynamic(void const* serialData, size_t serialLength) {}
~StackPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new StackPluginDynamic(axis_, num_stack_);
}
void init() {
int device_id;
cudaGetDevice(&device_id);
in_ptr_tensor_.Resize({num_stack_});
in_ptr_gpu_ =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id));
}
const char* getPluginType() const override { return "stack_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
size_t getSerializationSize() const override {
size_t serialize_size = 0;
serialize_size += SerializedSize(axis_);
serialize_size += SerializedSize(num_stack_);
// serialize_size += num_stack_ * sizeof(int64_t);
return serialize_size;
}
void serialize(void* buffer) const override {
SerializeValue(&buffer, axis_);
SerializeValue(&buffer, num_stack_);
// SerializeCudaPointer(&buffer, in_ptr_gpu_, num_stack_);
}
nvinfer1::DimsExprs getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册