未验证 提交 ac0553a0 编写于 作者: Y Yuanle Liu 提交者: GitHub

fused_embedding_eltwise_layernorm_op and skip_layernorm_op support fp16 (#44969)

上级 3512bf11
...@@ -166,7 +166,6 @@ if(WITH_TENSORRT) ...@@ -166,7 +166,6 @@ if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference) pass_library(trt_map_matmul_to_mul_pass inference)
pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference) pass_library(trt_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(trt_multihead_matmul_fuse_pass inference) pass_library(trt_multihead_matmul_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference)
pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference) pass_library(preln_embedding_eltwise_layernorm_fuse_pass inference)
pass_library(preln_skip_layernorm_fuse_pass inference) pass_library(preln_skip_layernorm_fuse_pass inference)
pass_library(set_transformer_input_convert_pass inference) pass_library(set_transformer_input_convert_pass inference)
...@@ -177,6 +176,7 @@ endif() ...@@ -177,6 +176,7 @@ endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference) pass_library(embedding_eltwise_layernorm_fuse_pass inference)
pass_library(trt_skip_layernorm_fuse_pass inference)
endif() endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
......
...@@ -165,12 +165,17 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{ ...@@ -165,12 +165,17 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"gpu_cpu_map_matmul_v2_to_matmul_pass", "gpu_cpu_map_matmul_v2_to_matmul_pass",
"fc_fuse_pass", "fc_fuse_pass",
"fc_elementwise_layernorm_fuse_pass", "fc_elementwise_layernorm_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass",
"trt_skip_layernorm_fuse_pass",
"runtime_context_cache_pass",
}; };
const std::vector<std::string> kTrtLowerPrecisionPasses{ const std::vector<std::string> kTrtLowerPrecisionPasses{
"simplify_with_basic_ops_pass", "simplify_with_basic_ops_pass",
// "conv_bn_fuse_pass", // "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass",
"trt_embedding_eltwise_layernorm_fuse_pass",
"trt_skip_layernorm_fuse_pass",
"trt_map_matmul_v2_to_mul_pass", "trt_map_matmul_v2_to_mul_pass",
"trt_map_matmul_v2_to_matmul_pass", "trt_map_matmul_v2_to_matmul_pass",
"trt_map_matmul_to_mul_pass", "trt_map_matmul_to_mul_pass",
...@@ -186,6 +191,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -186,6 +191,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
"trt_skip_layernorm_fuse_pass", //
"multihead_matmul_fuse_pass_v2", // "multihead_matmul_fuse_pass_v2", //
"gpu_cpu_squeeze2_matmul_fuse_pass", // "gpu_cpu_squeeze2_matmul_fuse_pass", //
"gpu_cpu_reshape2_matmul_fuse_pass", // "gpu_cpu_reshape2_matmul_fuse_pass", //
......
...@@ -133,6 +133,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -133,6 +133,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
return weight; return weight;
}; };
auto GetFp16Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
*dim = temp_tensor->dims();
auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor);
return weight;
};
auto GetFp32Weight = [&](const std::string& var_name, auto GetFp32Weight = [&](const std::string& var_name,
framework::DDim* dim) -> TensorRTEngine::Weight { framework::DDim* dim) -> TensorRTEngine::Weight {
auto* temp_var = scope.FindVar(var_name); auto* temp_var = scope.FindVar(var_name);
...@@ -141,7 +150,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -141,7 +150,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor); auto weight = engine_->GetFp32TrtWeight(var_name, *temp_tensor);
return weight; return weight;
}; };
bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int hidden = 0; int hidden = 0;
for (int i = 0; i < input_num; i++) { for (int i = 0; i < input_num; i++) {
framework::DDim emb_dims; framework::DDim emb_dims;
...@@ -149,7 +158,11 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -149,7 +158,11 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
if (flag_varseqlen) { if (flag_varseqlen) {
weight = GetWeight(emb_names[i], &emb_dims); weight = GetWeight(emb_names[i], &emb_dims);
} else { } else {
weight = GetFp32Weight(emb_names[i], &emb_dims); if (with_fp16) {
weight = GetFp16Weight(emb_names[i], &emb_dims);
} else {
weight = GetFp32Weight(emb_names[i], &emb_dims);
}
} }
input_embs.push_back(weight.get()); input_embs.push_back(weight.get());
emb_sizes.push_back(weight.get().count); emb_sizes.push_back(weight.get().count);
...@@ -167,8 +180,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -167,8 +180,15 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims); bias_weight = GetWeight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims); scale_weight = GetWeight(op_desc.Input("Scale").front(), &scale_dims);
} else { } else {
bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims); if (with_fp16) {
scale_weight = GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims); bias_weight = GetFp16Weight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight =
GetFp16Weight(op_desc.Input("Scale").front(), &scale_dims);
} else {
bias_weight = GetFp32Weight(op_desc.Input("Bias").front(), &bias_dims);
scale_weight =
GetFp32Weight(op_desc.Input("Scale").front(), &scale_dims);
}
} }
int64_t bias_size = phi::product(bias_dims); int64_t bias_size = phi::product(bias_dims);
...@@ -282,21 +302,18 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -282,21 +302,18 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
test_mode); test_mode);
} }
} else { } else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
plugin::DynamicPluginTensorRT* plugin = nullptr; plugin::DynamicPluginTensorRT* plugin = nullptr;
std::vector<float*> input_embs_data; std::vector<void*> input_embs_data;
for (size_t i = 0; i < input_embs.size(); ++i) { for (size_t i = 0; i < input_embs.size(); ++i) {
input_embs_data.push_back(const_cast<float*>( input_embs_data.push_back(const_cast<void*>(
static_cast<const float*>(input_embs[i].values))); reinterpret_cast<const void*>(input_embs[i].values)));
} }
plugin = new plugin::EmbEltwiseLayernormPluginDynamic( plugin = new plugin::EmbEltwiseLayernormPluginDynamic(
input_embs_data, input_embs_data,
const_cast<float*>( const_cast<void*>(static_cast<const void*>(bias_weight.get().values)),
static_cast<const float*>(bias_weight.get().values)), const_cast<void*>(
const_cast<float*>( static_cast<const void*>(scale_weight.get().values)),
static_cast<const float*>(scale_weight.get().values)),
emb_sizes, emb_sizes,
bias_size, bias_size,
scale_size, scale_size,
......
...@@ -150,6 +150,15 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -150,6 +150,15 @@ class SkipLayerNormOpConverter : public OpConverter {
layer = plugin_layer; layer = plugin_layer;
} }
} else { } else {
auto GetFp16Weight =
[&](const std::string& arg_name) -> TensorRTEngine::Weight {
std::string var_name = op_desc.Input(arg_name).front();
auto* temp_var = scope.FindVar(var_name);
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
auto weight = engine_->GetFp16TrtWeight(var_name, *temp_tensor);
return weight;
};
auto GetFp32Weight = auto GetFp32Weight =
[&](const std::string& arg_name) -> TensorRTEngine::Weight { [&](const std::string& arg_name) -> TensorRTEngine::Weight {
std::string var_name = op_desc.Input(arg_name).front(); std::string var_name = op_desc.Input(arg_name).front();
...@@ -159,20 +168,29 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -159,20 +168,29 @@ class SkipLayerNormOpConverter : public OpConverter {
return weight; return weight;
}; };
auto bias_weight = GetFp32Weight("Bias").get(); // bool with_fp16 = engine_->WithFp16() &&
auto scale_weight = GetFp32Weight("Scale").get(); // !engine_->disable_trt_plugin_fp16() &&
// (input1->getType() == nvinfer1::DataType::kHALF);
bool with_fp16 = false;
TensorRTEngine::Weight bias_weight, scale_weight;
if (with_fp16) {
bias_weight = GetFp16Weight("Bias");
scale_weight = GetFp16Weight("Scale");
} else {
bias_weight = GetFp32Weight("Bias");
scale_weight = GetFp32Weight("Scale");
}
float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); float eps = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon"));
// bool with_fp16 =
// engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
bool with_fp16 = false;
plugin::SkipLayerNormPluginDynamic* plugin = plugin::SkipLayerNormPluginDynamic* plugin =
new plugin::SkipLayerNormPluginDynamic( new plugin::SkipLayerNormPluginDynamic(
static_cast<const float*>(bias_weight.values), const_cast<void*>(
static_cast<const float*>(scale_weight.values), static_cast<const void*>(bias_weight.get().values)),
bias_weight.count, const_cast<void*>(
scale_weight.count, static_cast<const void*>(scale_weight.get().values)),
bias_weight.get().count,
scale_weight.get().count,
eps, eps,
with_fp16); with_fp16);
layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin); layer = engine_->AddDynamicPlugin(inputs.data(), 2, plugin);
......
...@@ -31,7 +31,7 @@ namespace inference { ...@@ -31,7 +31,7 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
void TensorRTEngine::Weight::SetDataType(phi::DataType type) { void TensorRTEngine::Weight::SetDataType(phi::DataType type) {
nvinfer1::DataType nv_type; nvinfer1::DataType nv_type = nvinfer1::DataType::kFLOAT;
switch (type) { switch (type) {
case phi::DataType::FLOAT32: case phi::DataType::FLOAT32:
nv_type = nvinfer1::DataType::kFLOAT; nv_type = nvinfer1::DataType::kFLOAT;
...@@ -455,6 +455,67 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { ...@@ -455,6 +455,67 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
runtime_batch_ = batch_size; runtime_batch_ = batch_size;
} }
// Note: Only for support plugin.
TensorRTEngine::Weight TensorRTEngine::GetFp16TrtWeight(
const std::string &name, const framework::Tensor &weight_tensor) {
static int name_suffix_counter = 0;
std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix;
platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix),
0,
platform::errors::AlreadyExists(
"The weight named %s is set into the weight map "
"twice in TRT OP converter.",
name_with_suffix));
weight_map[name_with_suffix].reset(new framework::Tensor());
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
TensorRTEngine::Weight weight;
weight.SetCount(weight_tensor.numel());
weight.SetDataType(nvinfer1::DataType::kHALF);
// weight_tensor.dims().;
// if trt not support dtype, we need to cast to fp16.
if (weight_tensor.dtype() == phi::DataType::BFLOAT16) {
framework::Tensor bf16_tensor;
bf16_tensor.clear();
paddle::framework::TensorCopySync(
weight_tensor, platform::CPUPlace(), &bf16_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::FLOAT16);
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
auto *fp16_data = weight_map[name_with_suffix]->mutable_data<float16>(
platform::CPUPlace());
auto *bf16_data = bf16_tensor.mutable_data<bfloat16>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
fp16_data[i] = static_cast<float16>(bf16_data[i]);
}
} else if (weight_tensor.dtype() == phi::DataType::FLOAT32) {
framework::Tensor fp32_tensor;
fp32_tensor.clear();
paddle::framework::TensorCopySync(
weight_tensor, platform::CPUPlace(), &fp32_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::FLOAT16);
weight_map[name_with_suffix]->Resize(weight_tensor.dims());
auto *fp16_data = weight_map[name_with_suffix]->mutable_data<float16>(
platform::CPUPlace());
auto *fp32_data = fp32_tensor.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
fp16_data[i] = static_cast<float16>(fp32_data[i]);
}
} else {
paddle::framework::TensorCopySync(
weight_tensor, cpu_place, weight_map[name_with_suffix].get());
}
weight.SetValues(weight_map[name_with_suffix]->data());
name_suffix_counter += 1;
return weight;
}
// Note: Only for support plugin.
TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight( TensorRTEngine::Weight TensorRTEngine::GetFp32TrtWeight(
const std::string &name, const framework::Tensor &weight_tensor) { const std::string &name, const framework::Tensor &weight_tensor) {
static int name_suffix_counter = 0; static int name_suffix_counter = 0;
......
...@@ -421,6 +421,10 @@ class TensorRTEngine { ...@@ -421,6 +421,10 @@ class TensorRTEngine {
quant_dynamic_range_[tensor] = range; quant_dynamic_range_[tensor] = range;
} }
// Get fp16 trt weight. If src weight is not fp16, we will cast.
Weight GetFp16TrtWeight(const std::string& name,
const framework::Tensor& weight_tensor);
// Get fp32 trt weight. If src weight is not fp32, we will cast. // Get fp32 trt weight. If src weight is not fp32, we will cast.
Weight GetFp32TrtWeight(const std::string& name, Weight GetFp32TrtWeight(const std::string& name,
const framework::Tensor& weight_tensor); const framework::Tensor& weight_tensor);
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <cassert> #include <cassert>
#include <cub/cub.cuh> // NOLINT #include <cub/cub.cuh> // NOLINT
#include <type_traits>
#include <vector> #include <vector>
#include "glog/logging.h" #include "glog/logging.h"
...@@ -32,12 +33,6 @@ namespace plugin { ...@@ -32,12 +33,6 @@ namespace plugin {
// Dynamic shape plugin requires TRT version greater than 6.0. // Dynamic shape plugin requires TRT version greater than 6.0.
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
template <typename T>
EmbEltwiseLayernormPluginDynamicImpl<
T>::~EmbEltwiseLayernormPluginDynamicImpl() {}
inline half fp32tofp16(float x) { return static_cast<half>(x); }
template <typename T> template <typename T>
void EmbEltwiseLayernormPluginDynamicImpl<T>::shareGPUData( void EmbEltwiseLayernormPluginDynamicImpl<T>::shareGPUData(
const EmbEltwiseLayernormPluginDynamicImplBase *anthor) { const EmbEltwiseLayernormPluginDynamicImplBase *anthor) {
...@@ -62,36 +57,24 @@ int EmbEltwiseLayernormPluginDynamicImpl<T>::initialize() { ...@@ -62,36 +57,24 @@ int EmbEltwiseLayernormPluginDynamicImpl<T>::initialize() {
embs_gpu_.resize(embs_.size()); embs_gpu_.resize(embs_.size());
for (int i = 0; i < embs_.size(); i++) { for (int i = 0; i < embs_.size(); i++) {
if (embs_[i]) { if (embs_[i]) {
T *host_ptr; T *host_ptr = embs_[i];
auto size = emb_sizes_[i]; auto size = emb_sizes_[i];
if (std::is_same<T, half>::value) {
host_ptr = new T[size];
std::transform(embs_[i], (embs_[i] + size), host_ptr, fp32tofp16);
} else {
host_ptr = reinterpret_cast<T *>(embs_[i]);
}
cudaMalloc(&embs_gpu_[i], sizeof(T) * size); cudaMalloc(&embs_gpu_[i], sizeof(T) * size);
cudaMemcpy( cudaMemcpy(
embs_gpu_[i], host_ptr, size * sizeof(T), cudaMemcpyHostToDevice); embs_gpu_[i], host_ptr, size * sizeof(T), cudaMemcpyHostToDevice);
if (std::is_same<T, half>::value) {
delete[] host_ptr;
}
} }
} }
if (bias_) { if (bias_) {
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_); cudaMalloc(&bias_gpu_, sizeof(T) * bias_size_);
cudaMemcpy( cudaMemcpy(
bias_gpu_, bias_, bias_size_ * sizeof(float), cudaMemcpyHostToDevice); bias_gpu_, bias_, bias_size_ * sizeof(T), cudaMemcpyHostToDevice);
} }
if (scale_) { if (scale_) {
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_); cudaMalloc(&scale_gpu_, sizeof(T) * scale_size_);
cudaMemcpy(scale_gpu_, cudaMemcpy(
scale_, scale_gpu_, scale_, scale_size_ * sizeof(T), cudaMemcpyHostToDevice);
scale_size_ * sizeof(float),
cudaMemcpyHostToDevice);
} }
int input_num = embs_.size(); int input_num = embs_.size();
...@@ -239,22 +222,14 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination( ...@@ -239,22 +222,14 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
"The EmbEltwiseLayerNorm's output should be one" "The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.", "but it's (%d) outputs.",
nb_outputs)); nb_outputs));
PADDLE_ENFORCE_EQ(nb_outputs, int all_nums = nb_inputs + nb_outputs;
1,
platform::errors::InvalidArgument(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
pos, pos,
nb_inputs + nb_outputs, all_nums,
platform::errors::InvalidArgument("The pos(%d) should be less than the " platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.", "num(%d) of the input and the output.",
pos, pos,
nb_inputs + nb_outputs)); all_nums));
int all_nums = nb_inputs + nb_outputs;
const nvinfer1::PluginTensorDesc &desc = in_out[pos]; const nvinfer1::PluginTensorDesc &desc = in_out[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) { if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
return false; return false;
...@@ -269,7 +244,7 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination( ...@@ -269,7 +244,7 @@ bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
return desc.type == nvinfer1::DataType::kINT32 && return desc.type == nvinfer1::DataType::kINT32 &&
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1]; desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
} }
// output
if (pos == all_nums - 1) { if (pos == all_nums - 1) {
if (with_fp16_ == false) { if (with_fp16_ == false) {
return desc.type == nvinfer1::DataType::kFLOAT; return desc.type == nvinfer1::DataType::kFLOAT;
...@@ -288,7 +263,7 @@ nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType( ...@@ -288,7 +263,7 @@ nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
index, index,
0, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only has one input, so the " "The EmbEltwiseLayernorm Plugin only has one output, so the "
"index value should be 0, but get %d.", "index value should be 0, but get %d.",
index)); index));
if (with_fp16_) if (with_fp16_)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -49,9 +50,9 @@ template <typename T> ...@@ -49,9 +50,9 @@ template <typename T>
class EmbEltwiseLayernormPluginDynamicImpl class EmbEltwiseLayernormPluginDynamicImpl
: public EmbEltwiseLayernormPluginDynamicImplBase { : public EmbEltwiseLayernormPluginDynamicImplBase {
public: public:
explicit EmbEltwiseLayernormPluginDynamicImpl(std::vector<float*> input_embs, explicit EmbEltwiseLayernormPluginDynamicImpl(std::vector<T*> input_embs,
float* bias, T* bias,
float* scale, T* scale,
std::vector<int> emb_sizes, std::vector<int> emb_sizes,
int bias_size, int bias_size,
int scale_size, int scale_size,
...@@ -66,7 +67,7 @@ class EmbEltwiseLayernormPluginDynamicImpl ...@@ -66,7 +67,7 @@ class EmbEltwiseLayernormPluginDynamicImpl
hidden_size_(hidden_size), hidden_size_(hidden_size),
eps_(eps) {} eps_(eps) {}
~EmbEltwiseLayernormPluginDynamicImpl(); ~EmbEltwiseLayernormPluginDynamicImpl() {}
int initialize(); int initialize();
void terminate(); void terminate();
...@@ -79,13 +80,13 @@ class EmbEltwiseLayernormPluginDynamicImpl ...@@ -79,13 +80,13 @@ class EmbEltwiseLayernormPluginDynamicImpl
void shareGPUData(const EmbEltwiseLayernormPluginDynamicImplBase* anthor); void shareGPUData(const EmbEltwiseLayernormPluginDynamicImplBase* anthor);
private: private:
std::vector<float*> embs_; std::vector<T*> embs_;
float* bias_{nullptr}; T* bias_{nullptr};
float* scale_{nullptr}; T* scale_{nullptr};
// data on devices // data on devices
float* bias_gpu_{nullptr}; T* bias_gpu_{nullptr};
float* scale_gpu_{nullptr}; T* scale_gpu_{nullptr};
std::vector<T*> embs_gpu_; std::vector<T*> embs_gpu_;
std::vector<int> emb_sizes_; std::vector<int> emb_sizes_;
...@@ -101,9 +102,9 @@ class EmbEltwiseLayernormPluginDynamicImpl ...@@ -101,9 +102,9 @@ class EmbEltwiseLayernormPluginDynamicImpl
class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
public: public:
explicit EmbEltwiseLayernormPluginDynamic(std::vector<float*> input_embs, explicit EmbEltwiseLayernormPluginDynamic(std::vector<void*> input_embs,
float* bias, void* bias,
float* scale, void* scale,
std::vector<int> emb_sizes, std::vector<int> emb_sizes,
int bias_size, int bias_size,
int scale_size, int scale_size,
...@@ -123,14 +124,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { ...@@ -123,14 +124,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp16"; VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp16";
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<half>(embs_, instantiateImpl<half>();
bias_,
scale_,
emb_sizes_,
bias_size_,
scale_size_,
hidden_size_,
eps_);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be " "The Ernie(Bert) tensorRT plugin should be "
...@@ -141,63 +135,74 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { ...@@ -141,63 +135,74 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
#endif #endif
} else { } else {
VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp32"; VLOG(1) << "TRT Plugin DataType selected. EmbEltwiseLayerNorm-->fp32";
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<float>(embs_, instantiateImpl<float>();
bias_,
scale_,
emb_sizes_,
bias_size_,
scale_size_,
hidden_size_,
eps_);
} }
} }
EmbEltwiseLayernormPluginDynamic(void const* serial_data, EmbEltwiseLayernormPluginDynamic(void const* serial_data,
size_t serial_length) size_t serial_length)
: own_host_buff_(true) { : own_host_buff_(true) {
// the first var is with_fp16, we will use it.
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
DeserializeValue(&serial_data, &serial_length, &emb_sizes_); DeserializeValue(&serial_data, &serial_length, &emb_sizes_);
embs_.resize(emb_sizes_.size());
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
auto ptr = new float[size];
memcpy(ptr, serial_data, sizeof(float) * size);
embs_[i] = ptr;
reinterpret_cast<char const*&>(serial_data) +=
emb_sizes_[i] * sizeof(float);
serial_length -= emb_sizes_[i] * sizeof(float);
}
DeserializeValue(&serial_data, &serial_length, &bias_size_); DeserializeValue(&serial_data, &serial_length, &bias_size_);
DeserializeValue(&serial_data, &serial_length, &scale_size_); DeserializeValue(&serial_data, &serial_length, &scale_size_);
if (bias_size_) { embs_.resize(emb_sizes_.size());
bias_ = new float[bias_size_];
memcpy(bias_, serial_data, sizeof(float) * bias_size_); if (with_fp16_) {
} for (size_t i = 0; i < emb_sizes_.size(); i++) {
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(float); auto size = emb_sizes_[i];
serial_length -= bias_size_ * sizeof(float); auto ptr = new half[size];
memcpy(ptr, serial_data, sizeof(half) * size);
embs_[i] = ptr;
reinterpret_cast<char const*&>(serial_data) += size * sizeof(half);
serial_length -= size * sizeof(half);
}
if (bias_size_) {
bias_ = new half[bias_size_];
memcpy(bias_, serial_data, sizeof(half) * bias_size_);
}
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(half);
serial_length -= bias_size_ * sizeof(half);
if (scale_size_) { if (scale_size_) {
scale_ = new float[scale_size_]; scale_ = new half[scale_size_];
memcpy(scale_, serial_data, sizeof(float) * scale_size_); memcpy(scale_, serial_data, sizeof(half) * scale_size_);
}
reinterpret_cast<char const*&>(serial_data) += scale_size_ * sizeof(half);
serial_length -= scale_size_ * sizeof(half);
} else {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
auto ptr = new float[size];
memcpy(ptr, serial_data, sizeof(float) * size);
embs_[i] = ptr;
reinterpret_cast<char const*&>(serial_data) += size * sizeof(float);
serial_length -= size * sizeof(float);
}
if (bias_size_) {
bias_ = new float[bias_size_];
memcpy(bias_, serial_data, sizeof(float) * bias_size_);
}
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(float);
serial_length -= bias_size_ * sizeof(float);
if (scale_size_) {
scale_ = new float[scale_size_];
memcpy(scale_, serial_data, sizeof(float) * scale_size_);
}
reinterpret_cast<char const*&>(serial_data) +=
scale_size_ * sizeof(float);
serial_length -= scale_size_ * sizeof(float);
} }
reinterpret_cast<char const*&>(serial_data) += scale_size_ * sizeof(float);
serial_length -= scale_size_ * sizeof(float);
DeserializeValue(&serial_data, &serial_length, &hidden_size_); DeserializeValue(&serial_data, &serial_length, &hidden_size_);
DeserializeValue(&serial_data, &serial_length, &eps_); DeserializeValue(&serial_data, &serial_length, &eps_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<half>(embs_, instantiateImpl<half>();
bias_,
scale_,
emb_sizes_,
bias_size_,
scale_size_,
hidden_size_,
eps_);
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be " "The Ernie(Bert) tensorRT plugin should be "
...@@ -207,14 +212,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { ...@@ -207,14 +212,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
"AnalysisConfig::Precision::kFloat32, false, false) ")); "AnalysisConfig::Precision::kFloat32, false, false) "));
#endif #endif
} else { } else {
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<float>(embs_, instantiateImpl<float>();
bias_,
scale_,
emb_sizes_,
bias_size_,
scale_size_,
hidden_size_,
eps_);
} }
} }
...@@ -241,44 +239,68 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { ...@@ -241,44 +239,68 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
int sum_num = 0; int sum_num = 0;
sum_num += SerializedSize(with_fp16_);
sum_num += SerializedSize(emb_sizes_); sum_num += SerializedSize(emb_sizes_);
for (size_t i = 0; i < emb_sizes_.size(); i++) { if (with_fp16_) {
sum_num += emb_sizes_[i] * sizeof(float); for (size_t i = 0; i < emb_sizes_.size(); i++) {
sum_num += emb_sizes_[i] * sizeof(half);
}
sum_num += (bias_size_ + scale_size_) * sizeof(half);
} else {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
sum_num += emb_sizes_[i] * sizeof(float);
}
sum_num += (bias_size_ + scale_size_) * sizeof(float);
} }
sum_num += SerializedSize(bias_size_); sum_num += SerializedSize(bias_size_);
sum_num += SerializedSize(scale_size_); sum_num += SerializedSize(scale_size_);
sum_num += (bias_size_ + scale_size_) * sizeof(float);
sum_num += SerializedSize(hidden_size_); sum_num += SerializedSize(hidden_size_);
sum_num += SerializedSize(eps_); sum_num += SerializedSize(eps_);
sum_num += SerializedSize(with_fp16_);
return sum_num; return sum_num;
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
// the first var is for with_fp16, we will use it later;
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, emb_sizes_); SerializeValue(&buffer, emb_sizes_);
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
for (int j = 0; j < size; ++j) {
SerializeValue(&buffer, embs_[i][j]);
}
}
SerializeValue(&buffer, bias_size_); SerializeValue(&buffer, bias_size_);
SerializeValue(&buffer, scale_size_); SerializeValue(&buffer, scale_size_);
for (int i = 0; i < bias_size_; ++i) { if (with_fp16_) {
SerializeValue(&buffer, bias_[i]); for (size_t i = 0; i < emb_sizes_.size(); i++) {
} auto size = emb_sizes_[i];
for (int j = 0; j < size; ++j) {
SerializeValue(&buffer, reinterpret_cast<half*>(embs_[i])[j]);
}
}
for (int i = 0; i < bias_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<half*>(bias_)[i]);
}
for (int i = 0; i < scale_size_; ++i) { for (int i = 0; i < scale_size_; ++i) {
SerializeValue(&buffer, scale_[i]); SerializeValue(&buffer, reinterpret_cast<half*>(scale_)[i]);
}
} else {
for (size_t i = 0; i < emb_sizes_.size(); i++) {
auto size = emb_sizes_[i];
for (int j = 0; j < size; ++j) {
SerializeValue(&buffer, reinterpret_cast<float*>(embs_[i])[j]);
}
}
for (int i = 0; i < bias_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<float*>(bias_)[i]);
}
for (int i = 0; i < scale_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<float*>(scale_)[i]);
}
} }
SerializeValue(&buffer, hidden_size_); SerializeValue(&buffer, hidden_size_);
SerializeValue(&buffer, eps_); SerializeValue(&buffer, eps_);
SerializeValue(&buffer, with_fp16_);
} }
nvinfer1::DimsExprs getOutputDimensions(int output_index, nvinfer1::DimsExprs getOutputDimensions(int output_index,
...@@ -317,21 +339,28 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { ...@@ -317,21 +339,28 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
void destroy() TRT_NOEXCEPT override { void destroy() TRT_NOEXCEPT override {
if (own_host_buff_) { if (own_host_buff_) {
for (auto ptr : embs_) { if (with_fp16_) {
delete[] ptr; for (auto ptr : embs_) {
delete[] reinterpret_cast<half*>(ptr);
}
delete[] reinterpret_cast<half*>(bias_);
delete[] reinterpret_cast<half*>(scale_);
} else {
for (auto ptr : embs_) {
delete[] reinterpret_cast<float*>(ptr);
}
delete[] reinterpret_cast<float*>(bias_);
delete[] reinterpret_cast<float*>(scale_);
} }
delete[] bias_;
delete[] scale_;
} }
delete impl_; delete impl_;
delete this; delete this;
} }
private: private:
std::vector<float*> embs_; std::vector<void*> embs_;
float* bias_; void* bias_{nullptr};
float* scale_; void* scale_{nullptr};
std::vector<int> emb_sizes_; std::vector<int> emb_sizes_;
int bias_size_; int bias_size_;
...@@ -345,6 +374,24 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT { ...@@ -345,6 +374,24 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
void shareGPUData(const EmbEltwiseLayernormPluginDynamic* anthor) { void shareGPUData(const EmbEltwiseLayernormPluginDynamic* anthor) {
impl_->shareGPUData(anthor->impl_); impl_->shareGPUData(anthor->impl_);
} }
template <typename U>
void instantiateImpl() {
std::vector<U*> embs;
embs.resize(embs_.size());
for (size_t i = 0; i < embs_.size(); ++i) {
embs[i] = reinterpret_cast<U*>(embs_[i]);
}
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<U>(
embs,
reinterpret_cast<U*>(bias_),
reinterpret_cast<U*>(scale_),
emb_sizes_,
bias_size_,
scale_size_,
hidden_size_,
eps_);
}
}; };
class EmbEltwiseLayernormPluginDynamicCreator class EmbEltwiseLayernormPluginDynamicCreator
......
...@@ -31,31 +31,61 @@ namespace plugin { ...@@ -31,31 +31,61 @@ namespace plugin {
// Dynamic Plugin below. // Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
int SkipLayerNormPluginDynamic::initialize() TRT_NOEXCEPT { template <typename T>
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_); void SkipLayerNormPluginDynamicImpl<T>::shareGPUData(
cudaMemcpy(bias_gpu_, const SkipLayerNormPluginDynamicImplBase *anthor) {
bias_.data(), auto *ptr = dynamic_cast<const SkipLayerNormPluginDynamicImpl<T> *>(anthor);
bias_size_ * sizeof(float), if (!ptr->is_initialized_) {
cudaMemcpyHostToDevice); return;
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_); }
cudaMemcpy(scale_gpu_, scale_gpu_ = ptr->scale_gpu_;
scale_.data(), bias_gpu_ = ptr->bias_gpu_;
scale_size_ * sizeof(float), }
cudaMemcpyHostToDevice);
template <typename T>
int SkipLayerNormPluginDynamicImpl<T>::initialize() {
if (is_initialized_) {
return 0;
}
if (bias_) {
cudaMalloc(&bias_gpu_, sizeof(T) * bias_size_);
cudaMemcpy(
bias_gpu_, bias_, bias_size_ * sizeof(T), cudaMemcpyHostToDevice);
}
if (scale_) {
cudaMalloc(&scale_gpu_, sizeof(T) * scale_size_);
cudaMemcpy(
scale_gpu_, scale_, scale_size_ * sizeof(T), cudaMemcpyHostToDevice);
}
is_initialized_ = true;
return 0; return 0;
} }
void SkipLayerNormPluginDynamic::terminate() TRT_NOEXCEPT { template <typename T>
void SkipLayerNormPluginDynamicImpl<T>::terminate() {
if (bias_gpu_) { if (bias_gpu_) {
cudaFree(bias_gpu_); cudaFree(bias_gpu_);
bias_gpu_ = nullptr; bias_gpu_ = nullptr;
} }
if (scale_gpu_) { if (scale_gpu_) {
cudaFree(scale_gpu_); cudaFree(scale_gpu_);
scale_gpu_ = nullptr; scale_gpu_ = nullptr;
} }
} }
int SkipLayerNormPluginDynamic::initialize() TRT_NOEXCEPT {
impl_->initialize();
return 0;
}
void SkipLayerNormPluginDynamic::terminate() TRT_NOEXCEPT {
impl_->terminate();
}
nvinfer1::DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions( nvinfer1::DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions(
int output_index, int output_index,
const nvinfer1::DimsExprs *inputs, const nvinfer1::DimsExprs *inputs,
...@@ -73,6 +103,12 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination( ...@@ -73,6 +103,12 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
in_out, in_out,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr.")); "The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_EQ(nb_outputs,
1,
platform::errors::InvalidArgument(
"The SkipLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
pos, pos,
...@@ -82,30 +118,27 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination( ...@@ -82,30 +118,27 @@ bool SkipLayerNormPluginDynamic::supportsFormatCombination(
pos, pos,
nb_inputs + nb_outputs)); nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos]; const nvinfer1::PluginTensorDesc &desc = in_out[pos];
if (pos == 0) { if (pos == 0) {
if (with_fp16_) { if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT || return (desc.type == nvinfer1::DataType::kHALF) &&
in.type == nvinfer1::DataType::kHALF) && (desc.format == nvinfer1::TensorFormat::kLINEAR);
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else #else
return (in.type == nvinfer1::DataType::kFLOAT) && return (desc.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (desc.format == nvinfer1::TensorFormat::kLINEAR);
#endif #endif
} else { } else {
return (in.type == nvinfer1::DataType::kFLOAT) && return (desc.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR); (desc.format == nvinfer1::TensorFormat::kLINEAR);
} }
} }
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
if (pos == 1) { if (pos == 1) {
return in.type == prev.type && in.format == prev.format; return desc.type == prev.type && desc.format == prev.format;
} }
// output // output
return in.type == prev.type && in.format == prev.format; return desc.type == prev.type && desc.format == prev.format;
} }
nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType( nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
...@@ -115,7 +148,7 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType( ...@@ -115,7 +148,7 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
PADDLE_ENFORCE_EQ(index, PADDLE_ENFORCE_EQ(index,
0, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The SkipLayerNorm Plugin only has one input, so the " "The SkipLayerNorm Plugin only has one output, so the "
"index value should be 0, but get %d.", "index value should be 0, but get %d.",
index)); index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT || PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
...@@ -126,7 +159,8 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType( ...@@ -126,7 +159,8 @@ nvinfer1::DataType SkipLayerNormPluginDynamic::getOutputDataType(
return input_types[0]; return input_types[0];
} }
int SkipLayerNormPluginDynamic::enqueue( template <typename T>
int SkipLayerNormPluginDynamicImpl<T>::enqueue(
const nvinfer1::PluginTensorDesc *input_desc, const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs, const void *const *inputs,
...@@ -138,51 +172,45 @@ int SkipLayerNormPluginDynamic::enqueue( ...@@ -138,51 +172,45 @@ int SkipLayerNormPluginDynamic::enqueue(
int hidden = input_dims.d[2]; int hidden = input_dims.d[2];
auto input_type = input_desc[0].type; auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. SkipLayerNorm-->fp32"; if (std::is_same<T, float>::value) {
const float *input1 = static_cast<const float *>(inputs[0]); PADDLE_ENFORCE_EQ(input_type == nvinfer1::DataType::kFLOAT,
const float *input2 = static_cast<const float *>(inputs[1]); true,
float *output = static_cast<float *>(outputs[0]); platform::errors::InvalidArgument(
operators::math::SkipLayerNormFunctor<float> skip_layer_norm_func; "The SkipLayernorm Plugin only support fp32 input."));
skip_layer_norm_func(num, } else if (std::is_same<T, half>::value) {
hidden, PADDLE_ENFORCE_EQ(input_type == nvinfer1::DataType::kHALF,
input1, true,
input2, platform::errors::InvalidArgument(
scale_gpu_, "The SkipLayernorm Plugin only support fp16 input."));
bias_gpu_,
output,
eps_,
stream);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. SkipLayerNorm-->fp16";
const half *input1 = static_cast<const half *>(inputs[0]);
const half *input2 = static_cast<const half *>(inputs[1]);
half *output = static_cast<half *>(outputs[0]);
operators::math::SkipLayerNormFunctor<half> skip_layer_norm_func;
skip_layer_norm_func(num,
hidden,
input1,
input2,
scale_gpu_,
bias_gpu_,
output,
static_cast<half>(eps_),
stream);
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"));
#endif
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The SkipLayerNorm TRT Plugin's input type should be float or half.")); "Unsupport data type, the out type of SkipLayernorm should be "
"float or half."));
} }
auto *output_d = reinterpret_cast<T *>(outputs[0]);
const T *input1 = reinterpret_cast<const T *>(inputs[0]);
const T *input2 = reinterpret_cast<const T *>(inputs[1]);
auto *output = reinterpret_cast<T *>(outputs[0]);
operators::math::SkipLayerNormFunctor<T> skip_layer_norm_func;
skip_layer_norm_func(
num, hidden, input1, input2, scale_gpu_, bias_gpu_, output, eps_, stream);
return cudaGetLastError() != cudaSuccess;
}
int SkipLayerNormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc,
const void *const *inputs,
void *const *outputs,
void *workspace,
cudaStream_t stream) TRT_NOEXCEPT {
impl_->enqueue(input_desc, output_desc, inputs, outputs, workspace, stream);
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
#endif #endif
} // namespace plugin } // namespace plugin
......
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include <cstddef>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/phi/common/data_type.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -27,36 +29,155 @@ namespace tensorrt { ...@@ -27,36 +29,155 @@ namespace tensorrt {
namespace plugin { namespace plugin {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
class SkipLayerNormPluginDynamicImplBase {
public:
SkipLayerNormPluginDynamicImplBase() {}
virtual ~SkipLayerNormPluginDynamicImplBase() {}
virtual int initialize() = 0;
virtual void terminate() = 0;
virtual int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) = 0;
virtual void shareGPUData(
const SkipLayerNormPluginDynamicImplBase* anthor) = 0;
};
template <typename T>
class SkipLayerNormPluginDynamicImpl
: public SkipLayerNormPluginDynamicImplBase {
public:
explicit SkipLayerNormPluginDynamicImpl(
T* bias, T* scale, int bias_size, int scale_size, const float eps)
: bias_(bias),
scale_(scale),
bias_size_(bias_size),
scale_size_(scale_size),
eps_(eps) {}
~SkipLayerNormPluginDynamicImpl() {}
int initialize();
void terminate();
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT;
void shareGPUData(const SkipLayerNormPluginDynamicImplBase* anthor);
private:
T* bias_{nullptr};
T* scale_{nullptr};
// data on devices
T* bias_gpu_{nullptr};
T* scale_gpu_{nullptr};
int bias_size_;
int scale_size_;
float eps_;
bool is_initialized_{false};
};
class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT { class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
public: public:
explicit SkipLayerNormPluginDynamic(const float* bias, explicit SkipLayerNormPluginDynamic(void* bias,
const float* scale, void* scale,
int bias_size, int bias_size,
int scale_size, int scale_size,
const float eps, float eps,
bool with_fp16) bool with_fp16)
: bias_size_(bias_size), scale_size_(scale_size), eps_(eps) { : bias_(bias),
scale_(scale),
bias_size_(bias_size),
scale_size_(scale_size),
eps_(eps),
own_host_buff_(false) {
with_fp16_ = with_fp16; with_fp16_ = with_fp16;
bias_.resize(bias_size); if (with_fp16_) {
scale_.resize(scale_size); #ifdef TRT_PLUGIN_FP16_AVALIABLE
std::copy(bias, bias + bias_size, bias_.data()); VLOG(1) << "TRT Plugin DataType selected. SkipLayerNorm-->fp16";
std::copy(scale, scale + scale_size, scale_.data()); instantiateImpl<half>();
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "));
#endif
} else {
VLOG(1) << "TRT Plugin DataType selected. SkipLayerNorm-->fp32";
instantiateImpl<float>();
}
} }
SkipLayerNormPluginDynamic(void const* serial_data, size_t serial_length) { SkipLayerNormPluginDynamic(void const* serial_data, size_t serial_length)
DeserializeValue(&serial_data, &serial_length, &bias_); : own_host_buff_(true) {
DeserializeValue(&serial_data, &serial_length, &scale_); // the first var is with_fp16, we will use it.
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
DeserializeValue(&serial_data, &serial_length, &bias_size_); DeserializeValue(&serial_data, &serial_length, &bias_size_);
DeserializeValue(&serial_data, &serial_length, &scale_size_); DeserializeValue(&serial_data, &serial_length, &scale_size_);
DeserializeValue(&serial_data, &serial_length, &eps_); DeserializeValue(&serial_data, &serial_length, &eps_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
if (with_fp16_) {
if (bias_size_) {
bias_ = new half[bias_size_];
memcpy(bias_, serial_data, sizeof(half) * bias_size_);
}
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(half);
serial_length -= bias_size_ * sizeof(half);
if (scale_size_) {
scale_ = new half[scale_size_];
memcpy(scale_, serial_data, sizeof(half) * scale_size_);
}
reinterpret_cast<char const*&>(serial_data) += scale_size_ * sizeof(half);
serial_length -= scale_size_ * sizeof(half);
} else {
if (bias_size_) {
bias_ = new float[bias_size_];
memcpy(bias_, serial_data, sizeof(float) * bias_size_);
}
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(float);
serial_length -= bias_size_ * sizeof(float);
if (scale_size_) {
scale_ = new float[scale_size_];
memcpy(scale_, serial_data, sizeof(float) * scale_size_);
}
reinterpret_cast<char const*&>(serial_data) +=
scale_size_ * sizeof(float);
serial_length -= scale_size_ * sizeof(float);
}
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
instantiateImpl<half>();
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) tensorRT plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.EnableTensorRtEngine(1 << 30, 1, 5, "
"AnalysisConfig::Precision::kFloat32, false, false) "));
#endif
} else {
instantiateImpl<float>();
}
} }
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
auto ptr = new SkipLayerNormPluginDynamic( auto ptr = new SkipLayerNormPluginDynamic(
bias_.data(), scale_.data(), bias_size_, scale_size_, eps_, with_fp16_); bias_, scale_, bias_size_, scale_size_, eps_, with_fp16_);
ptr->bias_gpu_ = bias_gpu_; ptr->shareGPUData(this);
ptr->scale_gpu_ = scale_gpu_;
return ptr; return ptr;
} }
...@@ -65,20 +186,48 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -65,20 +186,48 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
} }
int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override; int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
size_t ser_size = SerializedSize(bias_) + SerializedSize(scale_) + size_t sum_num = 0;
SerializedSize(bias_size_) + SerializedSize(scale_size_) + sum_num += SerializedSize(with_fp16_);
SerializedSize(eps_) + SerializedSize(with_fp16_);
return ser_size; if (with_fp16_) {
sum_num += (bias_size_ + scale_size_) * sizeof(half);
} else {
sum_num += (bias_size_ + scale_size_) * sizeof(float);
}
sum_num += SerializedSize(bias_size_);
sum_num += SerializedSize(scale_size_);
sum_num += SerializedSize(eps_);
return sum_num;
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, bias_); // the first var is for with_fp16, we will use it later;
SerializeValue(&buffer, scale_); SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, bias_size_); SerializeValue(&buffer, bias_size_);
SerializeValue(&buffer, scale_size_); SerializeValue(&buffer, scale_size_);
SerializeValue(&buffer, eps_); SerializeValue(&buffer, eps_);
SerializeValue(&buffer, with_fp16_); if (with_fp16_) {
for (int i = 0; i < bias_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<half*>(bias_)[i]);
}
for (int i = 0; i < scale_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<half*>(scale_)[i]);
}
} else {
for (int i = 0; i < bias_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<float*>(bias_)[i]);
}
for (int i = 0; i < scale_size_; ++i) {
SerializeValue(&buffer, reinterpret_cast<float*>(scale_)[i]);
}
}
} }
nvinfer1::DimsExprs getOutputDimensions(int output_index, nvinfer1::DimsExprs getOutputDimensions(int output_index,
...@@ -115,20 +264,43 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -115,20 +264,43 @@ class SkipLayerNormPluginDynamic : public DynamicPluginTensorRT {
int nb_inputs) const int nb_inputs) const
TRT_NOEXCEPT override; TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; } void destroy() TRT_NOEXCEPT override {
void terminate() TRT_NOEXCEPT override; if (own_host_buff_) {
if (with_fp16_) {
delete[] reinterpret_cast<half*>(bias_);
delete[] reinterpret_cast<half*>(scale_);
} else {
delete[] reinterpret_cast<float*>(bias_);
delete[] reinterpret_cast<float*>(scale_);
}
}
delete impl_;
delete this;
}
private: private:
std::vector<float> bias_; void* bias_{nullptr};
std::vector<float> scale_; void* scale_{nullptr};
float* bias_gpu_{nullptr};
float* scale_gpu_{nullptr};
int bias_size_; int bias_size_;
int scale_size_; int scale_size_;
float eps_; float eps_;
bool own_host_buff_{false};
SkipLayerNormPluginDynamicImplBase* impl_{nullptr};
void shareGPUData(const SkipLayerNormPluginDynamic* anthor) {
impl_->shareGPUData(anthor->impl_);
}
template <typename U>
void instantiateImpl() {
impl_ = new SkipLayerNormPluginDynamicImpl<U>(reinterpret_cast<U*>(bias_),
reinterpret_cast<U*>(scale_),
bias_size_,
scale_size_,
eps_);
}
}; };
class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator { class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
...@@ -154,8 +326,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator { ...@@ -154,8 +326,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
const void* serial_data, const void* serial_data,
size_t serial_length) size_t serial_length)
TRT_NOEXCEPT override { TRT_NOEXCEPT override {
auto plugin = new SkipLayerNormPluginDynamic(serial_data, serial_length); return new SkipLayerNormPluginDynamic(serial_data, serial_length);
return plugin;
} }
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override { void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
...@@ -173,6 +344,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator { ...@@ -173,6 +344,7 @@ class SkipLayerNormPluginDynamicCreator : public nvinfer1::IPluginCreator {
std::vector<nvinfer1::PluginField> plugin_attributes_; std::vector<nvinfer1::PluginField> plugin_attributes_;
}; };
REGISTER_TRT_PLUGIN_V2(SkipLayerNormPluginDynamicCreator); REGISTER_TRT_PLUGIN_V2(SkipLayerNormPluginDynamicCreator);
#endif #endif
} // namespace plugin } // namespace plugin
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <cuda_fp16.h>
#include <cstring> #include <cstring>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
...@@ -46,10 +47,11 @@ template <typename T, class Enable = void> ...@@ -46,10 +47,11 @@ template <typename T, class Enable = void>
struct Serializer {}; struct Serializer {};
template <typename T> template <typename T>
struct Serializer<T, struct Serializer<
typename std::enable_if<std::is_arithmetic<T>::value || T,
std::is_enum<T>::value || typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_pod<T>::value>::type> { std::is_enum<T>::value || std::is_pod<T>::value ||
std::is_same<T, half>::value>::type> {
static size_t SerializedSize(T const& value) { return sizeof(T); } static size_t SerializedSize(T const& value) { return sizeof(T); }
static void Serialize(void** buffer, T const& value) { static void Serialize(void** buffer, T const& value) {
...@@ -86,10 +88,11 @@ struct Serializer<const char*> { ...@@ -86,10 +88,11 @@ struct Serializer<const char*> {
}; };
template <typename T> template <typename T>
struct Serializer<std::vector<T>, struct Serializer<
typename std::enable_if<std::is_arithmetic<T>::value || std::vector<T>,
std::is_enum<T>::value || typename std::enable_if<std::is_arithmetic<T>::value ||
std::is_pod<T>::value>::type> { std::is_enum<T>::value || std::is_pod<T>::value ||
std::is_same<T, half>::value>::type> {
static size_t SerializedSize(std::vector<T> const& value) { static size_t SerializedSize(std::vector<T> const& value) {
return sizeof(value.size()) + value.size() * sizeof(T); return sizeof(value.size()) + value.size() * sizeof(T);
} }
......
...@@ -98,8 +98,9 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -98,8 +98,9 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) {
std::string model_dir = FLAGS_infer_model; std::string model_dir = FLAGS_infer_model;
// Delete serialization cache to perform serialization first rather than // Delete serialization cache to perform serialization first rather than
// deserialization. // deserialization.
std::string opt_cache_dir = FLAGS_infer_model + "/_opt_cache"; std::string opt_cache_dir = FLAGS_infer_model + "/opt_cache";
delete_cache_files(opt_cache_dir); delete_cache_files(opt_cache_dir);
config.SetOptimCacheDir(opt_cache_dir);
SetConfig(&config, model_dir, true /* use_gpu */); SetConfig(&config, model_dir, true /* use_gpu */);
......
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
#include <paddle/fluid/platform/device_context.h> #include <paddle/fluid/platform/device_context.h>
#include <algorithm> #include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle { namespace paddle {
...@@ -99,19 +102,37 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> { ...@@ -99,19 +102,37 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
auto *output_d = out->mutable_data<T>(context.GetPlace()); auto *output_d = out->mutable_data<T>(context.GetPlace());
float eps = context.Attr<float>("epsilon"); float eps = context.Attr<float>("epsilon");
int shared_bytes = input_num * sizeof(int64_t); if (std::is_same<T, paddle::platform::float16>::value) {
math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func; const half *scale_new = reinterpret_cast<const half *>(scale_d);
emb_eltwise_layernorm_func(batch, const half *bias_new = reinterpret_cast<const half *>(bias_d);
seq_len, half *output_new = reinterpret_cast<half *>(output_d);
hidden,
in_ids_d, math::EmbEltwiseLayerNormFunctor<half> emb_eltwise_layernorm_func;
scale_d, emb_eltwise_layernorm_func(batch,
bias_d, seq_len,
in_embs_d, hidden,
output_d, in_ids_d,
eps, scale_new,
input_num, bias_new,
device_ctx.stream()); in_embs_d,
output_new,
eps,
input_num,
device_ctx.stream());
} else {
math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch,
seq_len,
hidden,
in_ids_d,
scale_d,
bias_d,
in_embs_d,
output_d,
eps,
input_num,
device_ctx.stream());
}
} }
}; };
...@@ -119,6 +140,14 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> { ...@@ -119,6 +140,14 @@ class EmbeddingEltWiseLayerNormKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL(
fused_embedding_eltwise_layernorm,
ops::EmbeddingEltWiseLayerNormKernel<phi::GPUContext, float>,
ops::EmbeddingEltWiseLayerNormKernel<phi::GPUContext,
paddle::platform::float16>);
#else
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
fused_embedding_eltwise_layernorm, fused_embedding_eltwise_layernorm,
ops::EmbeddingEltWiseLayerNormKernel<phi::GPUContext, float>); ops::EmbeddingEltWiseLayerNormKernel<phi::GPUContext, float>);
#endif
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <paddle/fluid/platform/device_context.h> #include <paddle/fluid/platform/device_context.h>
#include <algorithm> #include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
...@@ -53,15 +54,34 @@ class SkipLayerNormKernel : public framework::OpKernel<T> { ...@@ -53,15 +54,34 @@ class SkipLayerNormKernel : public framework::OpKernel<T> {
auto &device_ctx = context.template device_context<DeviceContext>(); auto &device_ctx = context.template device_context<DeviceContext>();
operators::math::SkipLayerNormFunctor<T> skip_layer_norm_func; operators::math::SkipLayerNormFunctor<T> skip_layer_norm_func;
skip_layer_norm_func(num, if (std::is_same<T, paddle::platform::float16>::value) {
hidden, const half *X_new = reinterpret_cast<const half *>(X_d);
X_d, const half *Y_new = reinterpret_cast<const half *>(Y_d);
Y_d, const half *scale_new = reinterpret_cast<const half *>(scale_d);
scale_d, const half *bias_new = reinterpret_cast<const half *>(bias_d);
bias_d, half *output_new = reinterpret_cast<half *>(output_d);
output_d, operators::math::SkipLayerNormFunctor<half> skip_layer_norm_func;
epsilon, skip_layer_norm_func(num,
device_ctx.stream()); hidden,
X_new,
Y_new,
scale_new,
bias_new,
output_new,
epsilon,
device_ctx.stream());
} else {
operators::math::SkipLayerNormFunctor<T> skip_layer_norm_func;
skip_layer_norm_func(num,
hidden,
X_d,
Y_d,
scale_d,
bias_d,
output_d,
epsilon,
device_ctx.stream());
}
} }
}; };
...@@ -69,5 +89,13 @@ class SkipLayerNormKernel : public framework::OpKernel<T> { ...@@ -69,5 +89,13 @@ class SkipLayerNormKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL(
skip_layernorm,
ops::SkipLayerNormKernel<phi::GPUContext, float>,
ops::SkipLayerNormKernel<phi::GPUContext, paddle::platform::float16>);
#else
REGISTER_OP_CUDA_KERNEL(skip_layernorm, REGISTER_OP_CUDA_KERNEL(skip_layernorm,
ops::SkipLayerNormKernel<phi::GPUContext, float>); ops::SkipLayerNormKernel<phi::GPUContext, float>);
#endif
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <type_traits>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
...@@ -42,8 +43,8 @@ __device__ inline void LayerNormSmall(T val, ...@@ -42,8 +43,8 @@ __device__ inline void LayerNormSmall(T val,
const phi::funcs::kvp<T> &thread_data, const phi::funcs::kvp<T> &thread_data,
const int ld, const int ld,
const int idx, const int idx,
const float *bias, const T *bias,
const float *scale, const T *scale,
T *output, T *output,
T eps) { T eps) {
using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>; using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
...@@ -70,8 +71,8 @@ template <typename T, int TPB> ...@@ -70,8 +71,8 @@ template <typename T, int TPB>
__device__ inline void LayerNorm(const phi::funcs::kvp<T> &thread_data, __device__ inline void LayerNorm(const phi::funcs::kvp<T> &thread_data,
const int ld, const int ld,
const int offset, const int offset,
const float *bias, const T *bias,
const float *scale, const T *scale,
T *output, T *output,
T eps) { T eps) {
using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>; using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
...@@ -100,8 +101,8 @@ template <typename T, typename T2, int TPB> ...@@ -100,8 +101,8 @@ template <typename T, typename T2, int TPB>
__device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data, __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
const int ld, const int ld,
const int offset, const int offset,
const float2 *bias, const T2 *bias,
const float2 *scale, const T2 *scale,
T2 *output, T2 *output,
T eps) { T eps) {
using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>; using BlockReduce = cub::BlockReduce<phi::funcs::kvp<T>, TPB>;
...@@ -120,8 +121,8 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data, ...@@ -120,8 +121,8 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
for (int i = threadIdx.x; i < ld; i += TPB) { for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i; const int idx = offset + i;
T2 val = output[idx]; T2 val = output[idx];
const float2 g = scale[i]; const T2 g = scale[i];
const float2 b = bias[i]; const T2 b = bias[i];
val.x = T(g.x) * (val.x - mu) * rsigma + T(b.x); val.x = T(g.x) * (val.x - mu) * rsigma + T(b.x);
val.y = T(g.y) * (val.y - mu) * rsigma + T(b.y); val.y = T(g.y) * (val.y - mu) * rsigma + T(b.y);
output[idx] = val; output[idx] = val;
...@@ -131,11 +132,11 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data, ...@@ -131,11 +132,11 @@ __device__ inline void LayerNorm2(const phi::funcs::kvp<T> &thread_data,
template <typename T, unsigned TPB> template <typename T, unsigned TPB>
__global__ void EmbEltwiseLayernormKernel(int hidden, __global__ void EmbEltwiseLayernormKernel(int hidden,
const int64_t *ids, const int64_t *ids,
const float *scale, const T *scale,
const float *bias, const T *bias,
const int64_t *embs, const int64_t *embs,
T *output, T *output,
float eps, T eps,
int input_num) { int input_num) {
cub::Sum pair_sum; cub::Sum pair_sum;
// blockIdx.x: position in the sequence // blockIdx.x: position in the sequence
...@@ -179,11 +180,11 @@ __global__ void EmbEltwiseLayernormKernel(int hidden, ...@@ -179,11 +180,11 @@ __global__ void EmbEltwiseLayernormKernel(int hidden,
template <> template <>
__global__ void EmbEltwiseLayernormKernel<half, 256>(int hidden, __global__ void EmbEltwiseLayernormKernel<half, 256>(int hidden,
const int64_t *ids, const int64_t *ids,
const float *scale, const half *scale,
const float *bias, const half *bias,
const int64_t *embs, const int64_t *embs,
half *output, half *output,
float eps, half eps,
int input_num) { int input_num) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
cub::Sum pair_sum; cub::Sum pair_sum;
...@@ -231,8 +232,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()(int batch, ...@@ -231,8 +232,8 @@ void EmbEltwiseLayerNormFunctor<T>::operator()(int batch,
int seq_len, int seq_len,
int hidden, int hidden,
const int64_t *ids, const int64_t *ids,
const float *scale, const T *scale,
const float *bias, const T *bias,
const int64_t *embs, const int64_t *embs,
T *output, T *output,
float eps, float eps,
...@@ -720,9 +721,9 @@ __global__ void SkipLayerNormSmallKernel(int num, ...@@ -720,9 +721,9 @@ __global__ void SkipLayerNormSmallKernel(int num,
const T *input1, const T *input1,
const T *input2, const T *input2,
T *output, T *output,
const float *scale, const T *scale,
const float *bias, const T *bias,
float eps) { T eps) {
const T rld = T(1) / T(hidden); const T rld = T(1) / T(hidden);
const int offset = blockIdx.x * hidden; const int offset = blockIdx.x * hidden;
cub::Sum pair_sum; cub::Sum pair_sum;
...@@ -747,9 +748,9 @@ __global__ void SkipLayerNormSmallKernel<half, 32>(int num, ...@@ -747,9 +748,9 @@ __global__ void SkipLayerNormSmallKernel<half, 32>(int num,
const half *input1, const half *input1,
const half *input2, const half *input2,
half *output, half *output,
const float *scale, const half *scale,
const float *bias, const half *bias,
float eps) { half eps) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half rld = half(1) / half(hidden); const half rld = half(1) / half(hidden);
const int offset = blockIdx.x * hidden; const int offset = blockIdx.x * hidden;
...@@ -774,9 +775,9 @@ __global__ void SkipLayerNormSmallKernel<half, 128>(int num, ...@@ -774,9 +775,9 @@ __global__ void SkipLayerNormSmallKernel<half, 128>(int num,
const half *input1, const half *input1,
const half *input2, const half *input2,
half *output, half *output,
const float *scale, const half *scale,
const float *bias, const half *bias,
float eps) { half eps) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half rld = half(1) / half(hidden); const half rld = half(1) / half(hidden);
const int offset = blockIdx.x * hidden; const int offset = blockIdx.x * hidden;
...@@ -801,9 +802,9 @@ __global__ void SkipLayerNormSmallKernel<half, 384>(int num, ...@@ -801,9 +802,9 @@ __global__ void SkipLayerNormSmallKernel<half, 384>(int num,
const half *input1, const half *input1,
const half *input2, const half *input2,
half *output, half *output,
const float *scale, const half *scale,
const float *bias, const half *bias,
float eps) { half eps) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half rld = half(1) / half(hidden); const half rld = half(1) / half(hidden);
const int offset = blockIdx.x * hidden; const int offset = blockIdx.x * hidden;
...@@ -829,9 +830,9 @@ __global__ void SkipLayerNormKernel(int num, ...@@ -829,9 +830,9 @@ __global__ void SkipLayerNormKernel(int num,
const T *input1, const T *input1,
const T *input2, const T *input2,
T *output, T *output,
const float *scale, const T *scale,
const float *bias, const T *bias,
float eps) { T eps) {
const T rld = T(1) / T(hidden); const T rld = T(1) / T(hidden);
const int offset = blockIdx.x * hidden; const int offset = blockIdx.x * hidden;
cub::Sum pair_sum; cub::Sum pair_sum;
...@@ -856,9 +857,9 @@ __global__ void SkipLayerNormKernel<half, 256>(int num, ...@@ -856,9 +857,9 @@ __global__ void SkipLayerNormKernel<half, 256>(int num,
const half *input1, const half *input1,
const half *input2, const half *input2,
half *output, half *output,
const float *scale, const half *scale,
const float *bias, const half *bias,
float eps) { half eps) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
const half rld = half(1) / half(hidden); const half rld = half(1) / half(hidden);
const int offset = blockIdx.x * hidden; const int offset = blockIdx.x * hidden;
...@@ -884,8 +885,8 @@ __global__ void SkipLayerNormKernel2(int num, ...@@ -884,8 +885,8 @@ __global__ void SkipLayerNormKernel2(int num,
const T2 *input1, const T2 *input1,
const T2 *input2, const T2 *input2,
T2 *output, T2 *output,
const float2 *scale, const T2 *scale,
const float2 *bias, const T2 *bias,
float eps) { float eps) {
const T rld = T(0.5f / hidden); // because hidden is hidden/2 const T rld = T(0.5f / hidden); // because hidden is hidden/2
const int offset = blockIdx.x * hidden; const int offset = blockIdx.x * hidden;
...@@ -912,8 +913,8 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>(int num, ...@@ -912,8 +913,8 @@ __global__ void SkipLayerNormKernel2<half, half2, 256>(int num,
const half2 *input1, const half2 *input1,
const half2 *input2, const half2 *input2,
half2 *output, half2 *output,
const float2 *scale, const half2 *scale,
const float2 *bias, const half2 *bias,
float eps) { float eps) {
// operator "+" of half only suppotted after cuda version 10.0 // operator "+" of half only suppotted after cuda version 10.0
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000 #if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000
...@@ -942,10 +943,10 @@ void SkipLayerNormFunctor<T>::operator()(const int num, ...@@ -942,10 +943,10 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
const int hidden, const int hidden,
const T *input1, const T *input1,
const T *input2, const T *input2,
const float *scale, const T *scale,
const float *bias, const T *bias,
T *output, T *output,
T eps, float eps,
gpuStream_t stream) { gpuStream_t stream) {
int block = num / hidden; int block = num / hidden;
if (hidden <= 32) { if (hidden <= 32) {
...@@ -984,8 +985,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num, ...@@ -984,8 +985,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
reinterpret_cast<const __half2 *>(input1), reinterpret_cast<const __half2 *>(input1),
reinterpret_cast<const __half2 *>(input2), reinterpret_cast<const __half2 *>(input2),
reinterpret_cast<__half2 *>(output), reinterpret_cast<__half2 *>(output),
reinterpret_cast<const float2 *>(scale), reinterpret_cast<const __half2 *>(scale),
reinterpret_cast<const float2 *>(bias), reinterpret_cast<const __half2 *>(bias),
eps); eps);
#endif #endif
} else { } else {
......
...@@ -68,8 +68,8 @@ class EmbEltwiseLayerNormFunctor { ...@@ -68,8 +68,8 @@ class EmbEltwiseLayerNormFunctor {
int seq_len, int seq_len,
int hidden, int hidden,
const int64_t *ids, const int64_t *ids,
const float *scale, const T *scale,
const float *bias, const T *bias,
const int64_t *embs, const int64_t *embs,
T *output, T *output,
float eps, float eps,
...@@ -125,10 +125,10 @@ class SkipLayerNormFunctor { ...@@ -125,10 +125,10 @@ class SkipLayerNormFunctor {
const int hidden, const int hidden,
const T *input1, const T *input1,
const T *input2, const T *input2,
const float *scale, const T *scale,
const float *bias, const T *bias,
T *output, T *output,
T eps, float eps,
gpuStream_t stream); gpuStream_t stream);
}; };
#endif #endif
......
...@@ -562,6 +562,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -562,6 +562,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
} }
runtime_batch = t_shape[0]; runtime_batch = t_shape[0];
VLOG(1) << "trt input [" << x << "] dtype is " << t.dtype(); VLOG(1) << "trt input [" << x << "] dtype is " << t.dtype();
auto indata_type = inference::tensorrt::PhiType2NvType(t.dtype()); auto indata_type = inference::tensorrt::PhiType2NvType(t.dtype());
auto intrt_index = engine->engine()->getBindingIndex(x.c_str()); auto intrt_index = engine->engine()->getBindingIndex(x.c_str());
auto intrt_type = engine->engine()->getBindingDataType(intrt_index); auto intrt_type = engine->engine()->getBindingDataType(intrt_index);
...@@ -570,6 +571,7 @@ class TensorRTEngineOp : public framework::OperatorBase { ...@@ -570,6 +571,7 @@ class TensorRTEngineOp : public framework::OperatorBase {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The TRT Engine OP's input type should equal " "The TRT Engine OP's input type should equal "
"to the input data type")); "to the input data type"));
auto type = framework::TransToProtoVarType(t.dtype()); auto type = framework::TransToProtoVarType(t.dtype());
if (type == framework::proto::VarType::FP32) { if (type == framework::proto::VarType::FP32) {
buffers[bind_index] = static_cast<void *>(t.data<float>()); buffers[bind_index] = static_cast<void *>(t.data<float>());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册