slice_op.cc 9.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
W
Wangzheee 已提交
14
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
15 16 17 18 19 20 21 22

namespace paddle {
namespace inference {
namespace tensorrt {

class SliceOpConverter : public OpConverter {
 public:
  void operator()(const framework::proto::OpDesc& op,
23 24
                  const framework::Scope& scope,
                  bool test_mode) override {
25 26
    // This OP is implemented by trt dynamic shpae plugin.
    // Dynamic shape plugin requires TRT version greater than 6.0.
27 28 29 30
    VLOG(4) << "convert slice op to tensorrt layer";
    framework::OpDesc op_desc(op, nullptr);
    // Declare inputs
    auto* input = engine_->GetITensor(op_desc.Input("Input")[0]);
31
    auto output_name = op_desc.Output("Out")[0];
32

33
    float out_scale = 1;
34
    if (op_desc.HasAttr("out_threshold")) {
35
      out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
36 37 38
      engine_->SetTensorDynamicRange(input, out_scale);
    }

39
    std::vector<int> axes =
40
        BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
41
    std::vector<int> starts =
42
        BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("starts"));
43
    std::vector<int> ends =
44
        BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
45 46
    std::vector<int> decrease_axises =
        BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("decrease_axis"));
47

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    auto input_dims = input->getDimensions();
    if (!engine_->with_dynamic_shape()) {
      // notice that input shape is [CHW] without batch axis when input has
      // static shape
      for (size_t i = input_dims.nbDims; i > 0; i--) {
        input_dims.d[i] = input_dims.d[i - 1];
      }
      input_dims.d[0] = 1;  // fake batchsize, not useful here
      for (size_t i = 0; i < axes.size(); i++) {
        if (starts[i] < 0) {
          starts[i] = std::max(starts[i] + input_dims.d[axes[i]], 0);
        }
        if (ends[i] < 0) {
          ends[i] = std::max(ends[i] + input_dims.d[axes[i]], 0);
        }
        ends[i] = std::min(ends[i], input_dims.d[axes[i]]);
        PADDLE_ENFORCE_GT(
65 66
            ends[i],
            starts[i],
67 68 69
            platform::errors::InvalidArgument(
                "Attr(ends) should be greater than attr(starts) in "
                "slice op. But received ends = %d, starts = %d.",
70 71
                ends[i],
                starts[i]));
72 73 74
      }
    }

75 76
    nvinfer1::ILayer* layer = nullptr;
    if (engine_->with_dynamic_shape()) {
W
Wangzheee 已提交
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
      if (engine_->use_oss() && engine_->with_ernie() &&
          input_dims.nbDims == 4) {
        std::vector<nvinfer1::ITensor*> plugin_inputs;
        if (engine_->with_interleaved()) {
          auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
          nvinfer1::Permutation transpose_embed{2, 1, 0, 3};
          shuffler_slice->setSecondTranspose(transpose_embed);
          engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0),
                                         out_scale);
          shuffler_slice->setName(
              ("SpecialSlice_interleaved: transpose: (Output: " + output_name +
               ")")
                  .c_str());
          plugin_inputs.emplace_back(shuffler_slice->getOutput(0));
        } else {
          plugin_inputs.emplace_back(input);
        }
        std::string pos_name;
        if (engine_->Has("ernie_pos_name")) {
          pos_name = engine_->Get<std::string>("ernie_pos_name");
        } else {
          // hard code for compatibility
          pos_name = engine_->network()->getInput(2)->getName();
        }
        plugin_inputs.emplace_back(
            engine_->GetITensor(pos_name));  // cu_seqlens, eval_placeholder_2
103

W
Wangzheee 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117
        // bool ban_fp16 = engine_->disable_trt_plugin_fp16();
        plugin::SpecialSlicePluginDynamic* plugin =
            new plugin::SpecialSlicePluginDynamic();
        layer = engine_->AddDynamicPlugin(
            plugin_inputs.data(), plugin_inputs.size(), plugin);
      } else {
        auto nchw_input_dims = input->getDimensions();
        nvinfer1::Dims trt_start_dims;
        trt_start_dims.nbDims = nchw_input_dims.nbDims;
        memset(trt_start_dims.d, 0, sizeof(int32_t) * nchw_input_dims.nbDims);
        nvinfer1::Dims trt_size_dims = trt_start_dims;
        nvinfer1::Dims trt_end_dims = trt_start_dims;
        nvinfer1::Dims trt_step_dims = trt_start_dims;
        for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1;
118

W
Wangzheee 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131
        // input : [N,C,H,W]
        bool has_neg_indices = false;
        for (size_t i = 0; i < axes.size(); i++) {
          int trt_axis = axes[i];
          trt_start_dims.d[trt_axis] = starts[i];
          trt_end_dims.d[trt_axis] = ends[i];
          if (starts[i] < 0 || ends[i] < 0) has_neg_indices = true;
        }
        auto* shape_tensor = Shape(input);
        auto* start_tensor = Add1DConstantLayer(trt_start_dims);
        if (has_neg_indices) {
          start_tensor = FixNegIndices(shape_tensor, start_tensor);
        }
132

W
Wangzheee 已提交
133 134 135 136 137 138 139 140 141 142 143 144 145
        std::vector<nvinfer1::ITensor*> end_vec_tensor;
        for (int i = 0; i < trt_end_dims.nbDims; i++) {
          end_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, i));
        }

