未验证 提交 430b0099 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Paddle-TRT]: Ernie Dynamic shape support. (#23138)

* add dynamic plugin support.
test=develop

* change emb eltwise layernorm to math function
test=develop

* add emb eltwise layernorm
test=develop

* can run dynamic shape ernie
test=develop

* fix ci
test=develop

* add ut for trt ernie dynamic

test=develop

* refine dynamic shape c++ interface.
test=develop

* fix comments
test=develop

* fix comments
test=develop
上级 d0413e58
......@@ -101,8 +101,10 @@ function(select_nvcc_arch_flags out_variable)
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
set(cuda_arch_bin "60 61")
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
add_definitions("-DSUPPORTS_CUDA_FP16")
set(cuda_arch_bin "70")
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
add_definitions("-DSUPPORTS_CUDA_FP16")
set(cuda_arch_bin "75")
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
......
......@@ -176,10 +176,14 @@ struct Argument {
DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
// usually use for trt dynamic shape.
// Usually use for trt dynamic shape.
// TRT will select the best kernel according to opt shape
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
// run fp16.
DECL_ARGUMENT_FIELD(min_input_shape, MinInputShape, input_shape_t);
DECL_ARGUMENT_FIELD(max_input_shape, MaxInputShape, input_shape_t);
DECL_ARGUMENT_FIELD(optim_input_shape, OptimInputShape, input_shape_t);
DECL_ARGUMENT_FIELD(disable_trt_plugin_fp16, CloseTrtPluginFp16, bool);
DECL_ARGUMENT_FIELD(use_tensorrt, UseTensorRT, bool);
DECL_ARGUMENT_FIELD(tensorrt_max_batch_size, TensorRtMaxBatchSize, int);
......
......@@ -130,6 +130,11 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("optim_input_shape",
new std::map<std::string, std::vector<int>>(
argument->optim_input_shape()));
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will
// not
// run fp16.
pass->Set("disable_trt_plugin_fp16",
new bool(argument->disable_trt_plugin_fp16()));
}
if (pass_name == "ngraph_subgraph_pass") {
pass->Set("program",
......
......@@ -272,7 +272,6 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
// Check trt version for dynamic shape input.
if (min_input_shape.size() > 0 && TRT_VERSION < 6000) {
std::cout << "hello";
LOG_FIRST_N(WARNING, 1) << "You are using the dynamic size input mode of "
"Paddle-TRT, but we found that the version of "
"the TensorRT is less than 6.0, so we use the "
......@@ -284,18 +283,23 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
if (min_input_shape.size() > 0 && TRT_VERSION > 6000) {
LOG_FIRST_N(WARNING, 1)
<< "The Paddle lib links the " << TRT_VERSION / 1000.
<< " version TensorRT, "
<< "The Paddle lib links the " << TRT_VERSION << " version TensorRT, "
<< "make sure the runtime TensorRT you are using is no less than this "
"version, otherwise, there might be Segfault!";
}
// Setting the disable_trt_plugin_fp16 to true means that TRT plugin will not
// run fp16.
// When running fp16, the output accuracy of the model will be affected,
// closing the plugin fp16 may bring some improvement on accuracy.
bool disable_trt_plugin_fp16 = Get<bool>("disable_trt_plugin_fp16");
tensorrt::TensorRTEngine *trt_engine =
inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
.Create(engine_key + std::to_string(predictor_id),
Get<int>("max_batch_size"), Get<int>("workspace_size"),
precision_mode, calibrator.get(), Get<int>("gpu_device_id"),
min_input_shape, max_input_shape, opt_input_shape);
min_input_shape, max_input_shape, opt_input_shape,
disable_trt_plugin_fp16);
bool need_serialize = (use_static_engine && !load_from_memory);
if (need_serialize) {
......
......@@ -128,6 +128,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(min_input_shape_);
CP_MEMBER(max_input_shape_);
CP_MEMBER(optim_input_shape_);
CP_MEMBER(disable_trt_plugin_fp16_);
CP_MEMBER(use_lite_);
CP_MEMBER(lite_precision_mode_);
......@@ -226,10 +227,7 @@ MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
void AnalysisConfig::EnableTensorRtEngine(
int workspace_size, int max_batch_size, int min_subgraph_size,
AnalysisConfig::Precision precision_mode, bool use_static,
bool use_calib_mode,
std::map<std::string, std::vector<int>> min_input_shape,
std::map<std::string, std::vector<int>> max_input_shape,
std::map<std::string, std::vector<int>> optim_input_shape) {
bool use_calib_mode) {
#ifdef PADDLE_WITH_CUDA
if (!use_gpu()) {
LOG(ERROR) << "To use TensorRT engine, please call EnableGpu() first";
......@@ -243,9 +241,6 @@ void AnalysisConfig::EnableTensorRtEngine(
tensorrt_precision_mode_ = precision_mode;
trt_use_static_engine_ = use_static;
trt_use_calib_mode_ = use_calib_mode;
min_input_shape_ = min_input_shape;
max_input_shape_ = max_input_shape;
optim_input_shape_ = optim_input_shape;
Update();
#else
......@@ -254,6 +249,17 @@ void AnalysisConfig::EnableTensorRtEngine(
#endif
}
void AnalysisConfig::SetTRTDynamicShapeInfo(
std::map<std::string, std::vector<int>> min_input_shape,
std::map<std::string, std::vector<int>> max_input_shape,
std::map<std::string, std::vector<int>> optim_input_shape,
bool disable_trt_plugin_fp16) {
min_input_shape_ = min_input_shape;
max_input_shape_ = max_input_shape;
optim_input_shape_ = optim_input_shape;
disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}
// TODO(Superjomn) refactor this, buggy.
void AnalysisConfig::Update() {
auto info = SerializeInfoCache();
......
......@@ -428,6 +428,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetMinInputShape(config_.min_input_shape_);
argument_.SetMaxInputShape(config_.max_input_shape_);
argument_.SetOptimInputShape(config_.optim_input_shape_);
argument_.SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
}
if (config_.lite_engine_enabled()) {
......@@ -951,4 +952,6 @@ USE_TRT_CONVERTER(instance_norm);
USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
USE_TRT_CONVERTER(skip_layernorm);
#endif
......@@ -222,16 +222,29 @@ struct AnalysisConfig {
* @param min_subgrpah_size the minimum TensorRT subgraph size needed, if a
* subgraph is less than this, it will not transfer to TensorRT engine.
*/
void EnableTensorRtEngine(
int workspace_size = 1 << 20, int max_batch_size = 1,
int min_subgraph_size = 3, Precision precision = Precision::kFloat32,
bool use_static = false, bool use_calib_mode = true,
std::map<std::string, std::vector<int>> min_input_shape = {},
std::map<std::string, std::vector<int>> max_input_shape = {},
std::map<std::string, std::vector<int>> optim_input_shape = {});
void EnableTensorRtEngine(int workspace_size = 1 << 20,
int max_batch_size = 1, int min_subgraph_size = 3,
Precision precision = Precision::kFloat32,
bool use_static = false,
bool use_calib_mode = true);
/** A boolean state telling whether the TensorRT engine is used.
*/
bool tensorrt_engine_enabled() const { return use_tensorrt_; }
/**
* \brief Set min, max, opt shape for TensorRT Dynamic shape mode.
* @param min_input_shape the min input shape of the subgraph input
* @param max_input_shape the max input shape of the subgraph input
* @param opt_input_shape the opt input shape of the subgraph input
* @param disable_trt_plugin_fp16, setting this variable to true
* means that TRT plugin will not run fp16
*/
void SetTRTDynamicShapeInfo(
std::map<std::string, std::vector<int>> min_input_shape,
std::map<std::string, std::vector<int>> max_input_shape,
std::map<std::string, std::vector<int>> optim_input_shape,
bool disable_trt_plugin_fp16 = false);
/**
* \brief Turn on the usage of Lite sub-graph engine.
*/
......@@ -386,6 +399,10 @@ struct AnalysisConfig {
Precision tensorrt_precision_mode_;
bool trt_use_static_engine_;
bool trt_use_calib_mode_;
std::map<std::string, std::vector<int>> min_input_shape_{};
std::map<std::string, std::vector<int>> max_input_shape_{};
std::map<std::string, std::vector<int>> optim_input_shape_{};
bool disable_trt_plugin_fp16_{false};
// memory reuse related.
bool enable_memory_optim_{false};
......@@ -412,9 +429,6 @@ struct AnalysisConfig {
std::string serialized_info_cache_;
mutable std::unique_ptr<PassStrategy> pass_builder_;
std::map<std::string, std::vector<int>> min_input_shape_;
std::map<std::string, std::vector<int>> max_input_shape_;
std::map<std::string, std::vector<int>> optim_input_shape_;
bool use_lite_{false};
std::vector<std::string> lite_passes_filter_;
......
......@@ -77,12 +77,14 @@ const std::vector<std::string> kTRTSubgraphPasses({
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass", //
"multihead_matmul_fuse_pass", //
"conv_bn_fuse_pass", //
"fc_fuse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", //
"skip_layernorm_fuse_pass", //
"conv_bn_fuse_pass", //
"fc_fuse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
"conv_elementwise_add_act_fuse_pass", //
......
......@@ -3,7 +3,7 @@ nv_library(tensorrt_converter
SRCS mul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc
batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc
pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
nv_test(test_op_converter SRCS test_op_converter.cc DEPS
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class EmbEltwiseLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fluid swish op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
auto id_names = op_desc.Input("Ids");
auto emb_names = op_desc.Input("Embs");
PADDLE_ENFORCE_EQ(id_names.size(), emb_names.size(),
platform::errors::InvalidArgument(
"The id and emb size of fused EmbEltwiseLayerNormOp "
"should be same "));
int input_num = id_names.size();
// Declare inputs
std::vector<nvinfer1::ITensor*> input_ids;
for (int i = 0; i < input_num; i++) {
input_ids.push_back(engine_->GetITensor(id_names[i]));
}
std::vector<float*> input_embs;
std::vector<int> emb_sizes;
// get the presistable var's data
auto get_persistable_data = [&](const std::string& var_name,
framework::DDim* dims) -> float* {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
return temp_data;
};
int hidden = 0;
for (int i = 0; i < input_num; i++) {
framework::DDim emb_dims;
float* emb_data = get_persistable_data(emb_names[i], &emb_dims);
int64_t emb_size = framework::product(emb_dims);
input_embs.push_back(emb_data);
emb_sizes.push_back(emb_size);
PADDLE_ENFORCE_EQ(
emb_dims.size(), 2,
platform::errors::InvalidArgument(
"The fused EmbEltwiseLayerNorm's emb should be 2 dims."));
hidden = emb_dims[1];
}
framework::DDim bias_dims, scale_dims;
auto* bias =
get_persistable_data(op_desc.Input("Bias").front(), &bias_dims);
auto* scale =
get_persistable_data(op_desc.Input("Scale").front(), &scale_dims);
int64_t bias_size = framework::product(bias_dims);
int64_t scale_size = framework::product(scale_dims);
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::EmbEltwiseLayernormPluginDynamic* plugin =
new plugin::EmbEltwiseLayernormPluginDynamic(input_embs, bias, scale,
emb_sizes, bias_size,
scale_size, hidden, eps);
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm", {output_name},
test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(fused_embedding_eltwise_layernorm,
EmbEltwiseLayerNormOpConverter);
......@@ -18,32 +18,6 @@ namespace paddle {
namespace inference {
namespace tensorrt {
// Reorder the elements from istrides to ostrides, borrowed from TRT convert in
// tensorflow.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/tensorrt/convert/convert_nodes.cc#L318
template <typename T>
void Reorder2(nvinfer1::DimsHW shape, const T* idata, nvinfer1::DimsHW istrides,
T* odata, nvinfer1::DimsHW ostrides) {
for (int h = 0; h < shape.h(); ++h) {
for (int w = 0; w < shape.w(); ++w) {
odata[h * ostrides.h() + w * ostrides.w()] =
idata[h * istrides.h() + w * istrides.w()];
}
}
}
// indata c * k
// Reorder the data layout from CK to KC.
void ReorderCKtoKC(TensorRTEngine::Weight& iweights, // NOLINT
TensorRTEngine::Weight* oweights) {
int c = iweights.dims[0];
int k = iweights.dims[1];
oweights->dims.assign({k, c});
nvinfer1::DimsHW istrides = {1, k};
nvinfer1::DimsHW ostrides = {c, 1};
Reorder2({k, c}, static_cast<float const*>(iweights.get().values), istrides,
static_cast<float*>(const_cast<void*>(oweights->get().values)),
ostrides);
}
/*
* FC converter convert a MUL op in Fluid to a FC layer in TRT.
*/
......@@ -64,7 +38,6 @@ class FcOpConverter : public OpConverter {
}
// Declare inputs
auto* X = engine_->GetITensor(op_desc.Input(i_name).front());
// Declare weights
auto* Y_v = scope.FindVar(op_desc.Input(w_name).front());
PADDLE_ENFORCE_NOT_NULL(Y_v);
......@@ -101,28 +74,44 @@ class FcOpConverter : public OpConverter {
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL); // a matrix
size_t n_output = Y_t->dims()[1];
std::unique_ptr<framework::Tensor> tmp(new framework::LoDTensor());
tmp->Resize(Y_t->dims());
int m = Y_t->dims()[0];
int n = Y_t->dims()[1];
auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
dst[j * m + i] = src[i * n + j];
}
}
};
auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output,
TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) {
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *inputs,
n_output, weight.get(), bias.get());
auto output_name = op_desc.Output("Out").front();
if (activation_type == "relu") {
nvinfer1::IActivationLayer* relu_layer =
TRT_ENGINE_ADD_LAYER(engine_, Activation, *(fc_layer->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer, "fc", {output_name}, test_mode);
} else {
RreplenishLayerAndOutput(fc_layer, "fc", {output_name}, test_mode);
}
};
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(Y_t->numel());
memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float));
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
memcpy(tmp->mutable_data<float>(platform::CPUPlace()), weight_data,
Y_t->dims()[0] * Y_t->dims()[1] * sizeof(float));
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())};
TensorRTEngine::Weight tmp_weight(nvinfer1::DataType::kFLOAT,
static_cast<void*>(tmp->data<float>()),
static_cast<size_t>(Y_t->numel()));
weight.dims.assign({Y_t->dims()[0], Y_t->dims()[1]});
tmp_weight.dims = weight.dims;
// The data layout of TRT FC layer's weight is different from fluid's FC,
// need to reorder the elements.
ReorderCKtoKC(weight, &tmp_weight);
// Currently, the framework can only handle one fluid op -> one TRT layer,
// but fc fuses `mul` and `bias` (2 fluid ops), so here is a trick, just
// handle `mul`, leave `add` as another layer.
// DEBUG
weight.dims.assign({n, m});
float* bias_data = nullptr;
int bias_num = 0;
if (with_bias) {
......@@ -136,6 +125,10 @@ class FcOpConverter : public OpConverter {
static_cast<void*>(bias_data),
static_cast<size_t>(bias_num)};
if (engine_->with_dynamic_shape()) {
regist_fc(X, n_output, weight, bias);
return;
}
// in order to handle situations in NLP models(input dims < 3,
// x_num_col_dims != 1, etc.), reshape input to perform FC correctly.
auto* reshape_itensor = X;
......@@ -192,20 +185,7 @@ class FcOpConverter : public OpConverter {
reshape_layer->setReshapeDimensions(reshape_dim);
reshape_itensor = reshape_layer->getOutput(0);
}
auto* fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *reshape_itensor,
n_output, tmp_weight.get(), bias.get());
engine_->SetWeights(op_desc.Input(w_name).front(), std::move(tmp));
auto output_name = op_desc.Output("Out").front();
if (activation_type == "relu") {
nvinfer1::IActivationLayer* relu_layer =
TRT_ENGINE_ADD_LAYER(engine_, Activation, *(fc_layer->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer, "fc", {output_name}, test_mode);
} else {
RreplenishLayerAndOutput(fc_layer, "fc", {output_name}, test_mode);
}
regist_fc(reshape_itensor, n_output, weight, bias);
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
namespace paddle {
namespace inference {
......@@ -22,187 +23,93 @@ class MultiheadMatMulOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt "
"network structure";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* Q = engine_->GetITensor(op_desc.Input("Q").front());
auto* K = engine_->GetITensor(op_desc.Input("K").front());
auto* V = engine_->GetITensor(op_desc.Input("V").front());
auto* BiasQ = scope.FindVar(op_desc.Input("BiasQ").front());
auto* BiasK = scope.FindVar(op_desc.Input("BiasK").front());
auto* BiasV = scope.FindVar(op_desc.Input("BiasV").front());
auto* BiasQK = engine_->GetITensor(op_desc.Input("BiasQK").front());
PADDLE_ENFORCE_EQ(op_desc.Input("Q").size(), 1,
platform::errors::InvalidArgument(
"size of input Q of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(op_desc.Input("K").size(), 1,
platform::errors::InvalidArgument(
"size of input K of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(op_desc.Input("V").size(), 1,
platform::errors::InvalidArgument(
"size of input V of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(
op_desc.Input("BiasQK").size(), 1,
platform::errors::InvalidArgument(
"size of input BiasQK of multihead_matmul should be 1"));
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1,
platform::errors::InvalidArgument(
"size of output of multihead_matmul should be 1"));
PADDLE_ENFORCE_NOT_NULL(
BiasQ, platform::errors::InvalidArgument(
"param BiasQ of multihead_matmul should not be null"));
PADDLE_ENFORCE_NOT_NULL(
BiasK, platform::errors::InvalidArgument(
"param BiasK of multihead_matmul should not be null"));
PADDLE_ENFORCE_NOT_NULL(
BiasV, platform::errors::InvalidArgument(
"param BiasV of multihead_matmul should not be null"));
PADDLE_ENFORCE_EQ(
BiasQK->getDimensions().nbDims, 3,
platform::errors::InvalidArgument(
"dims size of input BiasQK of multihead_matmul should be 3"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("alpha"), true,
platform::errors::PreconditionNotMet(
"attribute alpha of multihead_matmul should not be empty"));
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("head_number"), true,
platform::errors::PreconditionNotMet(
"attribute head_number of multihead_matmul should not be empty"));
// Declare attributes
const bool transpose_q =
op_desc.HasAttr("transpose_Q")
? boost::get<bool>(op_desc.GetAttr("transpose_Q"))
: false;
const bool transpose_k =
op_desc.HasAttr("transpose_K")
? boost::get<bool>(op_desc.GetAttr("transpose_K"))
: true;
const bool transpose_v =
op_desc.HasAttr("transpose_V")
? boost::get<bool>(op_desc.GetAttr("transpose_V"))
: false;
const float alpha = boost::get<float>(op_desc.GetAttr("alpha"));
const int head_number = boost::get<int>(op_desc.GetAttr("head_number"));
nvinfer1::Dims q_shape = Q->getDimensions();
int seq_len = q_shape.d[0];
int size_per_head = q_shape.d[1] / head_number;
std::string alpha_name = op_desc.Output("Out")[0] + "_alpha";
framework::DDim alpha_dim = framework::make_ddim({1});
std::unique_ptr<framework::LoDTensor> alpha_t(new framework::LoDTensor());
alpha_t->Resize(alpha_dim);
float* alpha_data = alpha_t->mutable_data<float>(platform::CPUPlace());
alpha_data[0] = alpha;
TensorRTEngine::Weight scale{nvinfer1::DataType::kFLOAT,
static_cast<void*>(alpha_data), 1};
TensorRTEngine::Weight shift{nvinfer1::DataType::kFLOAT, nullptr, 0};
TensorRTEngine::Weight power{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* bias_q_t = BiasQ->GetMutable<framework::LoDTensor>();
auto* bias_k_t = BiasK->GetMutable<framework::LoDTensor>();
auto* bias_v_t = BiasV->GetMutable<framework::LoDTensor>();
float* bias_q_cpu_data = engine_->GetWeightCPUData(
op_desc.Input("BiasQ").front(), bias_q_t, false);
float* bias_k_cpu_data = engine_->GetWeightCPUData(
op_desc.Input("BiasK").front(), bias_k_t, false);
float* bias_v_cpu_data = engine_->GetWeightCPUData(
op_desc.Input("BiasV").front(), bias_v_t, false);
std::unique_ptr<framework::LoDTensor> bias_q_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> bias_k_tensor(
new framework::LoDTensor());
std::unique_ptr<framework::LoDTensor> bias_v_tensor(
new framework::LoDTensor());
bias_q_tensor->Resize(bias_q_t->dims());
bias_k_tensor->Resize(bias_k_t->dims());
bias_v_tensor->Resize(bias_v_t->dims());
platform::CPUPlace cpu_place;
TensorCopySync((*bias_q_t), cpu_place, bias_q_tensor.get());
TensorCopySync((*bias_k_t), cpu_place, bias_k_tensor.get());
TensorCopySync((*bias_v_t), cpu_place, bias_v_tensor.get());
TensorRTEngine::Weight scale_weights_q{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight shift_weights_q{
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_q_cpu_data),
bias_q_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights_q{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight scale_weights_k{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight shift_weights_k{
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_k_cpu_data),
bias_k_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights_k{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight scale_weights_v{nvinfer1::DataType::kFLOAT, nullptr,
0};
TensorRTEngine::Weight shift_weights_v{
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_v_cpu_data),
bias_v_tensor->memory_size() / sizeof(float)};
TensorRTEngine::Weight power_weights_v{nvinfer1::DataType::kFLOAT, nullptr,
0};
auto* q_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *Q, nvinfer1::ScaleMode::kCHANNEL,
shift_weights_q.get(), scale_weights_q.get(), power_weights_q.get());
auto* k_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *K, nvinfer1::ScaleMode::kCHANNEL,
shift_weights_k.get(), scale_weights_k.get(), power_weights_k.get());
auto* v_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *V, nvinfer1::ScaleMode::kCHANNEL,
shift_weights_v.get(), scale_weights_v.get(), power_weights_v.get());
auto* v_transpose_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(v_eltadd_layer->getOutput(0)));
auto* q_transpose_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(q_eltadd_layer->getOutput(0)));
auto* k_transpose_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(k_eltadd_layer->getOutput(0)));
nvinfer1::Dims3 head_reshape_dim(seq_len, head_number, size_per_head);
v_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
v_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
q_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
q_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
k_transpose_reshape_layer->setReshapeDimensions(head_reshape_dim);
k_transpose_reshape_layer->setSecondTranspose({1, 0, 2});
auto* q_scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *(q_transpose_reshape_layer->getOutput(0)),
nvinfer1::ScaleMode::kUNIFORM, shift.get(), scale.get(), power.get());
auto* qk_matmul_layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *(q_scale_layer->getOutput(0)), transpose_q,
*(k_transpose_reshape_layer->getOutput(0)), transpose_k);
auto* qk_eltadd_layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *BiasQK, *(qk_matmul_layer->getOutput(0)),
nvinfer1::ElementWiseOperation::kSUM);
auto* softmax_layer = TRT_ENGINE_ADD_LAYER(
engine_, SoftMax, *(qk_eltadd_layer->getOutput(0)));
softmax_layer->setAxes(4);
auto* qkv_matmul_layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *(softmax_layer->getOutput(0)), false,
*(v_transpose_reshape_layer->getOutput(0)), transpose_v);
auto* qkv_transpose_reshape_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *(qkv_matmul_layer->getOutput(0)));
nvinfer1::Dims2 qkv_reshape_dim(seq_len, head_number * size_per_head);
qkv_transpose_reshape_layer->setFirstTranspose({1, 0, 2});
qkv_transpose_reshape_layer->setReshapeDimensions(qkv_reshape_dim);
engine_->SetWeights(alpha_name, std::move(alpha_t));
engine_->SetWeights(op_desc.Input("BiasQ").front(),
std::move(bias_q_tensor));
engine_->SetWeights(op_desc.Input("BiasK").front(),
std::move(bias_k_tensor));
engine_->SetWeights(op_desc.Input("BiasV").front(),
std::move(bias_v_tensor));
auto output_name = op_desc.Output("Out").front();
RreplenishLayerAndOutput(qkv_transpose_reshape_layer, "multihead_matmul",
{output_name}, test_mode);
// Shouble be a 5 dims tensor.
auto* input = engine_->GetITensor(op_desc.Input("Input").front());
auto* input_bias_qk = engine_->GetITensor(op_desc.Input("BiasQK").front());
// fc weights and fc bias
auto weight_name = op_desc.Input("W").front();
auto bias_name = op_desc.Input("Bias").front();
auto* weight_v = scope.FindVar(weight_name);
auto* weight_t = weight_v->GetMutable<framework::LoDTensor>();
auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* weight_data =
engine_->GetWeightCPUData(weight_name, weight_t, false);
float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t, false);
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_t->numel());
memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float));
// (hidden, 3, all_head_size)
auto weight_dims = weight_t->dims();
int hidden = weight_dims[0]; // channels_in
int three = weight_dims[1]; // channels_out
int all_head_size = weight_dims[2]; // channels_out
int m = hidden;
int n = three * all_head_size;
auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
dst[j * m + i] = src[i * n + j];
}
}
};
// transpose weight_data from m * n to n * m
tranpose_weight(weight_data_tmp.data(), weight_data, m, n);
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(weight_t->numel())};
weight.dims.assign({n, m});
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<size_t>(bias_t->numel())};
auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight.get(), bias.get());
auto* fc_out = fc_layer->getOutput(0);
// add qkv to context
int head_number = boost::get<int>(op_desc.GetAttr("head_number"));
int head_size = all_head_size / head_number;
float scale = boost::get<float>(op_desc.GetAttr("alpha"));
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_out);
plugin_inputs.push_back(input_bias_qk);
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden, head_number, head_size,
scale, ban_fp16);
layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which "
"is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name},
test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
}
};
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class SkipLayerNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
#if IS_TRT_VERSION_GE(6000)
VLOG(4) << "convert fused skip layernorm op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
std::vector<nvinfer1::ITensor*> inputs;
inputs.push_back(input1);
inputs.push_back(input2);
auto get_persistable_data = [&](const std::string& arg_name,
framework::DDim* dims) -> float* {
std::string var_name = op_desc.Input(arg_name).front();
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
return temp_data;
};
framework::DDim bias_dims, scale_dims;
auto* bias = get_persistable_data("Bias", &bias_dims);
auto* scale = get_persistable_data("Scale", &scale_dims);
float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
int bias_size = framework::product(bias_dims);
int scale_size = framework::product(scale_dims);
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic(bias, scale, bias_size,
scale_size, eps, ban_fp16);
layer = engine_->AddPluginV2(inputs.data(), 2, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static"
"shape mode, which is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface"
" to set the shape information to run the dynamic shape mode."));
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "skip_layernorm", {output_name}, test_mode);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(skip_layernorm, SkipLayerNormOpConverter);
......@@ -160,6 +160,16 @@ void TensorRTEngine::FreezeNetwork() {
Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true));
}
infer_builder_config_->addOptimizationProfile(optim_profile_.get());
if (WithFp16()) {
infer_builder_config_->setFlag(nvinfer1::BuilderFlag::kFP16);
if (disable_trt_plugin_fp16()) {
LOG(INFO) << "NOTE: In order to achieve higher accuracy, you have "
"disabled the fp16 mode of TRT Plugin,\n"
<< "you can reopen it with "
"'config.SetDynamicShapeInfo(min_shape, max_shape, "
"opt_shape, false /*disable_trt_plugin_fp16*/)'";
}
}
infer_engine_.reset(infer_builder_->buildEngineWithConfig(
*network(), *infer_builder_config_));
#endif
......
......@@ -124,6 +124,7 @@ class TensorRTEngine {
const ShapeMapType min_input_shape = {},
const ShapeMapType max_input_shape = {},
const ShapeMapType optim_input_shape = {},
bool disable_trt_plugin_fp16 = false,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
......@@ -133,6 +134,7 @@ class TensorRTEngine {
min_input_shape_(min_input_shape),
max_input_shape_(max_input_shape),
optim_input_shape_(optim_input_shape),
disable_trt_plugin_fp16_(disable_trt_plugin_fp16),
logger_(logger) {
if (min_input_shape_.size() != 0 && max_input_shape_.size() != 0 &&
optim_input_shape_.size() != 0) {
......@@ -207,6 +209,13 @@ class TensorRTEngine {
void SetRuntimeBatch(size_t batch_size);
int GetRuntimeBatch();
bool WithFp16() {
bool enable_fp16 = (precision_ == AnalysisConfig::Precision::kHalf);
bool support_fp16 = infer_builder_->platformHasFastFp16();
return enable_fp16 && support_fp16;
}
int GetDeviceId() { return device_id_; }
nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
int num_inputs, plugin::PluginTensorRT*);
......@@ -264,9 +273,18 @@ class TensorRTEngine {
ShapeMapType min_input_shape() { return min_input_shape_; }
ShapeMapType max_input_shape() { return max_input_shape_; }
ShapeMapType optim_input_shape() { return optim_input_shape_; }
bool disable_trt_plugin_fp16() { return disable_trt_plugin_fp16_; }
bool with_dynamic_shape() { return with_dynamic_shape_; }
#if IS_TRT_VERSION_GE(6000)
nvinfer1::IPluginV2Layer* AddPluginV2(nvinfer1::ITensor* const* inputs,
int num_inputs,
plugin::DynamicPluginTensorRT* plugin) {
owned_pluginv2_.emplace_back(plugin);
return network()->addPluginV2(inputs, num_inputs, *plugin);
}
#endif
private:
// Each ICudaEngine object is bound to a specific GPU when it is instantiated,
// ensure that the thread is associated with the correct device by calling
......@@ -289,6 +307,7 @@ class TensorRTEngine {
ShapeMapType min_input_shape_;
ShapeMapType max_input_shape_;
ShapeMapType optim_input_shape_;
bool disable_trt_plugin_fp16_{false};
nvinfer1::ILogger& logger_;
// max data size for the buffers.
......@@ -322,6 +341,7 @@ class TensorRTEngine {
#if IS_TRT_VERSION_GE(6000)
infer_ptr<nvinfer1::IBuilderConfig> infer_builder_config_;
std::unique_ptr<nvinfer1::IOptimizationProfile> optim_profile_;
std::vector<std::unique_ptr<plugin::DynamicPluginTensorRT>> owned_pluginv2_;
#endif
std::mutex mutex_;
}; // class TensorRTEngine
......@@ -358,10 +378,12 @@ class TRTEngineManager {
const std::map<std::string, std::vector<int>> min_input_shape = {},
const std::map<std::string, std::vector<int>> max_input_shape = {},
const std::map<std::string, std::vector<int>> optim_input_shape = {},
bool disable_trt_plugin_fp16 = false,
nvinfer1::ILogger& logger = NaiveLogger::Global()) {
auto* p = new TensorRTEngine(max_batch, max_workspace, precision,
calibrator, device_id, min_input_shape,
max_input_shape, optim_input_shape, logger);
auto* p =
new TensorRTEngine(max_batch, max_workspace, precision, calibrator,
device_id, min_input_shape, max_input_shape,
optim_input_shape, disable_trt_plugin_fp16, logger);
engines_[name].reset(p);
return p;
}
......
......@@ -23,6 +23,11 @@ struct SimpleOpTypeSetTeller : public Teller {
SimpleOpTypeSetTeller() {
#if IS_TRT_VERSION_GE(5130)
teller_set.insert("relu6");
#endif
#if IS_TRT_VERSION_GE(6000)
teller_set.insert("fused_embedding_eltwise_layernorm");
teller_set.insert("multihead_matmul");
teller_set.insert("skip_layernorm");
#endif
}
......@@ -38,9 +43,11 @@ struct SimpleOpTypeSetTeller : public Teller {
private:
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{
{"mul", "conv2d", "pool2d", "relu", "depthwise_conv2d", "softmax",
"batch_norm", "elementwise_add", "leaky_relu", "fc"}};
std::unordered_set<std::string> teller_set{{
"mul", "conv2d", "pool2d",
"relu", "depthwise_conv2d", "softmax",
"batch_norm", "elementwise_add", "leaky_relu",
"fc"};
std::unordered_set<std::string> teller_set{
"mul",
"conv2d",
"pool2d",
......@@ -65,8 +72,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"instance_norm",
"gelu",
"layer_norm",
"multihead_matmul",
}};
};
};
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
......
nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu instance_norm_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor)
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
int EmbEltwiseLayernormPluginDynamic::initialize() {
embs_gpu_.reserve(embs_.size());
for (int i = 0; i < embs_.size(); i++) {
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
cudaMemcpy(embs_gpu_[i], embs_[i], emb_sizes_[i] * sizeof(float),
cudaMemcpyHostToDevice);
}
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
cudaMemcpy(bias_gpu_, bias_, bias_size_ * sizeof(float),
cudaMemcpyHostToDevice);
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
cudaMemcpy(scale_gpu_, scale_, scale_size_ * sizeof(float),
cudaMemcpyHostToDevice);
return 0;
}
size_t EmbEltwiseLayernormPluginDynamic::getSerializationSize() const {
return 0;
}
void EmbEltwiseLayernormPluginDynamic::serialize(void *buffer) const {}
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
PADDLE_ENFORCE_EQ(output_index, 0,
platform::errors::InvalidArgument(
"There is only one output of the EmbEltwiseLayernorm, "
"so the index should be zero,"
"but it's (%d)",
output_index));
PADDLE_ENFORCE_EQ(
nb_inputs, 3,
platform::errors::InvalidArgument(
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret;
ret.nbDims = 5;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = expr_builder.constant(hidden_size_);
ret.d[3] = expr_builder.constant(1);
ret.d[4] = expr_builder.constant(1);
return ret;
}
bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
(in_out && pos < (nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &desc = in_out[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
return false;
}
if (pos == 0) {
return desc.type == nvinfer1::DataType::kINT32;
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1 || pos == 2) {
return desc.type == nvinfer1::DataType::kINT32 &&
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
}
if (pos == 3) {
return desc.type == nvinfer1::DataType::kFLOAT;
}
}
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(
index, 0, platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return nvinfer1::DataType::kFLOAT;
}
int EmbEltwiseLayernormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
auto id_dims = input_desc[0].dims;
int batch = id_dims.d[0];
int seq_len = id_dims.d[1];
int input_num = embs_.size();
framework::Tensor in_ptr_tensor, emb_ptr_tensor;
int device_id;
cudaGetDevice(&device_id);
in_ptr_tensor.Resize({input_num});
emb_ptr_tensor.Resize({input_num});
int64_t *in_ptr_gpu_d =
in_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
int64_t *emb_ptr_gpu_d =
emb_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
std::vector<int64_t> in_ptr, emb_ptr;
for (int i = 0; i < input_num; i++) {
in_ptr.push_back(reinterpret_cast<uintptr_t>(inputs[i]));
emb_ptr.push_back(reinterpret_cast<uintptr_t>(embs_gpu_[i]));
}
cudaMemcpyAsync(in_ptr_gpu_d, in_ptr.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(emb_ptr_gpu_d, emb_ptr.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, stream);
auto out_type = output_desc[0].type;
const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1);
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only only support fp32 input."));
float *output_d = static_cast<float *>(outputs[0]);
operators::math::EmbEltwiseLayerNormFunctor<float> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch, seq_len, hidden_size_, in_ptr_gpu_d,
scale_gpu_, bias_gpu_, emb_ptr_gpu_d, output_d,
eps_, input_num, stream);
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
public:
explicit EmbEltwiseLayernormPluginDynamic(std::vector<float*> input_embs,
float* bias, float* scale,
std::vector<int> emb_sizes,
int bias_size, int scale_size,
int hidden_size, float eps)
: embs_(input_embs),
bias_(bias),
scale_(scale),
emb_sizes_(emb_sizes),
bias_size_(bias_size),
scale_size_(scale_size),
hidden_size_(hidden_size),
eps_(eps) {}
EmbEltwiseLayernormPluginDynamic(void const* serialData,
size_t serialLength) {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new EmbEltwiseLayernormPluginDynamic(
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_,
eps_);
}
const char* getPluginType() const override {
return "fused_embedding_eltwise_layernorm_plugin";
}
int getNbOutputs() const override { return 1; }
int initialize() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void destroy() override { delete this; }
private:
std::vector<float*> embs_;
float* bias_;
float* scale_;
// data on devices
float* bias_gpu_;
float* scale_gpu_;
std::vector<float*> embs_gpu_;
std::vector<int> emb_sizes_;
int bias_size_;
int scale_size_;
int hidden_size_;
float eps_;
};
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
template <typename T>
__global__ void transpose(T *src, T *dst, const int batch_size,
const int seq_len, const int head_num,
const int size_per_head) {
int batch_id = blockIdx.x / (head_num * seq_len);
int seq_id = blockIdx.x % seq_len;
int head_id = (blockIdx.x % (head_num * seq_len)) / seq_len;
dst[batch_id * (head_num * seq_len * size_per_head) +
seq_id * head_num * size_per_head + head_id * size_per_head +
threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}
template <typename T>
__global__ void TransposeQkvKernel(const int H, const T *input, T *output) {
// Input: BxSx3xNxH
// Bias: 3xSxB
// Output: 3xBxNxSxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z;
const int N = blockDim.y;
const int S = gridDim.x;
const int B = gridDim.y;
const int NH = N * H;
const int NHS = NH * S;
const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B;
const int i = threadIdx.x;
output[out_offset + i] = input[in_offset + i];
}
inline void TransposeQKV(const int batch, const int seq_len,
const int head_size, const int head_num,
const float *input, float *output,
cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 4 == 0 && scratch_size % 4 == 0) {
const int h = head_size / 4;
const float4 *input4 = reinterpret_cast<const float4 *>(input);
float4 *output4 = reinterpret_cast<float4 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 4));
TransposeQkvKernel<float4><<<grid, block, 0, stream>>>(h, input4, output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const float2 *input2 = reinterpret_cast<const float2 *>(input);
float2 *output2 = reinterpret_cast<float2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 2));
TransposeQkvKernel<float2><<<grid, block, 0, stream>>>(h, input2, output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024));
TransposeQkvKernel<float><<<grid, block, 0, stream>>>(head_size, input,
output);
}
}
#ifdef SUPPORTS_CUDA_FP16
inline void TransposeQKV(const int batch, const int seq_len,
const int head_size, const int head_num,
const half *input, half *output, cudaStream_t stream) {
int scratch_size = batch * head_num * seq_len * seq_len;
const dim3 grid(seq_len, batch, 3);
if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const half2 *input2 = reinterpret_cast<const half2 *>(input);
half2 *output2 = reinterpret_cast<half2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 2));
TransposeQkvKernel<half2><<<grid, block, 0, stream>>>(h, input2, output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024));
TransposeQkvKernel<half><<<grid, block, 0, stream>>>(head_size, input,
output);
}
}
#endif
int QkvToContextPluginDynamic::initialize() { return 0; }
size_t QkvToContextPluginDynamic::getSerializationSize() const { return 0; }
void QkvToContextPluginDynamic::serialize(void *buffer) const {}
nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
// input[0], (B, S, 3 * N * H, 1, 1)
// input[1], (B, head_num, seq_len, seq_len)
// output, (B, seq_len, hidden)
PADDLE_ENFORCE_EQ(output_index, 0,
platform::errors::InvalidArgument(
"There is only one output of the EmbEltwiseLayernorm, "
"so the index should be zero,"
"but it's (%d)",
output_index));
PADDLE_ENFORCE_EQ(
nb_inputs, 2,
platform::errors::InvalidArgument(
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
"it has (%d) inputs",
nb_inputs));
nvinfer1::DimsExprs ret;
ret.nbDims = 5;
ret.d[0] = inputs[0].d[0];
ret.d[1] = inputs[0].d[1];
ret.d[2] = expr_builder.constant(hidden_);
ret.d[3] = expr_builder.constant(1);
ret.d[4] = expr_builder.constant(1);
return ret;
}
bool QkvToContextPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
if (ban_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1) {
return in.type == prev.type && in.format == prev.format;
}
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType QkvToContextPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(
index, 0, platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return input_types[0];
}
int QkvToContextPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
int input_num = ProductDim(input_dims);
// input[0], (B, S, 3 * N * H, 1, 1)
int batch = input_dims.d[0];
int seq_len = input_dims.d[1];
framework::Tensor multihead_temp_tensor;
int scratch_size = batch * head_number_ * seq_len * seq_len * 1;
int device_id;
cudaGetDevice(&device_id);
multihead_temp_tensor.Resize({scratch_size + input_num});
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
auto *multihead_temp_data = multihead_temp_tensor.mutable_data<float>(
platform::CUDAPlace(device_id));
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
const float *input0_data = static_cast<const float *>(inputs[0]);
const float *input1_data = static_cast<const float *>(inputs[1]);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr,
stream);
auto *device_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id)));
const platform::CUDADeviceContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<float> multihead_compute_func;
multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_,
qkptr, input1_data, tptr, scale_,
static_cast<float>(0.0));
int grid = batch * head_number_ * seq_len;
int block = head_size_;
float *output = static_cast<float *>(outputs[0]);
transpose<float><<<grid, block, 0, stream>>>(tptr, output, batch, seq_len,
head_number_, head_size_);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
auto *multihead_temp_data =
multihead_temp_tensor.mutable_data<int16_t>( // NOLINT
platform::CUDAPlace(device_id));
half *qkptr = reinterpret_cast<half *>(multihead_temp_data);
half *tptr = qkptr + scratch_size;
const half *input0_data = static_cast<const half *>(inputs[0]);
const half *input1_data = static_cast<const half *>(inputs[1]);
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV(batch, seq_len, head_size_, head_number_, input0_data, tptr,
stream);
auto *device_ctx = static_cast<platform::CUDADeviceContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id)));
const platform::CUDADeviceContext &dev_ctx = *device_ctx;
operators::math::MultiHeadGPUComputeFunctor<half> multihead_compute_func;
multihead_compute_func(dev_ctx, batch, seq_len, head_number_, head_size_,
qkptr, input1_data, tptr, half(scale_), half(0.0));
int grid = batch * head_number_ * seq_len;
int block = head_size_;
half *output = static_cast<half *>(outputs[0]);
transpose<half><<<grid, block, 0, stream>>>(tptr, output, batch, seq_len,
head_number_, head_size_);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The QKV TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class QkvToContextPluginDynamic : public DynamicPluginTensorRT {
public:
explicit QkvToContextPluginDynamic(int hidden, int head_number, int head_size,
float scale, bool ban_fp16)
: hidden_(hidden),
head_number_(head_number),
head_size_(head_size),
scale_(scale),
ban_fp16_(ban_fp16) {}
QkvToContextPluginDynamic(void const* serialData, size_t serialLength) {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new QkvToContextPluginDynamic(hidden_, head_number_, head_size_,
scale_, ban_fp16_);
}
const char* getPluginType() const override { return "qkv_to_context_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void destroy() override { delete this; }
private:
int hidden_;
int head_number_;
int head_size_;
float scale_;
bool ban_fp16_;
};
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_runtime.h>
#include <stdio.h>
#include <cassert>
#include <cub/cub.cuh> // NOLINT
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
int SkipLayerNormPluginDynamic::initialize() {
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
cudaMemcpy(bias_gpu_, bias_, bias_size_ * sizeof(float),
cudaMemcpyHostToDevice);
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
cudaMemcpy(scale_gpu_, scale_, scale_size_ * sizeof(float),
cudaMemcpyHostToDevice);
return 0;
}
size_t SkipLayerNormPluginDynamic::getSerializationSize() const { return 0; }
void SkipLayerNormPluginDynamic::serialize(void *buffer) const {}
nvinfer1::DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) {
PADDLE_ENFORCE_EQ(
inputs[0].nbDims, 5,
platform::errors::InvalidArgument(
"The Input dim of the SkipLayernorm should be 5, but it's (%d) now.",
inputs[0].nbDims));
return inputs[0];
}
bool SkipLayerNormPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
in_out, platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos, nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) {
#ifdef SUPPORTS_CUDA_FP16
if (ban_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
} else {
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
}
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1) {
return in.type == prev.type && in.format == prev.format;
}
// output
return in.type == prev.type && in.format == prev.format;
}
nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0,
platform::errors::InvalidArgument(
"The SkipLayerNorm Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
input_types[0] == nvinfer1::DataType::kHALF),
true, platform::errors::InvalidArgument(
"The input type should be half or float"));
return input_types[0];
}
int SkipLayerNormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
size_t num = ProductDim(input_dims);
int hidden = input_dims.d[2];
auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
const float *input1 = static_cast<const float *>(inputs[0]);
const float *input2 = static_cast<const float *>(inputs[1]);
float *output = static_cast<float *>(outputs[0]);
operators::math::SkipLayerNormFunctor<float> skip_layer_norm_func;
skip_layer_norm_func(num, hidden, input1, input2, scale_gpu_, bias_gpu_,
output, eps_, stream);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef SUPPORTS_CUDA_FP16
const half *input1 = static_cast<const half *>(inputs[0]);
const half *input2 = static_cast<const half *>(inputs[1]);
half *output = static_cast<half *>(outputs[0]);
operators::math::SkipLayerNormFunctor<half> skip_layer_norm_func;
skip_layer_norm_func(num, hidden, input1, input2, scale_gpu_, bias_gpu_,
output, static_cast<half>(eps_), stream);
#else
PADDLE_THROW(platform::errors::Fatal(
"The cuda archs you specific should greater than 600."));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The SkipLayerNorm TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
public:
explicit SkipLayerNormPluginDynamic(float* bias, float* scale, int bias_size,
int scale_size, const float eps,
bool ban_fp16)
: bias_(bias),
scale_(scale),
bias_size_(bias_size),
scale_size_(scale_size),
eps_(eps),
ban_fp16_(ban_fp16) {}
SkipLayerNormPluginDynamic(void const* serialData, size_t serialLength) {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new SkipLayerNormPluginDynamic(bias_, scale_, bias_size_,
scale_size_, eps_, ban_fp16_);
}
const char* getPluginType() const override { return "skip_layernorm_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override;
size_t getSerializationSize() const override;
void serialize(void* buffer) const override;
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* inOut,
int nbInputs, int nbOutputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const override;
void destroy() override { delete this; }
private:
float* bias_;
float* scale_;
float* bias_gpu_;
float* scale_gpu_;
int bias_size_;
int scale_size_;
float eps_;
bool ban_fp16_;
};
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -16,10 +16,12 @@
#include <NvInfer.h>
#include <cstring>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
......@@ -112,6 +114,72 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
std::vector<nvinfer1::ITensor*> inputs_;
};
#if IS_TRT_VERSION_GE(6000)
class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
public:
DynamicPluginTensorRT() {}
DynamicPluginTensorRT(const void* serialized_data, size_t length) {}
// The Func in IPluginExt or IpluginExtV2
virtual const char* getPluginVersion() const { return "1"; }
virtual const char* getPluginType() const = 0;
int getNbOutputs() const { return 1; }
int initialize() override { return 0; }
void terminate() override{};
virtual size_t getSerializationSize() const = 0;
virtual void serialize(void* buffer) const = 0;
// The Func in IPluginV2
nvinfer1::IPluginV2DynamicExt* clone() const = 0;
virtual nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) = 0; // NOLINT
virtual bool supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) = 0;
virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) = 0;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const override {
return 0;
}
virtual int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) = 0;
virtual nvinfer1::DataType getOutputDataType(
int index, const nvinfer1::DataType* input_types,
int nb_inputs) const = 0;
void setPluginNamespace(const char* plugin_namespace) override {
name_space_ = plugin_namespace;
}
const char* getPluginNamespace() const override {
return name_space_.c_str();
}
virtual void destroy() = 0;
protected:
void deserializeBase(void const*& serial_data, // NOLINT
size_t& serial_length); // NOLINT
size_t getBaseSerializationSize() const;
void serializeBase(void*& buffer) const; // NOLINT
private:
std::string name_space_{"paddle_trt"};
std::string plugin_base_{"plugin_dynamic"};
};
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
......
......@@ -349,9 +349,6 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_resnext_test SRCS trt_resnext_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
inference_analysis_test(trt_bert_test SRCS trt_bert_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${BERT_INSTALL_DIR}/model)
inference_analysis_test(trt_fc_prelu_test SRCS trt_fc_prelu_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
......@@ -367,6 +364,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(test_analyzer_capi_gpu SRCS analyzer_capi_gpu_tester.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} paddle_fluid_c
ARGS --infer_model=${TRT_MODEL_INSTALL_DIR}/trt_inference_test_models)
set(TRT_MODEL_QUANT_RESNET_DIR "${INFERENCE_DEMO_INSTALL_DIR}/quant_small_model")
if (NOT EXISTS ${TRT_MODEL_QUANT_RESNET_DIR})
inference_download_and_uncompress(${INFERENCE_DEMO_INSTALL_DIR} ${INFERENCE_URL}/tensorrt_test "quant_small_model.tar.gz")
......@@ -382,6 +380,15 @@ if(WITH_GPU AND TENSORRT_FOUND)
inference_analysis_test(trt_dynamic_shape_test SRCS trt_dynamic_shape_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_DYNAMIC_MODEL})
set(TEST_TRT_ERNIE_MODEL "${TRT_MODEL_INSTALL_DIR}/ernie_test")
if (NOT EXISTS ${TEST_TRT_ERNIE_MODEL})
inference_download_and_uncompress(${TEST_TRT_ERNIE_MODEL} ${INFERENCE_URL}/tensorrt_test "ernie_model_4.tar.gz")
endif()
inference_analysis_test(test_trt_dynamic_shape_ernie SRCS trt_dynamic_shape_ernie_test.cc
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS}
ARGS --infer_model=${TEST_TRT_ERNIE_MODEL}/ernie_model_4)
endif()
set(LITE_MODEL_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/lite")
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
TEST(TensorRT, split_converter) {
AnalysisConfig config;
int batch_size = 1;
config.SetModel(FLAGS_infer_model);
config.EnableUseGpu(1200, 0);
config.SwitchUseFeedFetchOps(false);
config.EnableTensorRtEngine(1 << 30, batch_size, 10,
AnalysisConfig::Precision::kFloat32, false,
false);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config);
int64_t i0[128] = {
96, 54, 78, 37, 106, 35, 122, 33, 95, 63, 81, 60, 65, 68, 45, 96,
117, 61, 43, 15, 12, 64, 91, 100, 90, 74, 99, 23, 22, 91, 83, 13,
28, 71, 59, 15, 40, 26, 66, 18, 31, 87, 85, 11, 55, 67, 28, 126,
7, 89, 39, 67, 88, 29, 66, 38, 98, 1, 66, 38, 95, 56, 48, 95,
9, 38, 90, 82, 101, 6, 75, 46, 42, 89, 98, 12, 6, 101, 82, 55,
81, 113, 33, 91, 44, 73, 41, 39, 12, 113, 13, 86, 36, 91, 53, 68,
103, 67, 65, 92, 27, 76, 24, 107, 54, 94, 63, 10, 15, 32, 91, 45,
37, 126, 49, 118, 73, 127, 122, 119, 28, 96, 92, 79, 21, 90, 11, 40};
int64_t i1[128] = {
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,
15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74,
75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104,
105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119,
120, 121, 122, 123, 124, 125, 126, 127};
int64_t i2[128] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float i3[128 * 128] = {0.0};
int64_t i4[1] = {0};
auto input_names = predictor->GetInputNames();
auto input_t0 = predictor->GetInputTensor(input_names[0]);
input_t0->Reshape({batch_size, 128, 1});
input_t0->copy_from_cpu(i0);
auto input_t1 = predictor->GetInputTensor(input_names[1]);
input_t1->Reshape({batch_size, 128, 1});
input_t1->copy_from_cpu(i1);
auto input_t2 = predictor->GetInputTensor(input_names[2]);
input_t2->Reshape({batch_size, 128, 1});
input_t2->copy_from_cpu(i2);
auto input_t3 = predictor->GetInputTensor(input_names[3]);
input_t3->Reshape({batch_size, 128, 128});
input_t3->copy_from_cpu(i3);
auto input_t4 = predictor->GetInputTensor(input_names[4]);
input_t4->Reshape({batch_size, 1});
input_t4->copy_from_cpu(i4);
ASSERT_TRUE(predictor->ZeroCopyRun());
}
} // namespace inference
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gflags/gflags.h>
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tests/api/trt_test_helper.h"
namespace paddle {
namespace inference {
void run(const AnalysisConfig& config, std::vector<float>* out_data) {
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int run_batch = 1;
const int run_seq_len = 128;
std::vector<int64_t> tmp_input;
std::vector<float> tmp_four_input;
tmp_input.reserve(run_batch * run_seq_len);
tmp_four_input.reserve(run_batch * run_seq_len);
int64_t i0[run_seq_len] = {
1, 3558, 4, 75, 491, 89, 340, 313, 93, 4, 255, 10, 75, 321,
4095, 1902, 4, 134, 49, 75, 311, 14, 44, 178, 543, 15, 12043, 2,
75, 201, 340, 9, 14, 44, 486, 218, 1140, 279, 12043, 2};
int64_t i1[run_seq_len] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int64_t i2[run_seq_len] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 33, 34, 35, 36, 37, 38, 39};
float i3[run_seq_len] = {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0};
// first input
auto input_t = predictor->GetInputTensor(input_names[0]);
input_t->Reshape({run_batch, run_seq_len, 1});
input_t->copy_from_cpu(i0);
// second input
auto input_t2 = predictor->GetInputTensor(input_names[1]);
input_t2->Reshape({run_batch, run_seq_len, 1});
input_t2->copy_from_cpu(i1);
// third input.
auto input_t3 = predictor->GetInputTensor(input_names[2]);
input_t3->Reshape({run_batch, run_seq_len, 1});
input_t3->copy_from_cpu(i2);
auto input_t4 = predictor->GetInputTensor(input_names[3]);
input_t4->Reshape({run_batch, run_seq_len, 1});
input_t4->copy_from_cpu(i3);
ASSERT_TRUE(predictor->ZeroCopyRun());
auto output_names = predictor->GetOutputNames();
auto output_t = predictor->GetOutputTensor(output_names[0]);
std::vector<int> output_shape = output_t->shape();
int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1,
std::multiplies<int>());
out_data->resize(out_num);
output_t->copy_to_cpu(out_data->data());
}
void trt_ernie(bool with_fp16, std::vector<float> result) {
AnalysisConfig config;
std::string model_dir = FLAGS_infer_model;
SetConfig(&config, model_dir, true /* use_gpu */);
config.SwitchUseFeedFetchOps(false);
int head_number = 12;
int batch = 1;
int min_seq_len = 1;
int max_seq_len = 128;
int opt_seq_len = 128;
std::vector<int> min_shape = {batch, min_seq_len, 1};
std::vector<int> max_shape = {batch, max_seq_len, 1};
std::vector<int> opt_shape = {batch, opt_seq_len, 1};
// Set the input's min, max, opt shape
std::map<std::string, std::vector<int>> min_input_shape = {
{"read_file_0.tmp_0", min_shape},
{"read_file_0.tmp_1", min_shape},
{"read_file_0.tmp_2", min_shape},
{"stack_0.tmp_0", {batch, head_number, min_seq_len, min_seq_len}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"read_file_0.tmp_0", max_shape},
{"read_file_0.tmp_1", max_shape},
{"read_file_0.tmp_2", max_shape},
{"stack_0.tmp_0", {batch, head_number, max_seq_len, max_seq_len}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"read_file_0.tmp_0", opt_shape},
{"read_file_0.tmp_1", opt_shape},
{"read_file_0.tmp_2", opt_shape},
{"stack_0.tmp_0", {batch, head_number, opt_seq_len, opt_seq_len}}};
auto precision = AnalysisConfig::Precision::kFloat32;
if (with_fp16) {
precision = AnalysisConfig::Precision::kHalf;
}
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, true);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
std::vector<float> out_data;
run(config, &out_data);
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6);
}
}
TEST(AnalysisPredictor, no_fp16) {
std::vector<float> result = {0.597841, 0.219972, 0.182187};
trt_ernie(false, result);
}
TEST(AnalysisPredictor, fp16) {
#ifdef SUPPORTS_CUDA_FP16
std::vector<float> result = {0.598336, 0.219558, 0.182106};
trt_ernie(true, result);
#endif
}
} // namespace inference
} // namespace paddle
......@@ -34,9 +34,11 @@ TEST(AnalysisPredictor, use_gpu) {
{"image", {1, 1, 10, 10}}};
std::map<std::string, std::vector<int>> opt_input_shape = {
{"image", {1, 1, 3, 3}}};
config.EnableTensorRtEngine(
1 << 30, 1, 1, AnalysisConfig::Precision::kFloat32, false, true,
min_input_shape, max_input_shape, opt_input_shape);
config.EnableTensorRtEngine(1 << 30, 1, 1,
AnalysisConfig::Precision::kFloat32, false, true);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
auto predictor = CreatePaddlePredictor(config);
auto input_names = predictor->GetInputNames();
int channels = 1;
......
......@@ -93,7 +93,7 @@ set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_fun
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor)
endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
......
......@@ -20,90 +20,12 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
template <typename T>
using kvp = cub::KeyValuePair<T, T>;
template <typename T>
using cv2 = cub::CubVector<T, 2>;
template <typename T, int TPB>
__device__ inline void LayerNorm(const cv2<T> &thread_data, const int ld,
const int offset, const float *bias,
const float *scale, T *output, float eps) {
using BlockReduce = cub::BlockReduce<cv2<T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());
if (threadIdx.x == 0) {
mu = sum_kv.x;
rsigma = rsqrt(sum_kv.y - mu * mu + eps);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(scale[i]);
const T b(bias[i]);
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, unsigned TPB>
__global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
const T *scale, const T *bias,
const int64_t *embs, T *output,
float eps, int input_num) {
cub::Sum pair_sum;
// blockIdx.x: position in the sequence
// blockIdx.y: batch
// gridDim.x: Seq
// gridDim.y: Batch
extern __shared__ int64_t array_id[];
const T rhidden = T(1.f) / T(hidden);
const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
if (threadIdx.x == 0) {
for (int i = 0; i < input_num; ++i) {
const int64_t *ids_p = reinterpret_cast<const int64_t *>(ids[i]);
array_id[i] = ids_p[seq_pos];
}
}
__syncthreads();
const int64_t out_offset = seq_pos * hidden;
cv2<T> thread_data;
thread_data.x = 0;
thread_data.y = 0;
#pragma unroll
for (int it = threadIdx.x; it < hidden; it += TPB) {
T val = 0;
for (int i = 0; i < input_num; ++i) {
val += reinterpret_cast<const T *>(embs[i])[array_id[i] * hidden + it];
}
output[out_offset + it] = val;
const T rhiddenval = rhidden * val;
cv2<T> temp_data;
temp_data.x = rhiddenval;
temp_data.y = rhiddenval * val;
thread_data = pair_sum(thread_data, temp_data);
}
LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
}
template <typename DeviceContext, typename T>
class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
public:
......@@ -154,13 +76,11 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
auto *output_d = out->mutable_data<T>(context.GetPlace());
float eps = context.Attr<float>("epsilon");
const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1);
int shared_bytes = input_num * sizeof(int64_t);
EmbEltwiseLayernormKernel<
T, tpb><<<grid, block, shared_bytes, device_ctx.stream()>>>(
hidden, in_ids_d, scale_d, bias_d, in_embs_d, output_d, eps, input_num);
math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch, seq_len, hidden, in_ids_d, scale_d,
bias_d, in_embs_d, output_d, eps, input_num,
device_ctx.stream());
}
};
......
......@@ -18,271 +18,12 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
template <typename T>
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
val += __shfl_xor(val, mask, warpSize);
#endif
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = max(val, __shfl_xor(val, mask, warpSize));
#endif
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMax(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
val = warpReduceMax(val, mask);
return val;
}
template <typename T>
__global__ void add_QKV(const T *Q, const T *K, const T *V, T *q_buf_,
T *k_buf_, T *v_buf_, const T *bias_q, const T *bias_k,
const T *bias_v, int batch_size, int seq_len,
int head_num, int size_per_head) {
const T *data_ptr_q, *data_ptr_k, *data_ptr_v;
const T *bias_ptr_q, *bias_ptr_k, *bias_ptr_v;
int m = batch_size * seq_len;
int n = head_num * size_per_head;
int row_offset = (blockIdx.x % m) * n;
data_ptr_q = Q + row_offset;
data_ptr_k = K + row_offset;
data_ptr_v = V + row_offset;
// bias ptr
bias_ptr_q = bias_q;
bias_ptr_k = bias_k;
bias_ptr_v = bias_v;
int batch_id = (blockIdx.x % m) / seq_len;
int head_id = threadIdx.x / size_per_head;
int id_in_head = threadIdx.x % size_per_head;
int word_start_id = (blockIdx.x) % seq_len;
#if __CUDA_ARCH__ >= 350
T tmp_q = __ldg(&data_ptr_q[threadIdx.x]) + __ldg(&bias_ptr_q[threadIdx.x]);
T tmp_k = __ldg(&data_ptr_k[threadIdx.x]) + __ldg(&bias_ptr_k[threadIdx.x]);
T tmp_v = __ldg(&data_ptr_v[threadIdx.x]) + __ldg(&bias_ptr_v[threadIdx.x]);
#else
T tmp_q = data_ptr_q[threadIdx.x] + bias_ptr_q[threadIdx.x];
T tmp_k = data_ptr_k[threadIdx.x] + bias_ptr_k[threadIdx.x];
T tmp_v = data_ptr_v[threadIdx.x] + bias_ptr_v[threadIdx.x];
#endif
int target_id = batch_id * (seq_len * head_num * size_per_head) +
head_id * seq_len * size_per_head +
word_start_id * size_per_head + id_in_head;
q_buf_[target_id] = tmp_q;
k_buf_[target_id] = tmp_k;
v_buf_[target_id] = tmp_v;
}
// Keep to compare performance
template <typename T>
__global__ void add_QKV_V2(const T *Q, const T *K, const T *V, T *q_buf_,
T *k_buf_, T *v_buf_, const T *bias_Q,
const T *bias_K, const T *bias_V, int batch_size,
int seq_len, int head_num, int size_per_head,
const int word_per_block) {
const T *data_ptr;
T *buf_ptr;
const T *bias_ptr;
int m = batch_size * seq_len;
int n = head_num * size_per_head;
int qkv_id = blockIdx.x * word_per_block / m;
int row_offset = (blockIdx.x * word_per_block % m) * n;
if (qkv_id == 0) {
data_ptr = Q + row_offset;
buf_ptr = q_buf_;
bias_ptr = bias_Q;
} else if (qkv_id == 1) {
data_ptr = K + row_offset;
buf_ptr = k_buf_;
bias_ptr = bias_K;
} else {
data_ptr = V + row_offset;
buf_ptr = v_buf_;
bias_ptr = bias_V;
}
int batch_id = (blockIdx.x * word_per_block % m) / seq_len;
int head_id = threadIdx.x / size_per_head;
int id_in_head = threadIdx.x % size_per_head;
int word_start_id = (blockIdx.x * word_per_block) % seq_len;
#if __CUDA_ARCH__ >= 350
T bias = __ldg(&bias_ptr[threadIdx.x]);
#else
T bias = bias_ptr[threadIdx.x];
#endif
for (int i = word_start_id; i < word_start_id + word_per_block; ++i) {
T tmp = data_ptr[threadIdx.x] + bias;
int target_id = batch_id * (seq_len * head_num * size_per_head) +
head_id * seq_len * size_per_head + i * size_per_head +
id_in_head;
buf_ptr[target_id] = tmp;
data_ptr += n;
}
}
template <typename T>
__global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
const int batch_size,
const int head_num,
const int seq_len,
const unsigned mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0);
__shared__ float s_sum, s_max;
float qk = threadIdx.x < seq_len
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
float max_val = blockReduceMax<float>(tmp, mask);
if (threadIdx.x == 0) s_max = max_val;
__syncthreads();
float qk_tmp =
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x == 0) {
s_sum = sum_val + 1e-6f;
}
__syncthreads();
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
}
// For verify result
template <typename T>
__global__ void elt_qk_add(const T *bias_qk, T *qk_buf, int head_num,
int seq_len, int size_per_head, int batch_size) {
int m = batch_size * head_num * seq_len;
int row_id = blockIdx.x % m;
int dst_id = row_id * seq_len + threadIdx.x;
const T *bias_ptr = bias_qk;
#if __CUDA_ARCH__ >= 350
int tmp_bias = __ldg(&bias_ptr[dst_id]);
#else
int tmp_bias = bias_ptr[dst_id];
#endif
qk_buf[dst_id] += tmp_bias;
}
// Compute Q*K->softmax->eltadd
template <typename T>
void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
int seq_len, int size_per_head, int batch_size,
bool q_trans, bool k_trans, T *q_buf_, T *k_buf_,
T *qk_buf_, const T *bias_qk, T alpha, T beta) {
CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
auto stream = context.stream();
blas.BatchedGEMM(transA, transB, seq_len, seq_len, size_per_head, alpha,
q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num,
seq_len * size_per_head, seq_len * size_per_head);
int grid = batch_size * head_num * seq_len;
int block = seq_len;
// Align block to 32, also limit seq_len to max block size.
PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument(
"seq_len should <= 1024, "
"but received seq_len is:%d",
seq_len));
if (seq_len <= 32)
block = 32;
else if (seq_len > 32 && seq_len <= 64)
block = 64;
else if (seq_len > 64 && seq_len <= 128)
block = 128;
else if (seq_len > 128 && seq_len <= 256)
block = 256;
else if (seq_len > 256 && seq_len <= 512)
block = 512;
else
block = 1024;
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
template <typename T>
__global__ void transpose(T *src, T *dst, const int batch_size,
const int seq_len, const int head_num,
......@@ -295,25 +36,6 @@ __global__ void transpose(T *src, T *dst, const int batch_size,
threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
}
// Compute QK*V->transpose
template <typename T>
void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num,
int seq_len, int size_per_head, int batch_size,
bool qk_trans, bool v_trans, T *v_buf_, const T *qk_buf_,
T *dst, T alpha, T beta) {
int m = batch_size * seq_len;
int k = head_num * size_per_head;
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
auto stream = context.stream();
CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;
blas.BatchedGEMM(transA, transB, seq_len, size_per_head, seq_len, alpha,
qk_buf_, v_buf_, beta, dst, batch_size * head_num,
seq_len * seq_len, seq_len * size_per_head);
}
template <typename T>
inline __device__ T add_func(T a, T b);
......@@ -341,8 +63,8 @@ __device__ float4 add_func<float4>(float4 a, float4 b) {
}
template <typename T>
__global__ void transpose_qkv_kernel(const int H, const T *input, const T *bias,
T *output) {
__global__ void TransposeQkvKernel(const int H, const T *input, const T *bias,
T *output) {
// Input: BxSx3xNxH
// Bias: 3xSxB
// Output: 3xBxNxSxH
......@@ -385,8 +107,8 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 4));
transpose_qkv_kernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
output4);
TransposeQkvKernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
output4);
} else if (head_size % 2 == 0 && scratch_size % 2 == 0) {
const int h = head_size / 2;
const float2 *input2 = reinterpret_cast<const float2 *>(input);
......@@ -398,8 +120,8 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 2));
transpose_qkv_kernel<float2><<<grid, block, 0, stream>>>(h, input2, bias2,
output2);
TransposeQkvKernel<float2><<<grid, block, 0, stream>>>(h, input2, bias2,
output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
......@@ -407,30 +129,11 @@ void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024));
transpose_qkv_kernel<float><<<grid, block, 0, stream>>>(head_size, input,
bias, output);
TransposeQkvKernel<float><<<grid, block, 0, stream>>>(head_size, input,
bias, output);
}
}
template <typename T>
void MultiHeadGPUComputeV2(const platform::CUDADeviceContext &dev_ctx,
int batch, int seq_len, int head_num, int head_size,
T *qkptr, const T *bias_qk_ptr, T *tptr, T alpha,
T beta) {
auto stream = dev_ctx.stream();
const int tsize = batch * head_num * seq_len * head_size;
T *qptr = tptr;
T *kptr = qptr + tsize;
T *vptr = kptr + tsize;
// batch gemm stride, softmaxwithscale.
MatMulWithHeadQK<T>(dev_ctx, head_num, seq_len, head_size, batch, false, true,
qptr, kptr, qkptr, bias_qk_ptr, alpha, beta);
// batch gemm stride, transpose.
MatMulWithHeadQKV<T>(dev_ctx, head_num, seq_len, head_size, batch, false,
false, vptr, qkptr, tptr, T(1.0), beta);
}
template <typename DeviceContext, typename T>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public:
......@@ -502,8 +205,9 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data,
bias_d, tptr, stream);
MultiHeadGPUComputeV2<T>(device_ctx, batch, seq_len, head_number, head_size,
qkptr, bias_qk_d, tptr, scale, T(0.0));
math::MultiHeadGPUComputeFunctor<T> multihead_compute_func;
multihead_compute_func(device_ctx, batch, seq_len, head_number, head_size,
qkptr, bias_qk_d, tptr, scale, T(0.0));
int grid = batch * head_number * seq_len;
int block = head_size;
......
......@@ -63,6 +63,7 @@ math_library(matrix_bit_code)
math_library(unpooling)
math_library(vol2col)
math_library(prelu)
math_library(bert_encoder_functor)
math_library(tree2col DEPS math_function)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_cuda_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, int TPB>
__device__ inline void LayerNormSmall(T val, const kvp<T> &thread_data,
const int ld, const int idx,
const float *bias, const float *scale,
T *output, T eps) {
using BlockReduce = cub::BlockReduce<kvp<T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
}
__syncthreads();
if (threadIdx.x < ld) {
const T g(scale[threadIdx.x]);
const T b(bias[threadIdx.x]);
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, int TPB>
__device__ inline void LayerNorm(const kvp<T> &thread_data, const int ld,
const int offset, const float *bias,
const float *scale, T *output, T eps) {
using BlockReduce = cub::BlockReduce<kvp<T>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T mu; // mean
__shared__ T rsigma; // 1 / std.dev.
const auto sum_kv = BlockReduce(temp_storage).Reduce(thread_data, cub::Sum());
if (threadIdx.x == 0) {
mu = sum_kv.key;
rsigma = rsqrt(sum_kv.value - mu * mu + eps);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const T val = output[idx];
const T g(scale[i]);
const T b(bias[i]);
output[idx] = g * (val - mu) * rsigma + b;
}
}
template <typename T, unsigned TPB>
__global__ void EmbEltwiseLayernormKernel(int hidden, const int64_t *ids,
const float *scale, const float *bias,
const int64_t *embs, T *output,
float eps, int input_num) {
cub::Sum pair_sum;
// blockIdx.x: position in the sequence
// blockIdx.y: batch
// gridDim.x: Seq
// gridDim.y: Batch
extern __shared__ int64_t array_id[];
const T rhidden = T(1.f) / T(hidden);
const int64_t seq_pos = blockIdx.y + blockIdx.x * gridDim.y;
if (threadIdx.x == 0) {
for (int i = 0; i < input_num; ++i) {
const int64_t *ids_p = reinterpret_cast<const int64_t *>(ids[i]);
array_id[i] = ids_p[seq_pos];
}
}
__syncthreads();
const int64_t out_offset = seq_pos * hidden;
kvp<T> thread_data(0, 0);
#pragma unroll
for (int it = threadIdx.x; it < hidden; it += TPB) {
T val = 0;
for (int i = 0; i < input_num; ++i) {
val += reinterpret_cast<const T *>(embs[i])[array_id[i] * hidden + it];
}
output[out_offset + it] = val;
const T rhiddenval = rhidden * val;
thread_data = pair_sum(thread_data, kvp<T>(rhiddenval, rhiddenval * val));
}
LayerNorm<T, TPB>(thread_data, hidden, out_offset, bias, scale, output, eps);
}
template <typename T>
void EmbEltwiseLayerNormFunctor<T>::operator()(
int batch, int seq_len, int hidden, const int64_t *ids, const float *scale,
const float *bias, const int64_t *embs, T *output, float eps, int input_num,
cudaStream_t stream) {
const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1);
int shared_bytes = input_num * sizeof(int64_t);
EmbEltwiseLayernormKernel<T, tpb><<<grid, block, shared_bytes, stream>>>(
hidden, ids, scale, bias, embs, output, eps, input_num);
}
template class EmbEltwiseLayerNormFunctor<float>;
#ifdef SUPPORTS_CUDA_FP16
template class EmbEltwiseLayerNormFunctor<half>;
#endif
template <typename T>
__global__ void SoftmaxKernelWithEltadd(T *qk_buf_, const T *bias_qk_,
const int batch_size,
const int head_num, const int seq_len,
const unsigned mask) {
int qk_offset = blockIdx.x * seq_len;
assert(blockDim.x % 32 == 0);
__shared__ float s_sum, s_max;
float qk = threadIdx.x < seq_len
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
bias_qk_[threadIdx.x + qk_offset]))
: 0.0f;
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
float max_val = blockReduceMax<float>(tmp, mask);
if (threadIdx.x == 0) s_max = max_val;
__syncthreads();
float qk_tmp =
threadIdx.x < seq_len ? __expf(static_cast<float>(tmp - s_max)) : 0.0f;
float sum_val = blockReduceSum<float>(qk_tmp, mask);
if (threadIdx.x == 0) {
s_sum = sum_val + 1e-6f;
}
__syncthreads();
if (threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (T)(qk_tmp / s_sum);
}
template <typename T>
inline void MatMulWithHeadQK(const platform::CUDADeviceContext &context,
int head_num, int seq_len, int size_per_head,
int batch_size, bool q_trans, bool k_trans,
T *q_buf_, T *k_buf_, T *qk_buf_, const T *bias_qk,
T alpha, T beta) {
CBLAS_TRANSPOSE transA = !q_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !k_trans ? CblasNoTrans : CblasTrans;
typedef typename CUDATypeTraits<T>::TYPE run_type;
auto blas =
operators::math::GetBlas<platform::CUDADeviceContext, run_type>(context);
auto stream = context.stream();
blas.BatchedGEMM(
transA, transB, seq_len, seq_len, size_per_head,
static_cast<run_type>(alpha), reinterpret_cast<run_type *>(q_buf_),
reinterpret_cast<run_type *>(k_buf_), static_cast<run_type>(beta),
reinterpret_cast<run_type *>(qk_buf_), batch_size * head_num,
seq_len * size_per_head, seq_len * size_per_head);
int grid = batch_size * head_num * seq_len;
int block = seq_len;
// Align block to 32, also limit seq_len to max block size.
PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument(
"seq_len should <= 1024, "
"but received seq_len is:%d",
seq_len));
if (seq_len <= 32)
block = 32;
else if (seq_len > 32 && seq_len <= 64)
block = 64;
else if (seq_len > 64 && seq_len <= 128)
block = 128;
else if (seq_len > 128 && seq_len <= 256)
block = 256;
else if (seq_len > 256 && seq_len <= 512)
block = 512;
else
block = 1024;
SoftmaxKernelWithEltadd<T><<<grid, block, 0, stream>>>(
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
}
template <typename T>
inline void MatMulWithHeadQKV(const platform::CUDADeviceContext &context,
int head_num, int seq_len, int size_per_head,
int batch_size, bool qk_trans, bool v_trans,
T *v_buf_, const T *qk_buf_, T *dst, T alpha,
T beta) {
int m = batch_size * seq_len;
int k = head_num * size_per_head;
typedef typename CUDATypeTraits<T>::TYPE run_type;
auto blas =
operators::math::GetBlas<platform::CUDADeviceContext, run_type>(context);
auto stream = context.stream();
CBLAS_TRANSPOSE transA = !qk_trans ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !v_trans ? CblasNoTrans : CblasTrans;
blas.BatchedGEMM(
transA, transB, seq_len, size_per_head, seq_len,
static_cast<run_type>(alpha), reinterpret_cast<const run_type *>(qk_buf_),
reinterpret_cast<run_type *>(v_buf_), static_cast<run_type>(beta),
reinterpret_cast<run_type *>(dst), batch_size * head_num,
seq_len * seq_len, seq_len * size_per_head);
}
template <typename T>
void MultiHeadGPUComputeFunctor<T>::operator()(
const platform::CUDADeviceContext &dev_ctx, int batch, int seq_len,
int head_num, int head_size, T *qkptr, const T *bias_qk_ptr, T *tptr,
T alpha, T beta) {
auto stream = dev_ctx.stream();
const int tsize = batch * head_num * seq_len * head_size;
T *qptr = tptr;
T *kptr = qptr + tsize;
T *vptr = kptr + tsize;
// batch gemm stride, softmaxwithscale.
MatMulWithHeadQK<T>(dev_ctx, head_num, seq_len, head_size, batch, false, true,
qptr, kptr, qkptr, bias_qk_ptr, alpha, beta);
// batch gemm stride, transpose.
MatMulWithHeadQKV<T>(dev_ctx, head_num, seq_len, head_size, batch, false,
false, vptr, qkptr, tptr, T(1.0), beta);
}
template class MultiHeadGPUComputeFunctor<float>;
#ifdef SUPPORTS_CUDA_FP16
template class MultiHeadGPUComputeFunctor<half>;
#endif
template <typename T, unsigned TPB>
__global__ void SkipLayerNormSmallKernel(int num, int hidden, const T *input1,
const T *input2, T *output,
const float *scale, const float *bias,
float eps) {
const T rld = T(1) / T(hidden);
const int offset = blockIdx.x * hidden;
cub::Sum pair_sum;
kvp<T> thread_data(0, 0);
const int idx = offset + threadIdx.x;
T val = 0;
if (threadIdx.x < hidden) {
val = input1[idx] + input2[idx];
const T rldval = rld * val;
thread_data = pair_sum(thread_data, kvp<T>(rldval, rldval * val));
}
LayerNormSmall<T, TPB>(val, thread_data, hidden, idx, bias, scale, output,
eps);
}
template <typename T, unsigned TPB>
__global__ void SkipLayerNormKernel(int num, int hidden, const T *input1,
const T *input2, T *output,
const float *scale, const float *bias,
float eps) {
const T rld = T(1) / T(hidden);
const int offset = blockIdx.x * hidden;
cub::Sum pair_sum;
kvp<T> thread_data(0, 0);
for (int it = threadIdx.x; it < hidden; it += TPB) {
const int idx = offset + it;
const T val = input1[idx] + input2[idx];
const T rldval = rld * val;
thread_data = pair_sum(thread_data, kvp<T>(rldval, rldval * val));
output[idx] = val;
}
LayerNorm<T, TPB>(thread_data, hidden, offset, bias, scale, output, eps);
}
template <typename T>
void SkipLayerNormFunctor<T>::operator()(const int num, const int hidden,
const T *input1, const T *input2,
const float *scale, const float *bias,
T *output, T eps,
cudaStream_t stream) {
int block = num / hidden;
if (hidden <= 32) {
const int threads = 32;
SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
num, hidden, input1, input2, output, scale, bias, eps);
} else if (hidden <= 128) {
const int threads = 128;
SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
num, hidden, input1, input2, output, scale, bias, eps);
} else if (hidden == 384) {
const int threads = 384;
SkipLayerNormSmallKernel<T, threads><<<block, threads, 0, stream>>>(
num, hidden, input1, input2, output, scale, bias, eps);
} else {
const int threads = 256;
SkipLayerNormKernel<T, threads><<<block, threads, 0, stream>>>(
num, hidden, input1, input2, output, scale, bias, eps);
}
}
template class SkipLayerNormFunctor<float>;
#ifdef SUPPORTS_CUDA_FP16
template class SkipLayerNormFunctor<half>;
#endif
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <cub/cub.cuh> // NOLINT
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CUDATypeTraits;
#ifdef SUPPORTS_CUDA_FP16
template <>
struct CUDATypeTraits<half> {
typedef platform::float16 TYPE;
};
#endif
template <>
struct CUDATypeTraits<float> {
typedef float TYPE;
};
#ifdef PADDLE_WITH_CUDA
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// in_var emb in_var emb
// | | | |
// lookup_table lookup_table
// | |
// lkt_var lkt_var
// \ /
// elementwise_add
// |
// elt_out_var
//
template <typename T>
class EmbEltwiseLayerNormFunctor {
public:
void operator()(int batch, int seq_len, int hidden, const int64_t *ids,
const float *scale, const float *bias, const int64_t *embs,
T *output, float eps, int input_num, cudaStream_t stream);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template <typename T>
class MultiHeadGPUComputeFunctor {
public:
void operator()(const platform::CUDADeviceContext &dev_ctx, int batch,
int seq_len, int head_num, int head_size, T *qkptr,
const T *bias_qk_ptr, T *tptr, T alpha, T beta);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// other_op1 other_op2
// | |
// |------elementwise_add
// |
// layer_norm
// |
// other_op3
// |
template <typename T>
class SkipLayerNormFunctor {
public:
void operator()(const int num, const int hidden, const T *input1,
const T *input2, const float *scale, const float *bias,
T *output, T eps, cudaStream_t stream);
};
#endif
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda_fp16.h>
#include <algorithm>
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__device__ __forceinline__ T FromFloat(float a);
template <typename T>
__device__ __forceinline__ float ToFloat(T a);
template <typename T>
__device__ __forceinline__ T exp_func(T a);
template <typename T>
struct KeyValuePair;
template <typename T>
using kvp = KeyValuePair<T>;
// from_float
template <>
__device__ __forceinline__ float FromFloat<float>(float a) {
return a;
}
#ifdef SUPPORTS_CUDA_FP16
template <>
__device__ __forceinline__ half FromFloat<half>(float a) {
return __float2half(a);
}
#endif
// to_float
template <>
__device__ __forceinline__ float ToFloat<float>(float a) {
return a;
}
#ifdef SUPPORTS_CUDA_FP16
template <>
__device__ __forceinline__ float ToFloat<half>(half a) {
return __half2float(a);
}
#endif
template <>
__device__ __forceinline__ float exp_func<float>(float a) {
return expf(a);
}
#ifdef SUPPORTS_CUDA_FP16
template <>
__device__ __forceinline__ half exp_func<half>(half a) {
#if __CUDA_ARCH__ >= 600
return hexp(a);
#else
return FromFloat<half>(expf(ToFloat<half>(a)));
#endif
}
#endif
template <>
struct KeyValuePair<float> {
__device__ __forceinline__ KeyValuePair() {}
__device__ __forceinline__ KeyValuePair(float k, float v)
: key(k), value(v) {}
__device__ __forceinline__ KeyValuePair(const KeyValuePair &a) {
key = a.key;
value = a.value;
}
float key;
float value;
__device__ __forceinline__ KeyValuePair
operator+(const KeyValuePair &a) const {
KeyValuePair tmp;
tmp.key = key + a.key;
tmp.value = value + a.value;
return tmp;
}
};
#ifdef SUPPORTS_CUDA_FP16
template <>
struct KeyValuePair<half> {
__device__ __forceinline__ KeyValuePair() {}
__device__ __forceinline__ KeyValuePair(half k, half v) : key(k), value(v) {}
__device__ __forceinline__ KeyValuePair(const KeyValuePair &a) {
key = a.key;
value = a.value;
}
half key;
half value;
__device__ __forceinline__ KeyValuePair
operator+(const KeyValuePair &a) const {
const half2 a2 = __halves2half2(key, value);
const half2 b2 = __halves2half2(a.key, a.value);
const half2 res = __hadd2(a2, b2);
return KeyValuePair(res.x, res.y);
}
};
#endif
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
template <typename T>
__inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
val += __shfl_xor(val, mask, warpSize);
#endif
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceSum(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : static_cast<T>(0.0f);
val = warpReduceSum<T>(val, mask);
return val;
}
template <typename T>
__inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
#if __CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = max(val, __shfl_xor(val, mask, warpSize));
#endif
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__ T blockReduceMax(T val, unsigned mask) {
static __shared__ T shared[WARP_SIZE];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceMax(val, mask);
if (lane == 0) shared[wid] = val;
__syncthreads();
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (threadIdx.x < block_span) ? shared[lane] : -1e10f;
val = warpReduceMax(val, mask);
return val;
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -252,7 +252,15 @@ class TensorRTEngineOp : public framework::OperatorBase {
bind_index, inference::tensorrt::Vec2TRT_Dims(t_shape, x, true));
#endif
}
buffers[bind_index] = static_cast<void *>(t.data<float>());
auto type = t.type();
if (type == framework::proto::VarType::FP32) {
buffers[bind_index] = static_cast<void *>(t.data<float>());
} else if (type == framework::proto::VarType::INT64) {
buffers[bind_index] = static_cast<void *>(t.data<int64_t>());
} else {
PADDLE_THROW(platform::errors::Fatal(
"The TRT Engine OP only support float and int64_t input."));
}
}
// Bind output tensor to TRT.
......
......@@ -412,13 +412,16 @@ void BindAnalysisConfig(py::module *m) {
py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
py::arg("min_subgraph_size") = 3,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
py::arg("use_static") = false, py::arg("use_calib_mode") = true,
py::arg("use_static") = false, py::arg("use_calib_mode") = true)
.def("set_trt_dynamic_shape_info",
&AnalysisConfig::SetTRTDynamicShapeInfo,
py::arg("min_input_shape") =
std::map<std::string, std::vector<int>>({}),
py::arg("max_input_shape") =
std::map<std::string, std::vector<int>>({}),
py::arg("optim_input_shape") =
std::map<std::string, std::vector<int>>({}))
std::map<std::string, std::vector<int>>({}),
py::arg("disable_trt_plugin_fp16") = false)
.def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
.def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
py::arg("x") = true)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册