未验证 提交 3ca8cf44 编写于 作者: W Wang Bojun 提交者: GitHub

Layernorm shape bugfix (#45431)

* fix bug fix

* add shape size check

* polish code

* multi -1 shape fix

* code style improve

* bug fix

* code style fix
上级 14f6c74b
......@@ -56,14 +56,10 @@ class LayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layernorm_layer = nullptr;
if (engine_->with_dynamic_shape()) {
int statis_num = 1;
// For dynamic shape,
// the batch num will be taken into account in plugin runtime.
for (int i = 1; i < begin_norm_axis; i++) {
statis_num *= X->getDimensions().d[i];
}
std::vector<int64_t> mean_shape{statis_num};
std::vector<int64_t> variance_shape{statis_num};
// the shape of mean and variance will be determine in configuPlugin.
std::vector<int64_t> mean_shape{1};
std::vector<int64_t> variance_shape{1};
plugin::LayerNormPluginDynamic* plugin =
new plugin::LayerNormPluginDynamic(
static_cast<const float*>(bias_weight.get().values),
......@@ -77,7 +73,7 @@ class LayerNormOpConverter : public OpConverter {
layernorm_layer = engine_->AddDynamicPlugin(&X, 1, plugin);
} else {
int statis_num = 1;
for (int i = 0; i < begin_norm_axis; i++) {
for (int i = 1; i < begin_norm_axis; i++) {
statis_num *= X->getDimensions().d[i];
}
std::vector<int64_t> mean_shape{statis_num};
......
......@@ -53,6 +53,22 @@ int LayerNormPlugin::enqueue(int batch_size,
int begin_norm_axis = begin_norm_axis_;
float eps = eps_;
PADDLE_ENFORCE_EQ(1,
mean_shape_.size(),
platform::errors::InvalidArgument(
"Size of mean_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d",
mean_shape_.size()));
PADDLE_ENFORCE_EQ(1,
variance_shape_.size(),
platform::errors::InvalidArgument(
"Size of variance_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d",
mean_shape_.size()));
int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0];
int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0];
std::vector<int> input_shape;
input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) {
......@@ -78,8 +94,8 @@ int LayerNormPlugin::enqueue(int batch_size,
scale_t.Resize(phi::make_ddim({feature_size}));
bias_t.Resize(phi::make_ddim({feature_size}));
mean_t.Resize(phi::make_ddim(mean_shape_));
variance_t.Resize(phi::make_ddim(variance_shape_));
mean_t.Resize(phi::make_ddim({batched_mean_shape}));
variance_t.Resize(phi::make_ddim({batched_variance_shape}));
int device_id;
cudaGetDevice(&device_id);
float *scale_d = scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
......@@ -147,6 +163,20 @@ bool LayerNormPluginDynamic::supportsFormatCombination(
return in.type == prev.type && in.format == prev.format;
}
void LayerNormPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nbOutputs) TRT_NOEXCEPT {
const auto &input_dims = in[0].desc.dims;
int statis_num = 1;
for (int i = 0; i < begin_norm_axis_; i++) {
statis_num *= input_dims.d[i];
}
mean_shape_[0] = statis_num;
variance_shape_[0] = statis_num;
}
nvinfer1::DataType LayerNormPluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType *input_types,
......@@ -189,8 +219,19 @@ int LayerNormPluginDynamic::enqueue(
"Size of variance_shape vector should be equal to 1,"
"but got Size of mean_shape vector:%d",
mean_shape_.size()));
int64_t batched_mean_shape = mean_shape_[0] * input_dims.d[0];
int64_t batched_variance_shape = variance_shape_[0] * input_dims.d[0];
PADDLE_ENFORCE_GE(mean_shape_[0],
0,
platform::errors::InvalidArgument(
"The size of mean vector should be positive,"
"but got:%d",
mean_shape_[0]));
PADDLE_ENFORCE_GE(variance_shape_[0],
0,
platform::errors::InvalidArgument(
"The size of mean vector should be positive,"
"but got:%d",
variance_shape_[0]));
const auto input_ddim = phi::make_ddim(input_shape);
auto matrix_dim = phi::flatten_to_2d(input_ddim, begin_norm_axis);
int feature_size = static_cast<int>(matrix_dim[1]);
......@@ -217,8 +258,8 @@ int LayerNormPluginDynamic::enqueue(
float *output = static_cast<float *>(outputs[0]);
scale_t.Resize(phi::make_ddim({feature_size}));
bias_t.Resize(phi::make_ddim({feature_size}));
mean_t.Resize(phi::make_ddim({batched_mean_shape}));
variance_t.Resize(phi::make_ddim({batched_variance_shape}));
mean_t.Resize(phi::make_ddim(mean_shape_));
variance_t.Resize(phi::make_ddim(variance_shape_));
float *scale_d =
scale_t.mutable_data<float>(platform::CUDAPlace(device_id));
......
......@@ -215,7 +215,7 @@ class LayerNormPluginDynamic : public DynamicPluginTensorRT {
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nbOutputs) TRT_NOEXCEPT override {}
int nbOutputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册