提交 2ca3fe5d 编写于 作者: Z zlsh80826

multihead att plugin

上级 954ebda1
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
namespace paddle { namespace paddle {
...@@ -30,7 +31,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -30,7 +31,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
// Declare inputs // Declare inputs
// Shouble be a 5 dims tensor. // Shouble be a 5 dims tensor.
auto* input = engine_->GetITensor(op_desc.Input("Input").front()); auto* input = engine_->GetITensor(op_desc.Input("Input").front());
auto* input_bias_qk = engine_->GetITensor(op_desc.Input("BiasQK").front());
// fc weights and fc bias // fc weights and fc bias
auto weight_name = op_desc.Input("W").front(); auto weight_name = op_desc.Input("W").front();
...@@ -65,14 +65,124 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -65,14 +65,124 @@ class MultiheadMatMulOpConverter : public OpConverter {
} }
} }
}; };
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number"));
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#ifdef USE_NVINFER_PLUGIN
int head_size = hidden / head_number;
// [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin]
auto transpose_weight_v2 = [](const float* src, float* dst, int N,
int H) {
const int HNH = H * N * H;
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int hnh = 0; hnh < HNH; ++hnh) {
dst[n * 3 * HNH + i * HNH + hnh] =
src[i * N * HNH + n * HNH + hnh];
}
}
}
};
// [3, N, H] -> [N, 3, H]
auto transpose_bias_v2 = [](const float* src, float* dst, int N, int H) {
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h];
}
}
}
};
memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float));
transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number,
head_size);
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
std::vector<float> bias_data_tmp;
bias_data_tmp.reserve(bias_t->numel());
memcpy(bias_data_tmp.data(), bias_data, bias_t->numel() * sizeof(float));
transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number,
head_size);
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
nvinfer1::Permutation permutation{1, 0, 2, 3, 4};
auto trans_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
trans_layer->setFirstTranspose(permutation);
auto* fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *trans_layer->getOutput(0), n, weight, bias);
auto pos_tensor = engine_->GetITensor("eval_placeholder_2");
plugin::CastIntPluginDynamic* cast_plugin =
new plugin::CastIntPluginDynamic();
auto cast_layer = engine_->AddPluginV2(&pos_tensor, 1, cast_plugin);
auto casted_pos_tensor = cast_layer->getOutput(0);
auto reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *casted_pos_tensor);
nvinfer1::Dims2 reshape_dim(0, 0);
nvinfer1::Permutation perm{1, 0, 2};
reshape_layer->setFirstTranspose(perm);
reshape_layer->setReshapeDimensions(reshape_dim);
auto reduce_layer =
TRT_ENGINE_ADD_LAYER(engine_, Reduce, *reshape_layer->getOutput(0),
nvinfer1::ReduceOperation::kMAX, 1, false);
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "1");
assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
bool has_mask = true;
const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32,
1}, // no bool type
};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj =
creator->createPlugin("CustomQKVToContextPluginDynamic", pluginPtr);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_layer->getOutput(0));
plugin_inputs.push_back(reduce_layer->getOutput(0));
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *pluginObj);
assert(plugin_layer != nullptr);
auto trans_r_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
assert(trans_r_layer != nullptr);
trans_r_layer->setFirstTranspose(permutation);
layer = trans_r_layer;
#else
// transpose weight_data from m * n to n * m // transpose weight_data from m * n to n * m
tranpose_weight(weight_data_tmp.data(), weight_data, m, n); auto* input_bias_qk =
engine_->GetITensor(op_desc.Input("BiasQK").front());
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
static_cast<size_t>(weight_t->numel())}; static_cast<size_t>(weight_t->numel())};
weight.dims.assign({n, m}); weight.dims.assign({n, m});
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<size_t>(bias_t->numel())}; static_cast<size_t>(bias_t->numel())};
...@@ -81,20 +191,18 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -81,20 +191,18 @@ class MultiheadMatMulOpConverter : public OpConverter {
weight.get(), bias.get()); weight.get(), bias.get());
auto* fc_out = fc_layer->getOutput(0); auto* fc_out = fc_layer->getOutput(0);
// add qkv to context // add qkv to context
int head_number = BOOST_GET_CONST(int, op_desc.GetAttr("head_number"));
int head_size = all_head_size / head_number; int head_size = all_head_size / head_number;
float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_out); plugin_inputs.push_back(fc_out);
plugin_inputs.push_back(input_bias_qk); plugin_inputs.push_back(input_bias_qk);
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::DynamicPluginTensorRT* plugin = plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden, head_number, head_size, new plugin::QkvToContextPluginDynamic(hidden, head_number, head_size,
scale, ban_fp16); scale, ban_fp16);
layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin); layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin);
#endif
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which " "You are running the Ernie(Bert) model in static shape mode, which "
......
...@@ -173,6 +173,8 @@ class OpConverter { ...@@ -173,6 +173,8 @@ class OpConverter {
"optim_input_shape should be same.")); "optim_input_shape should be same."));
} }
} }
std::cerr << "Declare input: " << input << std::endl;
if (input.find("stack_0.tmp_0") != std::string::npos) continue;
engine->DeclareInput( engine->DeclareInput(
input, FluidDataType2TRT( input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()), var->Proto()->type().lod_tensor().tensor().data_type()),
......
...@@ -23,8 +23,9 @@ class SliceOpConverter : public OpConverter { ...@@ -23,8 +23,9 @@ class SliceOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
// This OP is implemented by trt dynamic shpae plugin. // This OP is implemented by trt dynamic shpae plugin.
// Dynamic shape plugin requires TRT version greater than 6.0. // Dynamic shape plugin requires TRT version greater than 6.0.
std::cerr << "slice op converter\n" << std::endl;
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert slice op to tensorrt layer"; VLOG(4) << "convert slice op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
...@@ -42,7 +43,7 @@ class SliceOpConverter : public OpConverter { ...@@ -42,7 +43,7 @@ class SliceOpConverter : public OpConverter {
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
bool ban_fp16 = engine_->disable_trt_plugin_fp16(); bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin = plugin::SlicePluginDynamic* plugin =
new plugin::SlicePluginDynamic(starts, ends, ends, ban_fp16); new plugin::SlicePluginDynamic(starts, ends, axes, ban_fp16);
layer = engine_->AddPluginV2(&input, 1, plugin); layer = engine_->AddPluginV2(&input, 1, plugin);
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
......
...@@ -183,6 +183,8 @@ class TRTConvertValidation { ...@@ -183,6 +183,8 @@ class TRTConvertValidation {
std::vector<void*> buffers(num_bindings); std::vector<void*> buffers(num_bindings);
for (const std::string& name : input_output_names) { for (const std::string& name : input_output_names) {
// std::cerr << "Binding name: " << name << std::endl;
if (name.find("stack_0.tmp_0") != std::string::npos) continue;
auto* var = scope_.FindVar(name); auto* var = scope_.FindVar(name);
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
const int bind_index = engine_->engine()->getBindingIndex(name.c_str()); const int bind_index = engine_->engine()->getBindingIndex(name.c_str());
......
...@@ -71,6 +71,7 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -71,6 +71,7 @@ void TensorRTEngine::FreezeNetwork() {
// build engine. // build engine.
infer_builder_->setMaxBatchSize(max_batch_); infer_builder_->setMaxBatchSize(max_batch_);
infer_builder_->setMaxWorkspaceSize(max_workspace_); infer_builder_->setMaxWorkspaceSize(max_workspace_);
infer_builder_config_->setMaxWorkspaceSize(max_workspace_);
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf); bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
#if IS_TRT_VERSION_GE(5000) #if IS_TRT_VERSION_GE(5000)
if (enable_fp16) { if (enable_fp16) {
......
...@@ -85,6 +85,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -85,6 +85,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"gelu", "gelu",
"layer_norm", "layer_norm",
"scale", "scale",
"slice",
}; };
}; };
......
...@@ -2,6 +2,7 @@ nv_library(tensorrt_plugin ...@@ -2,6 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
cast_int_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor) DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/cast_int_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
nvinfer1::DimsExprs CastIntPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
assert(output_index == 0);
return inputs[0];
}
bool CastIntPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
const nvinfer1::PluginTensorDesc& in = in_out[pos];
return (in.type == nvinfer1::DataType::kINT32);
}
nvinfer1::DataType CastIntPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0, platform::errors::InvalidArgument(
"The Cast Int only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[index];
}
__global__ void castIntKernel(const int64_t* input, int32_t* output,
size_t num_elements) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_elements) return;
output[idx] = input[idx] + 1;
}
int CastIntPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs, void* workspace,
cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
auto output_dims = output_desc[0].dims;
size_t num_elements = ProductDim(input_dims);
size_t out_num_elements = ProductDim(output_dims);
assert(input_type ==
nvinfer1::DataType::kINT32); // although the input is int64_t
assert(num_elements == out_num_elements);
const size_t num_threads = 256;
castIntKernel<<<num_elements / num_threads + 1, num_threads>>>(
static_cast<const int64_t*>(inputs[0]), static_cast<int32_t*>(outputs[0]),
num_elements);
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class CastIntPluginDynamic : public DynamicPluginTensorRT {
public:
CastIntPluginDynamic() {}
CastIntPluginDynamic(void const* serial_data, size_t serial_length) {}
~CastIntPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new CastIntPluginDynamic();
}
const char* getPluginType() const override { return "cast_int_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
size_t getSerializationSize() const override { return 0; }
void serialize(void* buffer) const override {}
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs, int nb_outputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const override;
void destroy() override { delete this; }
};
class CastIntPluginV2Creator : public nvinfer1::IPluginCreator {
public:
CastIntPluginV2Creator() {}
const char* getPluginName() const override { return "cast_int_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new CastIntPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(CastIntPluginV2Creator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -221,11 +221,14 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -221,11 +221,14 @@ class TensorRTEngineOp : public framework::OperatorBase {
num_inputs += 1; num_inputs += 1;
} }
const int num_bindings = num_inputs + Outputs("Ys").size(); const int num_bindings = num_inputs + Outputs("Ys").size();
// std::cerr << "num bindings: " << num_bindings << std::endl;
std::vector<void *> buffers(num_bindings); std::vector<void *> buffers(num_bindings);
// Bind input tensor to TRT. // Bind input tensor to TRT.
for (const auto &x : Inputs("Xs")) { for (const auto &x : Inputs("Xs")) {
if (param_names_.count(x)) continue; if (param_names_.count(x)) continue;
// std::cerr << "runTRT name: " << x << std::endl;
if (x.find("stack_0.tmp_0") != std::string::npos) continue;
// convert input and copy to TRT engine's buffer // convert input and copy to TRT engine's buffer
auto &t = auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x); inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册