        for (size_t i = 0; i < axes.size(); i++) {
          int trt_axis = axes[i];
          if (ends[i] >= 0) {
            end_vec_tensor[trt_axis] = Add1DConstantLayer(ends[i]);
          } else {
            end_vec_tensor[trt_axis] =
                Sum(end_vec_tensor[trt_axis], Add1DConstantLayer(ends[i]));
          }
146
        }
147 148 149

// CI failed in trt 6015 but success in 7134, may be a trt bug
#if IS_TRT_VERSION_GE(7134)
W
Wangzheee 已提交
150 151
        auto* size_tensor =
            Sub(Min(Concat(end_vec_tensor), shape_tensor), start_tensor);
152
#else
W
Wangzheee 已提交
153
        auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor);
154 155
#endif

W
Wangzheee 已提交
156 157 158 159 160 161 162 163
        layer = TRT_ENGINE_ADD_LAYER(engine_,
                                     Slice,
                                     *input,
                                     trt_start_dims,
                                     trt_size_dims,
                                     trt_step_dims);
        layer->setInput(1, *start_tensor);
        layer->setInput(2, *size_tensor);
164

W
Wangzheee 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177
        if (decrease_axises.size() > 0) {
          std::vector<int32_t> gather_indices;
          for (int i = 0; i < trt_size_dims.nbDims; i++) {
            if (decrease_axises.end() !=
                std::find(decrease_axises.begin(), decrease_axises.end(), i))
              continue;
            gather_indices.push_back(i);
          }
          if (gather_indices.empty())
            gather_indices.push_back(decrease_axises[0]);
          auto real_size_tensor = Gather(size_tensor, gather_indices);
          layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0));
          layer->setInput(1, *real_size_tensor);
178
        }
W
Wangzheee 已提交
179 180 181 182 183 184 185
        bool with_fp16 =
            engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
        int decrease_axis =
            decrease_axises.size() == 0 ? -1 : decrease_axises[0];
        plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
            starts, ends, axes, decrease_axis, with_fp16);
        layer = engine_->AddDynamicPlugin(&input, 1, plugin);
186
      }
187
    } else {
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
#if IS_TRT_VERSION_GE(6000)
      auto chw_input_dims = input->getDimensions();
      nvinfer1::Dims trt_start_dims;
      trt_start_dims.nbDims = chw_input_dims.nbDims;
      memset(trt_start_dims.d, 0, sizeof(int32_t) * chw_input_dims.nbDims);
      nvinfer1::Dims trt_size_dims = chw_input_dims;
      nvinfer1::Dims trt_step_dims;
      trt_step_dims.nbDims = chw_input_dims.nbDims;
      for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1;

      // input : [C,H,W]
      for (size_t i = 0; i < axes.size(); i++) {
        int trt_axis = axes[i] - 1;
        trt_start_dims.d[trt_axis] = starts[i];
        trt_size_dims.d[trt_axis] = ends[i] - starts[i];
      }
      layer = TRT_ENGINE_ADD_LAYER(
          engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims);
      nvinfer1::Dims real_trt_size_dims;
      real_trt_size_dims.nbDims = 0;

      if (decrease_axises.size() > 0) {
        for (size_t i = 0; i < decrease_axises.size(); i++) {
          decrease_axises[i]--;
        }
        for (int i = 0; i < trt_size_dims.nbDims; i++) {
          if (decrease_axises.end() !=
              std::find(decrease_axises.begin(), decrease_axises.end(), i))
            continue;
          real_trt_size_dims.d[real_trt_size_dims.nbDims] = trt_size_dims.d[i];
          real_trt_size_dims.nbDims++;
        }
        if (real_trt_size_dims.nbDims == 0) {
          real_trt_size_dims.nbDims = 1;
          real_trt_size_dims.d[0] = 1;
        }
        auto reshape_layer =
            TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0));
        reshape_layer->setReshapeDimensions(real_trt_size_dims);
        layer = static_cast<nvinfer1::ILayer*>(reshape_layer);
      }
#else
230 231
      bool with_fp16 =
          engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
232
      plugin::SlicePlugin* plugin =
233
          new plugin::SlicePlugin(starts, ends, axes, with_fp16);
234
      layer = engine_->AddPlugin(&input, 1, plugin);
235
#endif
236
    }
237
    RreplenishLayerAndOutput(layer, "slice", {output_name}, test_mode);
238 239 240 241 242 243 244 245
  }
};

}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle

REGISTER_TRT_OP_CONVERTER(slice, SliceOpConverter);