未验证 提交 b7a1ae22 编写于 作者: C ccrrong 提交者: GitHub

add layer_norm trt fp16 support (#45043)

* add fp16 support

* update

* update half

* code format

* fix unittest

* fix rocm compile error

* code format

* code format

* fix rocm compile error

* fix rocm compile error
上级 dc31d2aa
...@@ -60,6 +60,8 @@ class LayerNormOpConverter : public OpConverter { ...@@ -60,6 +60,8 @@ class LayerNormOpConverter : public OpConverter {
// the shape of mean and variance will be determine in configuPlugin. // the shape of mean and variance will be determine in configuPlugin.
std::vector<int64_t> mean_shape{1}; std::vector<int64_t> mean_shape{1};
std::vector<int64_t> variance_shape{1}; std::vector<int64_t> variance_shape{1};
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::LayerNormPluginDynamic* plugin = plugin::LayerNormPluginDynamic* plugin =
new plugin::LayerNormPluginDynamic( new plugin::LayerNormPluginDynamic(
static_cast<const float*>(bias_weight.get().values), static_cast<const float*>(bias_weight.get().values),
...@@ -69,7 +71,8 @@ class LayerNormOpConverter : public OpConverter { ...@@ -69,7 +71,8 @@ class LayerNormOpConverter : public OpConverter {
begin_norm_axis, begin_norm_axis,
eps, eps,
mean_shape, mean_shape,
variance_shape); variance_shape,
with_fp16);
layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin); layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else { } else {
int statis_num = 1; int statis_num = 1;
...@@ -78,6 +81,8 @@ class LayerNormOpConverter : public OpConverter { ...@@ -78,6 +81,8 @@ class LayerNormOpConverter : public OpConverter {
} }
std::vector<int64_t> mean_shape{statis_num}; std::vector<int64_t> mean_shape{statis_num};
std::vector<int64_t> variance_shape{statis_num}; std::vector<int64_t> variance_shape{statis_num};
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin( plugin::LayerNormPlugin* plugin = new plugin::LayerNormPlugin(
static_cast<const float*>(bias_weight.get().values), static_cast<const float*>(bias_weight.get().values),
bias_weight.get().count, bias_weight.get().count,
...@@ -86,7 +91,8 @@ class LayerNormOpConverter : public OpConverter { ...@@ -86,7 +91,8 @@ class LayerNormOpConverter : public OpConverter {
begin_norm_axis, begin_norm_axis,
eps, eps,
mean_shape, mean_shape,
variance_shape); variance_shape,
with_fp16);
layernorm_layer = engine_->AddPlugin( layernorm_layer = engine_->AddPlugin(
&X, 1, reinterpret_cast<plugin::PluginTensorRT*>(plugin)); &X, 1, reinterpret_cast<plugin::PluginTensorRT*>(plugin));
} }
......
...@@ -26,7 +26,30 @@ namespace inference { ...@@ -26,7 +26,30 @@ namespace inference {
namespace tensorrt { namespace tensorrt {
namespace plugin { namespace plugin {
int LayerNormPlugin::initialize() TRT_NOEXCEPT { return 0; } int LayerNormPlugin::initialize() TRT_NOEXCEPT {
cudaMalloc(&bias_gpu_, sizeof(float) * bias_.size());
cudaMemcpy(bias_gpu_,
bias_.data(),
bias_.size() * sizeof(float),
cudaMemcpyHostToDevice);
cudaMalloc(&scale_gpu_, sizeof(float) * scale_.size());
cudaMemcpy(scale_gpu_,
scale_.data(),
scale_.size() * sizeof(float),
cudaMemcpyHostToDevice);
return 0;
}
void LayerNormPlugin::terminate() TRT_NOEXCEPT {
if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
}
nvinfer1::Dims LayerNormPlugin::getOutputDimensions( nvinfer1::Dims LayerNormPlugin::getOutputDimensions(
int index, const nvinfer1::Dims *inputDims, int nbInputs) TRT_NOEXCEPT { int index, const nvinfer1::Dims *inputDims, int nbInputs) TRT_NOEXCEPT {
...@@ -37,6 +60,18 @@ nvinfer1::Dims LayerNormPlugin::getOutputDimensions( ...@@ -37,6 +60,18 @@ nvinfer1::Dims LayerNormPlugin::getOutputDimensions(
return output_dims; return output_dims;
} }
bool LayerNormPlugin::supportsFormat(
nvinfer1::DataType type, nvinfer1::PluginFormat format) const TRT_NOEXCEPT {
if (with_fp16_) {
return ((type == nvinfer1::DataType::kFLOAT ||
type == nvinfer1::DataType::kHALF) &&
(format == nvinfer1::PluginFormat::kLINEAR));
} else {
return ((type == nvinfer1::DataType::kFLOAT) &&
(format == nvinfer1::PluginFormat::kLINEAR));
}
}
int LayerNormPlugin::enqueue(int batch_size, int LayerNormPlugin::enqueue(int batch_size,
const void *const *inputs, const void *const *inputs,
#if IS_TRT_VERSION_LT(8000) #if IS_TRT_VERSION_LT(8000)
...@@ -48,8 +83,6 @@ int LayerNormPlugin::enqueue(int batch_size, ...@@ -48,8 +83,6 @@ int LayerNormPlugin::enqueue(int batch_size,
#endif #endif
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
const auto &input_dims = this->getInputDims(0); const auto &input_dims = this->getInputDims(0);
const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = reinterpret_cast<float *const *>(outputs)[0];
int begin_norm_axis = begin_norm_axis_; int begin_norm_axis = begin_norm_axis_;
float eps = eps_; float eps = eps_;
...@@ -92,42 +125,76 @@ int LayerNormPlugin::enqueue(int batch_size, ...@@ -92,42 +125,76 @@ int LayerNormPlugin::enqueue(int batch_size,
feature_size, feature_size,
bias_.size())); bias_.size()));
scale_t.Resize(phi::make_ddim({feature_size}));
bias_t.Resize(phi::make_ddim({feature_size}));
mean_t.Resize(phi::make_ddim({batched_mean_shape}));
variance_t.Resize(phi::make_ddim({batched_variance_shape}));
int device_id; int device_id;
cudaGetDevice(&device_id); cudaGetDevice(&device_id);
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id)); mean_t.Resize(phi::make_ddim({batched_mean_shape}));
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id)); variance_t.Resize(phi::make_ddim({batched_variance_shape}));
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id)); float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d = float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id)); variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(scale_d, auto input_type = getDataType();
scale_.data(), if (input_type == nvinfer1::DataType::kFLOAT) {
sizeof(float) * feature_size, VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp32";
cudaMemcpyHostToDevice, const float *input = reinterpret_cast<const float *>(inputs[0]);
stream); float *output = static_cast<float *>(outputs[0]);
cudaMemcpyAsync(bias_d, phi::LayerNormDirectCUDAFunctor<float, float> layer_norm;
bias_.data(), layer_norm(stream,
sizeof(float) * feature_size, input,
cudaMemcpyHostToDevice, input_shape,
stream); bias_gpu_,
scale_gpu_,
phi::LayerNormDirectCUDAFunctor<float> layer_norm; output,
layer_norm(stream, mean_d,
input, variance_d,
input_shape, begin_norm_axis,
bias_d, eps);
scale_d, } else if (input_type == nvinfer1::DataType::kHALF) {
output, VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp16";
mean_d, const half *input = reinterpret_cast<const half *>(inputs[0]);
variance_d, half *output = static_cast<half *>(outputs[0]);
begin_norm_axis, phi::LayerNormDirectCUDAFunctor<half, float> layer_norm;
eps); layer_norm(stream,
input,
input_shape,
bias_gpu_,
scale_gpu_,
output,
mean_d,
variance_d,
begin_norm_axis,
eps);
} else {
PADDLE_THROW(platform::errors::Fatal(
"The LayerNorm TRT Plugin's input type should be float or half."));
}
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
int LayerNormPluginDynamic::initialize() TRT_NOEXCEPT {
cudaMalloc(&bias_gpu_, sizeof(float) * bias_.size());
cudaMemcpy(bias_gpu_,
bias_.data(),
bias_.size() * sizeof(float),
cudaMemcpyHostToDevice);
cudaMalloc(&scale_gpu_, sizeof(float) * scale_.size());
cudaMemcpy(scale_gpu_,
scale_.data(),
scale_.size() * sizeof(float),
cudaMemcpyHostToDevice);
return 0;
}
void LayerNormPluginDynamic::terminate() TRT_NOEXCEPT {
if (bias_gpu_) {
cudaFree(bias_gpu_);
bias_gpu_ = nullptr;
}
if (scale_gpu_) {
cudaFree(scale_gpu_);
scale_gpu_ = nullptr;
}
}
nvinfer1::DimsExprs LayerNormPluginDynamic::getOutputDimensions( nvinfer1::DimsExprs LayerNormPluginDynamic::getOutputDimensions(
int output_index, int output_index,
const nvinfer1::DimsExprs *inputDims, const nvinfer1::DimsExprs *inputDims,
...@@ -154,9 +221,14 @@ bool LayerNormPluginDynamic::supportsFormatCombination( ...@@ -154,9 +221,14 @@ bool LayerNormPluginDynamic::supportsFormatCombination(
nb_inputs + nb_outputs)); nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc &in = in_out[pos]; const nvinfer1::PluginTensorDesc &in = in_out[pos];
if (pos == 0) { if (pos == 0) {
// TODO(Shangzhizhou) FP16 support if (with_fp16_) {
return (in.type == nvinfer1::DataType::kFLOAT) && return ((in.type == nvinfer1::DataType::kFLOAT ||
(in.format == nvinfer1::TensorFormat::kLINEAR); in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::PluginFormat::kLINEAR));
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
} }
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1]; const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
// output // output
...@@ -187,6 +259,11 @@ nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType( ...@@ -187,6 +259,11 @@ nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType(
"The LayerNormPlugin only has one input, so the " "The LayerNormPlugin only has one input, so the "
"index value should be 0, but get %d.", "index value should be 0, but get %d.",
index)); index));
PADDLE_ENFORCE_EQ((input_types[0] == nvinfer1::DataType::kFLOAT ||
input_types[0] == nvinfer1::DataType::kHALF),
true,
platform::errors::InvalidArgument(
"The input type should be half or float"));
return input_types[0]; return input_types[0];
} }
...@@ -249,42 +326,40 @@ int LayerNormPluginDynamic::enqueue( ...@@ -249,42 +326,40 @@ int LayerNormPluginDynamic::enqueue(
"but got feature_size:%d, bias's size:%d.", "but got feature_size:%d, bias's size:%d.",
feature_size, feature_size,
bias_.size())); bias_.size()));
int device_id; int device_id;
cudaGetDevice(&device_id); cudaGetDevice(&device_id);
mean_t.Resize(phi::make_ddim(mean_shape_));
variance_t.Resize(phi::make_ddim(variance_shape_));
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
auto input_type = input_desc[0].type; auto input_type = input_desc[0].type;
if (input_type == nvinfer1::DataType::kFLOAT) { if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp32"; VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp32";
const float *input = reinterpret_cast<const float *>(inputs[0]); const float *input = reinterpret_cast<const float *>(inputs[0]);
float *output = static_cast<float *>(outputs[0]); float *output = static_cast<float *>(outputs[0]);
scale_t.Resize(phi::make_ddim({feature_size})); phi::LayerNormDirectCUDAFunctor<float, float> layer_norm;
bias_t.Resize(phi::make_ddim({feature_size})); layer_norm(stream,
mean_t.Resize(phi::make_ddim(mean_shape_)); input,
variance_t.Resize(phi::make_ddim(variance_shape_)); input_shape,
bias_gpu_,
float *scale_d = scale_gpu_,
scale_t.mutable_data<float>(platform::CUDAPlace(device_id)); output,
float *bias_d = bias_t.mutable_data<float>(platform::CUDAPlace(device_id)); mean_d,
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id)); variance_d,
float *variance_d = begin_norm_axis,
variance_t.mutable_data<float>(platform::CUDAPlace(device_id)); eps);
} else if (input_type == nvinfer1::DataType::kHALF) {
cudaMemcpyAsync(scale_d, VLOG(1) << "TRT Plugin DataType selected. LayerNorm-->fp16";
scale_.data(), const half *input = reinterpret_cast<const half *>(inputs[0]);
sizeof(float) * feature_size, half *output = static_cast<half *>(outputs[0]);
cudaMemcpyHostToDevice, phi::LayerNormDirectCUDAFunctor<half, float> layer_norm;
stream);
cudaMemcpyAsync(bias_d,
bias_.data(),
sizeof(float) * feature_size,
cudaMemcpyHostToDevice,
stream);
phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, layer_norm(stream,
input, input,
input_shape, input_shape,
bias_d, bias_gpu_,
scale_d, scale_gpu_,
output, output,
mean_d, mean_d,
variance_d, variance_d,
...@@ -292,7 +367,7 @@ int LayerNormPluginDynamic::enqueue( ...@@ -292,7 +367,7 @@ int LayerNormPluginDynamic::enqueue(
eps); eps);
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The LayerNorm TRT Plugin's input type should be float.")); "The LayerNorm TRT Plugin's input type should be float or half."));
} }
return cudaGetLastError() != cudaSuccess; return cudaGetLastError() != cudaSuccess;
} }
......
...@@ -31,8 +31,6 @@ namespace plugin { ...@@ -31,8 +31,6 @@ namespace plugin {
class LayerNormPlugin : public PluginTensorRT { class LayerNormPlugin : public PluginTensorRT {
std::vector<float> bias_; std::vector<float> bias_;
std::vector<float> scale_; std::vector<float> scale_;
framework::Tensor scale_t;
framework::Tensor bias_t;
framework::Tensor mean_t; framework::Tensor mean_t;
framework::Tensor variance_t; framework::Tensor variance_t;
int begin_norm_axis_; int begin_norm_axis_;
...@@ -40,12 +38,16 @@ class LayerNormPlugin : public PluginTensorRT { ...@@ -40,12 +38,16 @@ class LayerNormPlugin : public PluginTensorRT {
std::vector<int64_t> mean_shape_; std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_; std::vector<int64_t> variance_shape_;
// data on devices
float* bias_gpu_{nullptr};
float* scale_gpu_{nullptr};
public: public:
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
return getBaseSerializationSize() + SerializedSize(bias_) + return getBaseSerializationSize() + SerializedSize(bias_) +
SerializedSize(scale_) + SerializedSize(begin_norm_axis_) + SerializedSize(scale_) + SerializedSize(begin_norm_axis_) +
SerializedSize(eps_) + SerializedSize(mean_shape_) + SerializedSize(eps_) + SerializedSize(mean_shape_) +
SerializedSize(variance_shape_); SerializedSize(variance_shape_) + SerializedSize(with_fp16_);
} }
// TRT will call this func when we need to serialize the configuration of // TRT will call this func when we need to serialize the configuration of
...@@ -59,6 +61,7 @@ class LayerNormPlugin : public PluginTensorRT { ...@@ -59,6 +61,7 @@ class LayerNormPlugin : public PluginTensorRT {
SerializeValue(&buffer, eps_); SerializeValue(&buffer, eps_);
SerializeValue(&buffer, mean_shape_); SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_); SerializeValue(&buffer, variance_shape_);
SerializeValue(&buffer, with_fp16_);
} }
LayerNormPlugin(const float* bias, LayerNormPlugin(const float* bias,
...@@ -68,11 +71,13 @@ class LayerNormPlugin : public PluginTensorRT { ...@@ -68,11 +71,13 @@ class LayerNormPlugin : public PluginTensorRT {
int begin_norm_axis, int begin_norm_axis,
float eps, float eps,
std::vector<int64_t> mean_shape, std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape) std::vector<int64_t> variance_shape,
bool with_fp16)
: begin_norm_axis_(begin_norm_axis), : begin_norm_axis_(begin_norm_axis),
eps_(eps), eps_(eps),
mean_shape_(mean_shape), mean_shape_(mean_shape),
variance_shape_(variance_shape) { variance_shape_(variance_shape) {
with_fp16_ = with_fp16;
bias_.resize(bias_num); bias_.resize(bias_num);
scale_.resize(scale_num); scale_.resize(scale_num);
std::copy(bias, bias + bias_num, bias_.data()); std::copy(bias, bias + bias_num, bias_.data());
...@@ -89,24 +94,33 @@ class LayerNormPlugin : public PluginTensorRT { ...@@ -89,24 +94,33 @@ class LayerNormPlugin : public PluginTensorRT {
DeserializeValue(&serialData, &serialLength, &eps_); DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &mean_shape_); DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_); DeserializeValue(&serialData, &serialLength, &variance_shape_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
} }
~LayerNormPlugin() {} ~LayerNormPlugin() {}
int initialize() TRT_NOEXCEPT override; int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
LayerNormPlugin* clone() const TRT_NOEXCEPT override { LayerNormPlugin* clone() const TRT_NOEXCEPT override {
return new LayerNormPlugin(bias_.data(), auto ptr = new LayerNormPlugin(bias_.data(),
bias_.size(), bias_.size(),
scale_.data(), scale_.data(),
scale_.size(), scale_.size(),
begin_norm_axis_, begin_norm_axis_,
eps_, eps_,
mean_shape_, mean_shape_,
variance_shape_); variance_shape_,
with_fp16_);
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
return ptr;
} }
const char* getPluginType() const TRT_NOEXCEPT override { const char* getPluginType() const TRT_NOEXCEPT override {
return "layernorm_plugin"; return "layernorm_plugin";
} }
bool supportsFormat(nvinfer1::DataType type, nvinfer1::PluginFormat format)
const TRT_NOEXCEPT override;
int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
nvinfer1::Dims getOutputDimensions(int index, nvinfer1::Dims getOutputDimensions(int index,
const nvinfer1::Dims* inputs, const nvinfer1::Dims* inputs,
...@@ -150,11 +164,13 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -150,11 +164,13 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
int begin_norm_axis, int begin_norm_axis,
float eps, float eps,
std::vector<int64_t> mean_shape, std::vector<int64_t> mean_shape,
std::vector<int64_t> variance_shape) std::vector<int64_t> variance_shape,
bool with_fp16)
: begin_norm_axis_(begin_norm_axis), : begin_norm_axis_(begin_norm_axis),
eps_(eps), eps_(eps),
mean_shape_(mean_shape), mean_shape_(mean_shape),
variance_shape_(variance_shape) { variance_shape_(variance_shape) {
with_fp16_ = with_fp16;
bias_.resize(bias_num); bias_.resize(bias_num);
scale_.resize(scale_num); scale_.resize(scale_num);
std::copy(bias, bias + bias_num, bias_.data()); std::copy(bias, bias + bias_num, bias_.data());
...@@ -168,28 +184,35 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -168,28 +184,35 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
DeserializeValue(&serialData, &serialLength, &eps_); DeserializeValue(&serialData, &serialLength, &eps_);
DeserializeValue(&serialData, &serialLength, &mean_shape_); DeserializeValue(&serialData, &serialLength, &mean_shape_);
DeserializeValue(&serialData, &serialLength, &variance_shape_); DeserializeValue(&serialData, &serialLength, &variance_shape_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
} }
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new LayerNormPluginDynamic(bias_.data(), auto ptr = new LayerNormPluginDynamic(bias_.data(),
bias_.size(), bias_.size(),
scale_.data(), scale_.data(),
scale_.size(), scale_.size(),
begin_norm_axis_, begin_norm_axis_,
eps_, eps_,
mean_shape_, mean_shape_,
variance_shape_); variance_shape_,
with_fp16_);
ptr->bias_gpu_ = bias_gpu_;
ptr->scale_gpu_ = scale_gpu_;
return ptr;
} }
const char* getPluginType() const TRT_NOEXCEPT override { const char* getPluginType() const TRT_NOEXCEPT override {
return "layernorm_plugin_dynamic"; return "layernorm_plugin_dynamic";
} }
int getNbOutputs() const TRT_NOEXCEPT override { return 1; } int getNbOutputs() const TRT_NOEXCEPT override { return 1; }
int initialize() TRT_NOEXCEPT override { return 0; } int initialize() TRT_NOEXCEPT override;
void terminate() TRT_NOEXCEPT override;
size_t getSerializationSize() const TRT_NOEXCEPT override { size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(bias_) + SerializedSize(scale_) + return SerializedSize(bias_) + SerializedSize(scale_) +
SerializedSize(begin_norm_axis_) + SerializedSize(eps_) + SerializedSize(begin_norm_axis_) + SerializedSize(eps_) +
SerializedSize(mean_shape_) + SerializedSize(variance_shape_); SerializedSize(mean_shape_) + SerializedSize(variance_shape_) +
SerializedSize(with_fp16_);
} }
void serialize(void* buffer) const TRT_NOEXCEPT override { void serialize(void* buffer) const TRT_NOEXCEPT override {
...@@ -199,6 +222,7 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -199,6 +222,7 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
SerializeValue(&buffer, eps_); SerializeValue(&buffer, eps_);
SerializeValue(&buffer, mean_shape_); SerializeValue(&buffer, mean_shape_);
SerializeValue(&buffer, variance_shape_); SerializeValue(&buffer, variance_shape_);
SerializeValue(&buffer, with_fp16_);
} }
nvinfer1::DimsExprs getOutputDimensions(int output_index, nvinfer1::DimsExprs getOutputDimensions(int output_index,
...@@ -240,14 +264,15 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -240,14 +264,15 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
private: private:
std::vector<float> bias_; std::vector<float> bias_;
std::vector<float> scale_; std::vector<float> scale_;
framework::Tensor scale_t;
framework::Tensor bias_t;
framework::Tensor mean_t; framework::Tensor mean_t;
framework::Tensor variance_t; framework::Tensor variance_t;
int begin_norm_axis_; int begin_norm_axis_;
float eps_; float eps_;
std::vector<int64_t> mean_shape_; std::vector<int64_t> mean_shape_;
std::vector<int64_t> variance_shape_; std::vector<int64_t> variance_shape_;
// data on devices
float* bias_gpu_{nullptr};
float* scale_gpu_{nullptr};
}; };
class LayerNormPluginDynamicCreator : public TensorRTPluginCreator { class LayerNormPluginDynamicCreator : public TensorRTPluginCreator {
......
...@@ -379,7 +379,8 @@ __global__ void LayerNormForward( ...@@ -379,7 +379,8 @@ __global__ void LayerNormForward(
var_val = BlockReduceSum<U>(var_val, shared_var); var_val = BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
auto scale = static_cast<float>(1.) / static_cast<float>(feature_size); auto scale = static_cast<U>(static_cast<float>(1.) /
static_cast<float>(feature_size));
auto tmp = mean_val * scale; auto tmp = mean_val * scale;
mean[blockIdx.x] = mean_share = static_cast<U>(tmp); mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
var_share = static_cast<U>(var_val * scale - mean_share * mean_share); var_share = static_cast<U>(var_val * scale - mean_share * mean_share);
......
...@@ -21,24 +21,24 @@ ...@@ -21,24 +21,24 @@
namespace phi { namespace phi {
template <typename T> template <typename T, typename U>
void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream, void LayerNormDirectCUDAFunctor<T, U>::operator()(gpuStream_t stream,
const T *input, const T *input,
std::vector<int> input_shape, std::vector<int> input_shape,
const T *bias, const U *bias,
const T *scale, const U *scale,
T *output, T *output,
T *mean, U *mean,
T *variance, U *variance,
int begin_norm_axis, int begin_norm_axis,
float eps) { float eps) {
const auto x_dims = phi::make_ddim(input_shape); const auto x_dims = phi::make_ddim(input_shape);
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int64_t batch_size = static_cast<int64_t>(matrix_dim[0]); int64_t batch_size = static_cast<int64_t>(matrix_dim[0]);
int64_t feature_size = static_cast<int64_t>(matrix_dim[1]); int64_t feature_size = static_cast<int64_t>(matrix_dim[1]);
switch (paddle::operators::GetDesiredBlockDim(feature_size)) { switch (paddle::operators::GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
paddle::operators::LayerNormForward<T, T, kBlockDim> paddle::operators::LayerNormForward<T, U, kBlockDim>
<<<batch_size, kBlockDim, 0, stream>>>( <<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size)); input, scale, bias, output, mean, variance, eps, feature_size));
default: default:
...@@ -49,7 +49,10 @@ void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream, ...@@ -49,7 +49,10 @@ void LayerNormDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
} }
} }
template class LayerNormDirectCUDAFunctor<float>; template class LayerNormDirectCUDAFunctor<float, float>;
#if defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
template class LayerNormDirectCUDAFunctor<half, float>;
#endif
template <typename T, typename Context> template <typename T, typename Context>
void LayerNormKernel(const Context &dev_ctx, void LayerNormKernel(const Context &dev_ctx,
......
...@@ -32,17 +32,17 @@ void LayerNormKernel(const Context& ctx, ...@@ -32,17 +32,17 @@ void LayerNormKernel(const Context& ctx,
DenseTensor* variance); DenseTensor* variance);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T> template <typename T, typename U>
class LayerNormDirectCUDAFunctor { class LayerNormDirectCUDAFunctor {
public: public:
void operator()(gpuStream_t stream, void operator()(gpuStream_t stream,
const T* input, const T* input,
std::vector<int> input_shape, std::vector<int> input_shape,
const T* bias, const U* bias,
const T* scale, const U* scale,
T* output, T* output,
T* mean, U* mean,
T* variance, U* variance,
int begin_norm_axis, int begin_norm_axis,
float eps); float eps);
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册