skip_layernorm.cc 6.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"

namespace paddle {
namespace inference {
namespace tensorrt {

class SkipLayerNormOpConverter : public OpConverter {
 public:
  void operator()(const framework::proto::OpDesc& op,
W
Wangzheee 已提交
25 26
                  const framework::Scope& scope,
                  bool test_mode) override {
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
#if IS_TRT_VERSION_GE(6000)
    VLOG(4) << "convert fused skip layernorm op to tensorrt layer";
    framework::OpDesc op_desc(op, nullptr);
    // Declare inputs
    auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
    auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
    std::vector<nvinfer1::ITensor*> inputs;
    inputs.push_back(input1);
    inputs.push_back(input2);

    auto get_persistable_data = [&](const std::string& arg_name,
                                    framework::DDim* dims) -> float* {
      std::string var_name = op_desc.Input(arg_name).front();
      auto* temp_var = scope.FindVar(var_name);
      auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
      (*dims) = temp_tensor->dims();

44
      auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
45 46 47 48 49 50
      return temp_data;
    };

    framework::DDim bias_dims, scale_dims;
    auto* bias = get_persistable_data("Bias", &bias_dims);
    auto* scale = get_persistable_data("Scale", &scale_dims);
51 52
    int bias_size = phi::product(bias_dims);
    int scale_size = phi::product(scale_dims);
53
    bool enable_int8 = op_desc.HasAttr("enable_int8");
54 55

    nvinfer1::ILayer* layer = nullptr;
56 57

    if (engine_->use_oss()) {
58 59 60 61 62 63 64 65 66
      if (engine_->with_interleaved()) {
        VLOG(4) << "fused skip_layernorm op: use_oss and with_interleaved";
        if (!enable_int8) {
          PADDLE_THROW(
              platform::errors::Fatal("use with_interleaved must be int8."));
        }
        auto creator = GetPluginRegistry()->getPluginCreator(
            "CustomSkipLayerNormPluginDynamic", "3");
        PADDLE_ENFORCE_NE(
W
Wangzheee 已提交
67 68
            creator,
            nullptr,
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
            platform::errors::InvalidArgument(
                "fail to get creator of CustomSkipLayerNormPluginDynamic"));
        const std::vector<nvinfer1::PluginField> fields{
            {"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
            { "gamma",
              scale,
              nvinfer1::PluginFieldType::kFLOAT32,
              scale_size }};
        nvinfer1::PluginFieldCollection* pluginPtr =
            static_cast<nvinfer1::PluginFieldCollection*>(
                malloc(sizeof(*pluginPtr) +
                       fields.size() * sizeof(nvinfer1::PluginField)));
        pluginPtr->nbFields = static_cast<int>(fields.size());
        pluginPtr->fields = fields.data();

        auto pluginObj = creator->createPlugin(
            "CustomSkipLayerNormPluginDynamic", pluginPtr);
        auto plugin_layer = engine_->network()->addPluginV2(
            inputs.data(), inputs.size(), *pluginObj);

        PADDLE_ENFORCE_NE(
W
Wangzheee 已提交
90 91
            plugin_layer,
            nullptr,
92 93 94 95 96 97 98
            platform::errors::InvalidArgument(
                "fail to add CustomSkipLayerNormPluginDynamic layer"));
        layer = plugin_layer;
      } else {
        auto creator = GetPluginRegistry()->getPluginCreator(
            "CustomSkipLayerNormPluginDynamic", "2");
        PADDLE_ENFORCE_NE(
W
Wangzheee 已提交
99 100
            creator,
            nullptr,
101 102 103 104 105 106
            platform::errors::InvalidArgument(
                "fail to get creator of CustomSkipLayerNormPluginDynamic"));
        int type = static_cast<int>((engine_->WithFp16() == 1)
                                        ? nvinfer1::DataType::kHALF
                                        : nvinfer1::DataType::kFLOAT);
        int ld = input1->getDimensions().d[2];  // hidden dimension
W
Wangzheee 已提交
107 108
        PADDLE_ENFORCE_GT(ld,
                          0,
109 110 111 112
                          platform::errors::InvalidArgument(
                              "in CustomSkipLayerNormPluginDynamic hidden "
                              "dimension should > 0"));
        if (enable_int8) {
113
          type = static_cast<int>(nvinfer1::DataType::kHALF);
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        }

        const std::vector<nvinfer1::PluginField> fields{
            {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
            {"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1},
            {"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
            {"gamma", scale, nvinfer1::PluginFieldType::kFLOAT32, scale_size},
        };
        nvinfer1::PluginFieldCollection* pluginPtr =
            static_cast<nvinfer1::PluginFieldCollection*>(
                malloc(sizeof(*pluginPtr) +
                       fields.size() *
                           sizeof(nvinfer1::PluginField)));  // remember to free
        pluginPtr->nbFields = static_cast<int>(fields.size());
        pluginPtr->fields = fields.data();

        auto pluginObj = creator->createPlugin(
            "CustomSkipLayerNormPluginDynamic", pluginPtr);
        auto plugin_layer = engine_->network()->addPluginV2(
            inputs.data(), inputs.size(), *pluginObj);

        PADDLE_ENFORCE_NE(
W
Wangzheee 已提交
136 137
            plugin_layer,
            nullptr,
138 139 140
            platform::errors::InvalidArgument(
                "fail to add CustomSkipLayerNormPluginDynamic layer"));
        layer = plugin_layer;
141
      }
142
    } else {
143
      float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
W
Wangzheee 已提交
144 145 146 147
      /*      bool with_fp16 =
                engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
      */
      bool with_fp16 = false;
148
      plugin::SkipLayerNormPluginDynamic* plugin =
W
Wangzheee 已提交
149 150
          new plugin::SkipLayerNormPluginDynamic(
              bias, scale, bias_size, scale_size, eps, with_fp16);
151
      layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    }

    auto output_name = op_desc.Output("Out")[0];
    RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode);
#else
    PADDLE_THROW(platform::errors::Fatal(
        "You are running the TRT Dynamic Shape mode, need to confirm that "
        "your TRT version is no less than 6.0"));
#endif
  }
};

}  // namespace tensorrt
}  // namespace inference
}  // namespace paddle

REGISTER_TRT_OP_CONVERTER(skip_layernorm, SkipLayerNormOpConverter);