未验证 提交 3ca8cf44 编写于 作者: W Wang Bojun 提交者: GitHub

Layernorm shape bugfix (#45431)

* fix bug fix

* add shape size check

* polish code

* multi -1 shape fix

* code style improve

* bug fix

* code style fix
上级 14f6c74b
...@@ -56,14 +56,10 @@ class LayerNormOpConverter : public OpConverter { ...@@ -56,14 +56,10 @@ class LayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layernorm_layer = nullptr; nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
int statis_num = 1;
// For dynamic shape, // For dynamic shape,
// the batch num will be taken into account in plugin runtime. // the shape of mean and variance will be determine in configuPlugin.
for (int i = 1; i < begin_norm_axis; i++) { std::vector<int64_t> mean_shape{1};
statis_num *= X->getDimensions().d[i]; std::vector<int64_t> variance_shape{1};
}
std::vector<int64_t> mean_shape{statis_num};
std::vector<int64_t> variance_shape{statis_num};
plugin::LayerNormPluginDynamic* plugin = plugin::LayerNormPluginDynamic* plugin =
new plugin::LayerNormPluginDynamic( new plugin::LayerNormPluginDynamic(
static_cast<const float*>(bias_weight.get().values), static_cast<const float*>(bias_weight.get().values),
...@@ -77,7 +73,7 @@ class LayerNormOpConverter : public OpConverter { ...@@ -77,7 +73,7 @@ class LayerNormOpConverter : public OpConverter {
layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin); layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else { } else {
int statis_num = 1; int statis_num = 1;
for (int i = 0; i < begin_norm_axis; i++) { for (int i = 1; i < begin_norm_axis; i++) {
statis_num *= X->getDimensions().d[i]; statis_num *= X->getDimensions().d[i];
} }
std::vector<int64_t> mean_shape{statis_num}; std::vector<int64_t> mean_shape{statis_num};
......
...@@ -53,6 +53,22 @@ int LayerNormPlugin::enqueue(int batch_size, ...@@ -53,6 +53,22 @@ int LayerNormPlugin::enqueue(int batch_size,
int begin_norm_axis = begin_norm_axis_; int begin_norm_axis = begin_norm_axis_;
float eps = eps_; float eps = eps_;
PADDLE_ENFORCE_EQ(1,
mean_shape_.size(),
platform::errors::InvalidArgument(
"Size of mean_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d",
mean_shape_.size()));
PADDLE_ENFORCE_EQ(1,
variance_shape_.size(),
platform::errors::InvalidArgument(
"Size of variance_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d",
mean_shape_.size()));
int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0];
int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0];
std::vector<int> input_shape; std::vector<int> input_shape;
input_shape.push_back(batch_size); input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) { for (int i = 0; i < input_dims.nbDims; i++) {
...@@ -78,8 +94,8 @@ int LayerNormPlugin::enqueue(int batch_size, ...@@ -78,8 +94,8 @@ int LayerNormPlugin::enqueue(int batch_size,
scale_t.Resize(phi::make_ddim({feature_size})); scale_t.Resize(phi::make_ddim({feature_size}));
bias_t.Resize(phi::make_ddim({feature_size})); bias_t.Resize(phi::make_ddim({feature_size}));
mean_t.Resize(phi::make_ddim(mean_shape_)); mean_t.Resize(phi::make_ddim({batched_mean_shape}));
variance_t.Resize(phi::make_ddim(variance_shape_)); variance_t.Resize(phi::make_ddim({batched_variance_shape}));
int device_id; int device_id;
cudaGetDevice(&device_id); cudaGetDevice(&device_id);
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id)); float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
...@@ -147,6 +163,20 @@ bool LayerNormPluginDynamic::supportsFormatCombination( ...@@ -147,6 +163,20 @@ bool LayerNormPluginDynamic::supportsFormatCombination(
return in.type == prev.type && in.format == prev.format; return in.type == prev.type && in.format == prev.format;
} }
void LayerNormPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {
const auto &input_dims = in[0].desc.dims;
int statis_num = 1;
for (int i = 0; i < begin_norm_axis_; i++) {
statis_num *= input_dims.d[i];
}
mean_shape_[0] = statis_num;
variance_shape_[0] = statis_num;
}
nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType( nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType(
int index, int index,
const nvinfer1::DataType *input_types, const nvinfer1::DataType *input_types,
...@@ -189,8 +219,19 @@ int LayerNormPluginDynamic::enqueue( ...@@ -189,8 +219,19 @@ int LayerNormPluginDynamic::enqueue(
"Size of variance_shape vector should be equal to 1," "Size of variance_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d", "but got Size of mean_shape vector:%d",
mean_shape_.size())); mean_shape_.size()));
int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0]; PADDLE_ENFORCE_GE(mean_shape_[0],
int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0]; 0,
platform::errors::InvalidArgument(
"The size of mean vector should be positive,"
"but got:%d",
mean_shape_[0]));
PADDLE_ENFORCE_GE(variance_shape_[0],
0,
platform::errors::InvalidArgument(
"The size of mean vector should be positive,"
"but got:%d",
variance_shape_[0]));
const auto input_ddim = phi::make_ddim(input_shape); const auto input_ddim = phi::make_ddim(input_shape);
auto matrix_dim = phi::flatten_to_2d(input_ddim, begin_norm_axis); auto matrix_dim = phi::flatten_to_2d(input_ddim, begin_norm_axis);
int feature_size = static_cast<int>(matrix_dim[1]); int feature_size = static_cast<int>(matrix_dim[1]);
...@@ -217,8 +258,8 @@ int LayerNormPluginDynamic::enqueue( ...@@ -217,8 +258,8 @@ int LayerNormPluginDynamic::enqueue(
float *output = static_cast<float *>(outputs[0]); float *output = static_cast<float *>(outputs[0]);
scale_t.Resize(phi::make_ddim({feature_size})); scale_t.Resize(phi::make_ddim({feature_size}));
bias_t.Resize(phi::make_ddim({feature_size})); bias_t.Resize(phi::make_ddim({feature_size}));
mean_t.Resize(phi::make_ddim({batched_mean_shape})); mean_t.Resize(phi::make_ddim(mean_shape_));
variance_t.Resize(phi::make_ddim({batched_variance_shape})); variance_t.Resize(phi::make_ddim(variance_shape_));
float *scale_d = float *scale_d =
scale_t.mutable_data<float>(platform::CUDAPlace(device_id)); scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
......
...@@ -215,7 +215,7 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT { ...@@ -215,7 +215,7 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override {} int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, int nbInputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册