split_op_plugin.h 1.7 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62

#pragma once

#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include <thrust/device_vector.h>

namespace paddle {
namespace inference {
namespace tensorrt {

class SplitPlugin : public PluginTensorRT {
  int axis_;
  std::vector<int> output_lenght_;
  int nx_, ny_, nz_;
  thrust::device_vector<int> d_segment_offsets_;

 protected:
  virtual size_t getSerializationSize() override {
    return serialized_size(axis_) + serialized_size(output_lenght_)
      + getBaseSerializationSize();
  }

  virtual void serialize(void *buffer) override {
    serializeBase(buffer);
    serialize_value(&buffer, axis_);
    serialize_value(&buffer, output_lenght_);
  }

 public:
  Split() {}
  SplitPlugin(void const* serialData, size_t serialLength) {
    deserializeBase(serialData, serialLength);
    deserialize_value(&serialData, &serialLength, &axis_);
    deserialize_value(&serialData, &serialLength, &output_lenght_);
  }

  SplitPlugin* clone() const override {
    return new SplitPlugin(axis_, output_lenght_);
  }

  virtual const char* getPluginType() const override { return "split"; }
  virtual int getNbOutputs() const override { return output_lenght_.size(); }
  virtual nvinfer1::Dims getOutputDimensions(int index,
                                             const nvinfer1::Dims *inputs, int nbInputDims) override;
  virtual int initialize() override;
  virtual int enqueue(int batchSize,
                      const void *const *inputs, void **outputs,
                      void *workspace, cudaStream_t stream) override;

  void setAxis(int axis) {
    axis_ = axis;
  }

  void setOutputLengths(const std::vector<int> & output_lengths) {
    output_length_ = output_lengths;
  }

};

} // tensorrt
} // inference
} // paddle