diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu index f707a2cad19f2d39e8d38cf54987ad6a3860eced..099f5363319c5cf8e0705b6488d3cfe5f7a34ed1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu @@ -23,12 +23,7 @@ namespace inference { namespace tensorrt { namespace plugin { -// Dynamic Plugin below. #if IS_TRT_VERSION_GE(6000) -size_t StackPluginDynamic::getSerializationSize() const { return 0; } - -void StackPluginDynamic::serialize(void* buffer) const {} - nvinfer1::DimsExprs StackPluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, nvinfer1::IExprBuilder& expr_builder) { @@ -119,9 +114,9 @@ int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, sizeof(void*) * out_dims.d[axis_], cudaMemcpyHostToDevice, stream); - int num_stacks = out_dims.d[axis_]; + const int num_stacks = out_dims.d[axis_]; dim3 num_blocks(num_stacks, lead_unit); - int num_threads = 256; + const int num_threads = 256; auto infer_type = input_desc[0].type; if (infer_type == nvinfer1::DataType::kFLOAT) { diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h index 113eda42d35b7494f8ac533299bd0040f1afd68f..eadffbfb1a66d707598b83d612ce6a7b826aa251 100644 --- a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.h @@ -30,25 +30,47 @@ class StackPluginDynamic : public DynamicPluginTensorRT { public: StackPluginDynamic(int axis, int num_stack) : axis_(axis), num_stack_(num_stack) { - int device_id; - cudaGetDevice(&device_id); - in_ptr_tensor_.Resize({num_stack}); - in_ptr_gpu_ = - in_ptr_tensor_.mutable_data(platform::CUDAPlace(device_id)); + init(); + } + + StackPluginDynamic(void const* serialData, size_t serialLength) { + DeserializeValue(&serialData, &serialLength, &axis_); + DeserializeValue(&serialData, &serialLength, &num_stack_); + init(); } - StackPluginDynamic(void const* serialData, size_t serialLength) {} ~StackPluginDynamic() {} nvinfer1::IPluginV2DynamicExt* clone() const override { return new StackPluginDynamic(axis_, num_stack_); } + void init() { + int device_id; + cudaGetDevice(&device_id); + in_ptr_tensor_.Resize({num_stack_}); + in_ptr_gpu_ = + in_ptr_tensor_.mutable_data(platform::CUDAPlace(device_id)); + } + const char* getPluginType() const override { return "stack_plugin"; } int getNbOutputs() const override { return 1; } int initialize() override { return 0; } - size_t getSerializationSize() const override; - void serialize(void* buffer) const override; + size_t getSerializationSize() const override { + size_t serialize_size = 0; + + serialize_size += SerializedSize(axis_); + serialize_size += SerializedSize(num_stack_); + // serialize_size += num_stack_ * sizeof(int64_t); + + return serialize_size; + } + + void serialize(void* buffer) const override { + SerializeValue(&buffer, axis_); + SerializeValue(&buffer, num_stack_); + // SerializeCudaPointer(&buffer, in_ptr_gpu_, num_stack_); + } nvinfer1::DimsExprs getOutputDimensions( int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,