未验证 提交 a5ef246c 编写于 作者: P Pei Yang 提交者: GitHub

Optimize emb_eltwise_layernorm_plugin and support fp16 (#27128)

上级 4c5cfdea
......@@ -107,6 +107,9 @@ function(select_nvcc_arch_flags out_variable)
elseif(${CUDA_ARCH_NAME} STREQUAL "Maxwell")
set(cuda_arch_bin "50")
elseif(${CUDA_ARCH_NAME} STREQUAL "Pascal")
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
add_definitions("-DSUPPORTS_CUDA_FP16")
endif()
set(cuda_arch_bin "60 61")
elseif(${CUDA_ARCH_NAME} STREQUAL "Volta")
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
......
......@@ -80,10 +80,10 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
plugin::DynamicPluginTensorRT* plugin = nullptr;
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
auto use_fp16 = engine_->WithFp16();
auto plugin = new plugin::EmbEltwiseLayernormPluginDynamic(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
eps, use_fp16);
layer = engine_->AddPluginV2(input_ids.data(), input_num, plugin);
} else {
PADDLE_THROW(platform::errors::Fatal(
......
......@@ -32,13 +32,34 @@ namespace plugin {
#if IS_TRT_VERSION_GE(6000)
template <typename T>
int EmbEltwiseLayernormPluginDynamic<T>::initialize() {
EmbEltwiseLayernormPluginDynamicImpl<
T>::~EmbEltwiseLayernormPluginDynamicImpl() {
this->terminate();
}
inline half fp32tofp16(float x) { return static_cast<half>(x); }
template <typename T>
int EmbEltwiseLayernormPluginDynamicImpl<T>::initialize() {
embs_gpu_.resize(embs_.size());
for (int i = 0; i < embs_.size(); i++) {
if (embs_[i]) {
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
cudaMemcpy(embs_gpu_[i], embs_[i], emb_sizes_[i] * sizeof(float),
T *host_ptr;
auto size = emb_sizes_[i];
if (std::is_same<T, half>::value) {
host_ptr = new T[size];
std::transform(embs_[i], (embs_[i] + size), host_ptr, fp32tofp16);
} else {
host_ptr = reinterpret_cast<T *>(embs_[i]);
}
cudaMalloc(&embs_gpu_[i], sizeof(T) * size);
cudaMemcpy(embs_gpu_[i], host_ptr, size * sizeof(T),
cudaMemcpyHostToDevice);
if (std::is_same<T, half>::value) {
delete[] host_ptr;
}
}
}
......@@ -53,11 +74,105 @@ int EmbEltwiseLayernormPluginDynamic<T>::initialize() {
cudaMemcpyHostToDevice);
}
int input_num = embs_.size();
in_ptr_tensor_.Resize({input_num});
emb_ptr_tensor_.Resize({input_num});
cudaGetDevice(&device_id_);
auto emb_ptr_gpu_d =
emb_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
cudaMemcpy(emb_ptr_gpu_d, embs_gpu_.data(), sizeof(uintptr_t) * input_num,
cudaMemcpyHostToDevice);
return 0;
}
template <typename T>
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
void EmbEltwiseLayernormPluginDynamicImpl<T>::terminate() {
for (int i = 0; i < embs_gpu_.size(); ++i) {
if (embs_gpu_[i]) {
cudaFree(embs_gpu_[i]);
embs_gpu_[i] = nullptr;
}
}
if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
}
template <typename T>
int EmbEltwiseLayernormPluginDynamicImpl<T>::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
auto id_dims = input_desc[0].dims;
int batch = id_dims.d[0];
int seq_len = id_dims.d[1];
int input_num = embs_.size();
auto in_ptr_gpu_d =
in_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
auto emb_ptr_gpu_d =
emb_ptr_tensor_.mutable_data<int64_t>(platform::CUDAPlace(device_id_));
auto new_input_ptr = reinterpret_cast<uintptr_t>(inputs[0]);
if (old_input_ptr_ != new_input_ptr) {
old_input_ptr_ = new_input_ptr;
cudaMemcpyAsync(in_ptr_gpu_d, reinterpret_cast<const void *>(inputs),
sizeof(uintptr_t) * input_num, cudaMemcpyHostToDevice,
stream);
}
auto out_type = output_desc[0].type;
if (std::is_same<T, float>::value) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp32 input."));
} else if (std::is_same<T, half>::value) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kHALF, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp16 input."));
} else {
PADDLE_THROW(platform::errors::Fatal(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."));
}
auto *output_d = reinterpret_cast<T *>(outputs[0]);
operators::math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch, seq_len, hidden_size_, in_ptr_gpu_d,
scale_gpu_, bias_gpu_, emb_ptr_gpu_d, output_d,
eps_, input_num, stream);
return cudaGetLastError() != cudaSuccess;
}
template class EmbEltwiseLayernormPluginDynamicImpl<float>;
#ifdef SUPPORTS_CUDA_FP16
template class EmbEltwiseLayernormPluginDynamicImpl<half>;
#endif // SUPPORTS_CUDA_FP16
int EmbEltwiseLayernormPluginDynamic::initialize() {
impl_->initialize();
return 0;
}
void EmbEltwiseLayernormPluginDynamic::terminate() { impl_->terminate(); }
nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
nvinfer1::IExprBuilder &expr_builder) { // NOLINT
PADDLE_ENFORCE_EQ(output_index, 0,
......@@ -76,18 +191,7 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
return ret;
}
template <typename T>
void EmbEltwiseLayernormPluginDynamic<T>::terminate() {
for (auto ptr : embs_gpu_) {
if (ptr) cudaFree(ptr);
}
if (bias_gpu_) cudaFree(bias_gpu_);
if (scale_gpu_) cudaFree(scale_gpu_);
}
template <typename T>
bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
bool EmbEltwiseLayernormPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc *in_out, int nb_inputs,
int nb_outputs) {
PADDLE_ENFORCE_NOT_NULL(
......@@ -98,6 +202,11 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_EQ(nb_outputs, 1,
platform::errors::InvalidArgument(
"The EmbEltwiseLayerNorm's output should be one"
"but it's (%d) outputs.",
nb_outputs));
PADDLE_ENFORCE_LT(
pos, nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
......@@ -122,7 +231,7 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
}
if (pos == all_nums - 1) {
if (sizeof(T) == sizeof(float)) {
if (with_fp16_ == false) {
return desc.type == nvinfer1::DataType::kFLOAT;
} else {
return desc.type == nvinfer1::DataType::kHALF;
......@@ -131,84 +240,27 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
return false;
}
template <typename T>
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic<T>::getOutputDataType(
nvinfer1::DataType EmbEltwiseLayernormPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType *input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(
index, 0, platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return nvinfer1::DataType::kFLOAT;
if (with_fp16_)
return nvinfer1::DataType::kHALF;
else
return nvinfer1::DataType::kFLOAT;
}
template <typename T>
int EmbEltwiseLayernormPluginDynamic<T>::enqueue(
int EmbEltwiseLayernormPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc *input_desc,
const nvinfer1::PluginTensorDesc *output_desc, const void *const *inputs,
void *const *outputs, void *workspace, cudaStream_t stream) {
auto id_dims = input_desc[0].dims;
int batch = id_dims.d[0];
int seq_len = id_dims.d[1];
int input_num = embs_.size();
framework::Tensor in_ptr_tensor, emb_ptr_tensor;
int device_id;
cudaGetDevice(&device_id);
in_ptr_tensor.Resize({input_num});
emb_ptr_tensor.Resize({input_num});
int64_t *in_ptr_gpu_d =
in_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
int64_t *emb_ptr_gpu_d =
emb_ptr_tensor.mutable_data<int64_t>(platform::CUDAPlace(device_id));
std::vector<uintptr_t> in_ptr, emb_ptr;
for (int i = 0; i < input_num; i++) {
in_ptr.push_back(reinterpret_cast<uintptr_t>(inputs[i]));
emb_ptr.push_back(reinterpret_cast<uintptr_t>(embs_gpu_[i]));
}
cudaMemcpyAsync(in_ptr_gpu_d, in_ptr.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(emb_ptr_gpu_d, emb_ptr.data(), sizeof(int64_t) * input_num,
cudaMemcpyHostToDevice, stream);
auto out_type = output_desc[0].type;
const unsigned tpb = 256;
const dim3 grid(seq_len, batch, 1);
const dim3 block(tpb, 1, 1);
if (sizeof(T) == sizeof(float)) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kFLOAT, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp32 input."));
} else if (sizeof(T) == sizeof(int16_t)) {
PADDLE_ENFORCE_EQ(
out_type == nvinfer1::DataType::kHALF, true,
platform::errors::InvalidArgument(
"The EmbEltwiseLayernorm Plugin only support fp16 input."));
} else {
PADDLE_THROW(platform::errors::Fatal(
"Unsupport data type, the out type of EmbEltwiseLayernorm should be "
"float or half."));
}
T *output_d = static_cast<T *>(outputs[0]);
operators::math::EmbEltwiseLayerNormFunctor<T> emb_eltwise_layernorm_func;
emb_eltwise_layernorm_func(batch, seq_len, hidden_size_, in_ptr_gpu_d,
scale_gpu_, bias_gpu_, emb_ptr_gpu_d, output_d,
eps_, input_num, stream);
impl_->enqueue(input_desc, output_desc, inputs, outputs, workspace, stream);
return cudaGetLastError() != cudaSuccess;
}
template class EmbEltwiseLayernormPluginDynamic<float>;
#ifdef SUPPORTS_CUDA_FP16
template class EmbEltwiseLayernormPluginDynamic<half>;
#endif // SUPPORTS_CUDA_FP16
#endif
} // namespace plugin
......
......@@ -27,14 +27,76 @@ namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class EmbEltwiseLayernormPluginDynamicImplBase {
public:
EmbEltwiseLayernormPluginDynamicImplBase() {}
virtual ~EmbEltwiseLayernormPluginDynamicImplBase() {}
virtual int initialize() = 0;
virtual void terminate() = 0;
virtual int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs,
void* workspace, cudaStream_t stream) = 0;
};
template <typename T>
class EmbEltwiseLayernormPluginDynamicImpl
: public EmbEltwiseLayernormPluginDynamicImplBase {
public:
explicit EmbEltwiseLayernormPluginDynamicImpl(std::vector<float*> input_embs,
float* bias, float* scale,
std::vector<int> emb_sizes,
int bias_size, int scale_size,
int hidden_size, float eps)
: embs_(input_embs),
bias_(bias),
scale_(scale),
emb_sizes_(emb_sizes),
bias_size_(bias_size),
scale_size_(scale_size),
hidden_size_(hidden_size),
eps_(eps) {}
~EmbEltwiseLayernormPluginDynamicImpl();
int initialize();
void terminate();
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream);
private:
std::vector<float*> embs_;
float* bias_{nullptr};
float* scale_{nullptr};
// data on devices
float* bias_gpu_{nullptr};
float* scale_gpu_{nullptr};
std::vector<T*> embs_gpu_;
std::vector<int> emb_sizes_;
int bias_size_;
int scale_size_;
int hidden_size_;
float eps_;
framework::Tensor in_ptr_tensor_, emb_ptr_tensor_;
int device_id_{0};
uintptr_t old_input_ptr_{0};
};
class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
public:
explicit EmbEltwiseLayernormPluginDynamic(std::vector<float*> input_embs,
float* bias, float* scale,
std::vector<int> emb_sizes,
int bias_size, int scale_size,
int hidden_size, float eps)
int hidden_size, float eps,
bool with_fp16)
: embs_(input_embs),
bias_(bias),
scale_(scale),
......@@ -42,51 +104,81 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
bias_size_(bias_size),
scale_size_(scale_size),
hidden_size_(hidden_size),
eps_(eps) {}
eps_(eps),
with_fp16_(with_fp16),
own_host_buff_(false) {
if (with_fp16) {
#ifdef SUPPORTS_CUDA_FP16
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<half>(
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_,
hidden_size_, eps_);
#else
PADDLE_THROW(platform::errors::Fatal(
"Unsupported data type, current GPU doesn't support half."));
#endif // SUPPORTS_CUDA_FP16
} else {
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<float>(
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_,
hidden_size_, eps_);
}
}
EmbEltwiseLayernormPluginDynamic(void const* serial_data,
size_t serial_length) {
size_t serial_length)
: own_host_buff_(true) {
DeserializeValue(&serial_data, &serial_length, &emb_sizes_);
embs_gpu_.resize(emb_sizes_.size());
embs_.resize(emb_sizes_.size());
for (size_t i = 0; i < emb_sizes_.size(); i++) {
cudaMalloc(&embs_gpu_[i], sizeof(float) * emb_sizes_[i]);
cudaMemcpy(embs_gpu_[i], serial_data, emb_sizes_[i] * sizeof(float),
cudaMemcpyHostToDevice);
auto size = emb_sizes_[i];
auto ptr = new float[size];
memcpy(ptr, serial_data, sizeof(float) * size);
embs_[i] = ptr;
reinterpret_cast<char const*&>(serial_data) +=
emb_sizes_[i] * sizeof(float);
serial_length -= emb_sizes_[i] * sizeof(float);
embs_[i] = nullptr;
}
DeserializeValue(&serial_data, &serial_length, &bias_size_);
DeserializeValue(&serial_data, &serial_length, &scale_size_);
cudaMalloc(&bias_gpu_, sizeof(float) * bias_size_);
cudaMemcpy(bias_gpu_, serial_data, bias_size_ * sizeof(float),
cudaMemcpyHostToDevice);
bias_ = nullptr;
if (bias_size_) {
bias_ = new float[bias_size_];
memcpy(bias_, serial_data, sizeof(float) * bias_size_);
}
reinterpret_cast<char const*&>(serial_data) += bias_size_ * sizeof(float);
serial_length -= bias_size_ * sizeof(float);
cudaMalloc(&scale_gpu_, sizeof(float) * scale_size_);
cudaMemcpy(scale_gpu_, serial_data, scale_size_ * sizeof(float),
cudaMemcpyHostToDevice);
scale_ = nullptr;
if (scale_size_) {
scale_ = new float[scale_size_];
memcpy(scale_, serial_data, sizeof(float) * scale_size_);
}
reinterpret_cast<char const*&>(serial_data) += scale_size_ * sizeof(float);
serial_length -= scale_size_ * sizeof(float);
DeserializeValue(&serial_data, &serial_length, &hidden_size_);
DeserializeValue(&serial_data, &serial_length, &eps_);
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
if (with_fp16_) {
#ifdef SUPPORTS_CUDA_FP16
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<half>(
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_,
hidden_size_, eps_);
#else
PADDLE_THROW(platform::errors::Fatal(
"Unsupported data type, current GPU doesn't support half."));
#endif // SUPPORTS_CUDA_FP16
} else {
impl_ = new EmbEltwiseLayernormPluginDynamicImpl<float>(
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_,
hidden_size_, eps_);
}
}
nvinfer1::IPluginV2DynamicExt* clone() const override {
auto ptr = new EmbEltwiseLayernormPluginDynamic(
embs_, bias_, scale_, emb_sizes_, bias_size_, scale_size_, hidden_size_,
eps_);
ptr->embs_gpu_ = embs_gpu_;
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
eps_, with_fp16_);
return ptr;
}
......@@ -95,6 +187,7 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
}
int getNbOutputs() const override { return 1; }
int initialize() override;
void terminate() override;
size_t getSerializationSize() const override {
int sum_num = 0;
......@@ -110,24 +203,32 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
sum_num += (bias_size_ + scale_size_) * sizeof(float);
sum_num += SerializedSize(hidden_size_);
sum_num += SerializedSize(eps_);
// sum_num += SerializedSize(with_fp16_);
sum_num += SerializedSize(with_fp16_);
return sum_num;
}
void terminate() override;
void serialize(void* buffer) const override {
// SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, emb_sizes_);
for (size_t i = 0; i < emb_sizes_.size(); i++) {
SerializeCudaPointer(&buffer, embs_gpu_[i], emb_sizes_[i]);
auto size = emb_sizes_[i];
for (int j = 0; j < size; ++j) {
SerializeValue(&buffer, embs_[i][j]);
}
}
SerializeValue(&buffer, bias_size_);
SerializeValue(&buffer, scale_size_);
SerializeCudaPointer(&buffer, bias_gpu_, bias_size_);
SerializeCudaPointer(&buffer, scale_gpu_, scale_size_);
for (int i = 0; i < bias_size_; ++i) {
SerializeValue(&buffer, bias_[i]);
}
for (int i = 0; i < scale_size_; ++i) {
SerializeValue(&buffer, scale_[i]);
}
SerializeValue(&buffer, hidden_size_);
SerializeValue(&buffer, eps_);
SerializeValue(&buffer, with_fp16_);
}
nvinfer1::DimsExprs getOutputDimensions(
......@@ -158,23 +259,33 @@ class EmbEltwiseLayernormPluginDynamic : public DynamicPluginTensorRT {
const nvinfer1::DataType* input_types,
int nb_inputs) const override;
void destroy() override { delete this; }
void destroy() override {
if (own_host_buff_) {
for (auto ptr : embs_) {
delete[] ptr;
}
delete[] bias_;
delete[] scale_;
}
delete impl_;
delete this;
}
private:
std::vector<float*> embs_;
float* bias_;
float* scale_;
// data on devices
float* bias_gpu_;
float* scale_gpu_;
std::vector<float*> embs_gpu_;
std::vector<int> emb_sizes_;
int bias_size_;
int scale_size_;
int hidden_size_;
float eps_;
bool with_fp16_;
bool own_host_buff_{false};
EmbEltwiseLayernormPluginDynamicImplBase* impl_{nullptr};
};
class EmbEltwiseLayernormPluginV2Creator : public nvinfer1::IPluginCreator {
......@@ -198,8 +309,7 @@ class EmbEltwiseLayernormPluginV2Creator : public nvinfer1::IPluginCreator {
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
return new EmbEltwiseLayernormPluginDynamic<float>(serial_data,
serial_length);
return new EmbEltwiseLayernormPluginDynamic(serial_data, serial_length);
}
void setPluginNamespace(const char* lib_namespace) override {
......
......@@ -151,7 +151,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
run(config, &out_data); // serialize
run(*config_deser, &out_data); // deserialize
for (size_t i = 0; i < out_data.size(); i++) {
EXPECT_NEAR(result[i], out_data[i], 1e-6);
EXPECT_NEAR(result[i], out_data[i], 1e-2);
}
}
......@@ -159,13 +159,11 @@ TEST(AnalysisPredictor, no_fp16) {
std::vector<float> result = {0.597841, 0.219972, 0.182187};
trt_ernie(false, result);
}
TEST(AnalysisPredictor, fp16) {
#ifdef SUPPORTS_CUDA_FP16
std::vector<float> result = {0.598336, 0.219558, 0.182106};
TEST(AnalysisPredictor, fp16) {
std::vector<float> result = {0.59923654, 0.21923761, 0.18152587};
trt_ernie(true, result);
#endif
}
#endif // SUPPORTS_CUDA_FP16
} // namespace inference
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册