split_op_plugin.cu 2.6 KB
Newer Older
N
nhzlx 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {
20
namespace plugin {
N
nhzlx 已提交
21

22 23 24 25 26 27
nvinfer1::Dims SplitPlugin::getOutputDimensions(
    int index, const nvinfer1::Dims* input_dims, int num_inputs) {
  PADDLE_ENFORCE_EQ(num_inputs, 1);
  PADDLE_ENFORCE_LT(index, this->getNbOutputs());

  nvinfer1::Dims output_dims = input_dims[0];
28
  output_dims.d[axis_] = output_length_.at(index);
N
nhzlx 已提交
29 30 31 32
  return output_dims;
}

int SplitPlugin::initialize() {
33 34
  PADDLE_ENFORCE_LE(axis_, nvinfer1::Dims::MAX_DIMS);

N
nhzlx 已提交
35 36
  std::vector<int> segment_offsets(1, 0);
  for (int i = 0; i < this->getNbOutputs(); ++i) {
37
    segment_offsets.push_back(segment_offsets.back() + output_length_[i]);
N
nhzlx 已提交
38
  }
39
  segment_offsets_ = segment_offsets;
N
nhzlx 已提交
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
  nvinfer1::Dims dims = this->getInputDims(0);
  nx_ = 1;
  for (int i = dims.nbDims - 1; i > axis_; --i) {
    nx_ *= dims.d[i];
  }
  ny_ = dims.d[axis_];
  nz_ = 1;
  for (int i = axis_ - 1; i >= 0; --i) {
    nz_ *= dims.d[i];
  }
  return 0;
}

int SplitPlugin::enqueue(int batchSize, const void* const* inputs,
                         void** outputs, void* workspace, cudaStream_t stream) {
  auto const& input_dims = this->getInputDims(0);
56
  int input_size = 0;
N
nhzlx 已提交
57 58 59
  float const* idata = reinterpret_cast<float const*>(inputs[0]);
  float** odatas = reinterpret_cast<float**>(outputs);

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
  // kernel impl here.
  int inputBatchOffset = nx_ * ny_ * nz_;
  for (size_t i = 0; i < this->getNbOutputs(); i++) {
    for (size_t j = 0; j < batchSize; j++) {
      cudaMemcpyAsync(
          odatas[i] +
              j * (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ *
                  sizeof(float),
          inputs[0] +
              (inputBatchOffset * j + segment_offsets_[i] * nx_) *
                  sizeof(float),
          (segment_offsets_[i + 1] - segment_offsets_[i]) * nx_ * sizeof(float),
          cudaMemcpyDeviceToDevice, stream);
    }
  }
N
nhzlx 已提交
75 76 77 78

  return cudaGetLastError() != cudaSuccess;
}

79 80 81 82
}  // namespace plugin
}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle