未验证 提交 1c44d3e2 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] support ernie quant model with interleaved (#39424)

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved

* support ernie quant model with interleaved
上级 7e52beae
......@@ -334,6 +334,9 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
if (end_patter_layernorms[k]->Op()->HasAttr("out_threshold")) {
new_op_desc.SetAttr("enable_int8", true);
new_op_desc.SetAttr(
"out_threshold",
end_patter_layernorms[k]->Op()->GetAttr("out_threshold"));
}
auto* embedding_eltwise_layernorm = graph->CreateOpNode(&new_op_desc);
......
......@@ -501,7 +501,6 @@ PDNode* MultiHeadMatmulV3Pattern::operator()() {
auto* reshape2_qkv_out_var = pattern->NewNode(reshape2_qkv_out_repr())
->assert_is_op_output("reshape2");
reshape2_qkv_out_var->assert_is_ops_input(matmul_ops);
// Second path to matmul
auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_ops(matmul_ops);
auto* mul1_w_var = pattern->NewNode(mul1_w_repr())
......@@ -671,6 +670,7 @@ MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() {
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
......@@ -687,6 +687,7 @@ MultiHeadMatmulV2FusePass::MultiHeadMatmulV2FusePass() {
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
......@@ -761,7 +762,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out,
Node* softmax_qk, Node* eltadd0, Node* eltadd1, Node* eltadd2,
Node* matmul_qk) {
Node* matmul_qk, Node* reshape2_qkv) {
auto scale_attr = BOOST_GET_CONST(float, scale->Op()->GetAttr("scale"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
......@@ -905,7 +906,10 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_op_desc.SetAttr("dp_probs", qkv_plugin_scale);
}
}
if (reshape2_qkv->Op()->HasAttr("out_threshold")) {
multihead_op_desc.SetAttr("out_threshold",
reshape2_qkv->Op()->GetAttr("out_threshold"));
}
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(input0, multihead);
......@@ -1008,7 +1012,7 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
fuse_creater(input0, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, mul0_w,
mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, eltadd_qk_b,
reshape2_0, reshape2_qkv_out, scale, scale_out, softmax_qk,
eltadd0, eltadd1, eltadd2, matmul_qk);
eltadd0, eltadd1, eltadd2, matmul_qk, reshape2_qkv);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
......@@ -1130,6 +1134,7 @@ MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() {
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("shape") // -->(B, S, H, N) <--(B, S, N*H)
......@@ -1146,6 +1151,7 @@ MultiHeadMatmulV3FusePass::MultiHeadMatmulV3FusePass() {
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis") // {0, 2, 1, 3}
......
......@@ -158,8 +158,10 @@ void SkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
new_desc.SetInput("Scale", {layer_norm_scale->Name()});
new_desc.SetInput("Bias", {layer_norm_bias->Name()});
if (elementwise->Op()->HasAttr("out_threshold")) {
if (layer_norm->Op()->HasAttr("out_threshold")) {
new_desc.SetAttr("enable_int8", true);
new_desc.SetAttr("out_threshold",
layer_norm->Op()->GetAttr("out_threshold"));
}
// outputs
......
......@@ -142,7 +142,6 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
{"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1},
};
// remember to free
nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_ptr) +
......@@ -168,6 +167,11 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
shuffle_layer->setName(
("Embeltwise_Shuffle_reshape (Output: max_seqlen " +
op_desc.Output("Out")[0] + ")")
.c_str());
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
......@@ -178,12 +182,40 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
creator->createPlugin("CustomEmbLayerNormPluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
layer = plugin_layer;
plugin_layer->setName(("CustomEmbLayerNormPluginDynamic_V2(Output: " +
op_desc.Output("Out")[0] + ")")
.c_str());
free(plugin_ptr);
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "emb_eltwise_layernorm",
{output_name, std::string("qkv_plugin_mask")},
test_mode);
if (enable_int8) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), out_scale);
engine_->SetTensorDynamicRange(plugin_layer->getOutput(1), out_scale);
}
if (engine_->with_interleaved()) {
VLOG(4)
<< "fused emb_eltwise_layernorm op: use_oss and with_interleaved";
if (!enable_int8) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
auto* shuffler_embed = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *(plugin_layer->getOutput(0)));
nvinfer1::Permutation transpose_embed{2, 1, 0, 3};
shuffler_embed->setSecondTranspose(transpose_embed);
engine_->SetITensor(op_desc.Output("Out")[0],
shuffler_embed->getOutput(0));
shuffler_embed->setName(
("Emb_eltwise_out_shuffler_transpose (Output: " +
op_desc.Output("Out")[0] + ")")
.c_str());
} else {
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "CustomEmbLayerNormPluginDynamic_V2",
{output_name, std::string("qkv_plugin_mask")},
test_mode);
}
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
......@@ -85,67 +85,31 @@ class MultiheadMatMulOpConverter : public OpConverter {
if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) {
int head_size = hidden_out / head_number;
// [3, head_number, head_size, hidden_in] -> [head_number, 3, head_size,
// hidden_in]
auto transpose_weight_v2 = [](const float* src, float* dst, int three,
int head_number, int head_size,
int hidden_in) {
const int HH = head_size * hidden_in;
for (int i = 0; i < three; ++i) {
for (int n = 0; n < head_number; ++n) {
for (int hh = 0; hh < HH; ++hh) {
dst[n * three * HH + i * HH + hh] =
src[i * head_number * HH + n * HH + hh];
}
}
}
};
// [3, head_number, head_size] -> [head_number, 3, head_size]
auto transpose_bias_v2 = [](const float* src, float* dst, int N,
int H) {
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h];
}
}
}
};
memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float));
transpose_weight_v2(weight_data_tmp.data(), weight_data, three,
head_number, head_size, hidden_in);
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
std::vector<float> bias_data_tmp;
bias_data_tmp.reserve(bias_t->numel());
memcpy(bias_data_tmp.data(), bias_data,
bias_t->numel() * sizeof(float));
transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number,
head_size);
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data),
static_cast<int32_t>(bias_t->numel())};
nvinfer1::ILayer* fc_layer = nullptr;
float dp_probs = 1.0 / 127.0;
if (enable_int8) {
if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved";
if (!enable_int8) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
nvinfer1::ILayer* fc_layer = nullptr;
float dp_probs = 1.0 / 127.0;
nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n,
nv_ksize, weight, bias);
} else {
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight, bias);
}
if (enable_int8) {
fc_layer->setName(
("Multihead: Convolution/FullyConnected: (Output: " +
output_name + ")")
.c_str());
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode"));
"must have out_threshold in multihead layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
......@@ -153,73 +117,194 @@ class MultiheadMatMulOpConverter : public OpConverter {
dp_probs =
BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0;
}
}
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "3");
assert(creator != nullptr);
std::vector<nvinfer1::PluginField> fields{
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32,
1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32,
1}};
if (qkv2context_plugin_int8) {
fields.push_back({"dq_probs", &dp_probs,
nvinfer1::PluginFieldType::kFLOAT32, 1});
}
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(malloc(
sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic",
plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
if (engine_->Has("ernie_pos_name")) {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f);
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
shuffle_layer->setName(
("Multihead: Shuffle: (Output: " + output_name + ")").c_str());
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
} else {
int head_size = hidden_out / head_number;
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
// head_size,
// hidden_in]
auto transpose_weight_v2 = [](const float* src, float* dst, int three,
int head_number, int head_size,
int hidden_in) {
const int HH = head_size * hidden_in;
for (int i = 0; i < three; ++i) {
for (int n = 0; n < head_number; ++n) {
for (int hh = 0; hh < HH; ++hh) {
dst[n * three * HH + i * HH + hh] =
src[i * head_number * HH + n * HH + hh];
}
}
}
};
// [3, head_number, head_size] -> [head_number, 3, head_size]
auto transpose_bias_v2 = [](const float* src, float* dst, int N,
int H) {
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int h = 0; h < H; ++h) {
dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h];
}
}
}
};
memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float));
transpose_weight_v2(weight_data_tmp.data(), weight_data, three,
head_number, head_size, hidden_in);
std::vector<float> bias_data_tmp;
bias_data_tmp.reserve(bias_t->numel());
memcpy(bias_data_tmp.data(), bias_data,
bias_t->numel() * sizeof(float));
transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number,
head_size);
nvinfer1::ILayer* fc_layer = nullptr;
float dp_probs = 1.0 / 127.0;
if (enable_int8) {
nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n,
nv_ksize, weight, bias);
} else {
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n,
weight, bias);
}
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
if (enable_int8) {
PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers "
"in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
if (qkv2context_plugin_int8) {
dp_probs =
BOOST_GET_CONST(float, op_desc.GetAttr("dp_probs")) / 127.0;
}
}
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
auto mask_tensor = engine_->GetITensor("qkv_plugin_mask");
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
}
bool has_mask = true;
int var_seqlen = 1;
std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32,
1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32,
1}};
if (qkv2context_plugin_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
fields.push_back({"dq_probs", &dp_probs,
nvinfer1::PluginFieldType::kFLOAT32, 1});
}
}
bool has_mask = true;
int var_seqlen = 1;
std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}};
if (qkv2context_plugin_int8) {
fields.push_back(
{"dq_probs", &dp_probs, nvinfer1::PluginFieldType::kFLOAT32, 1});
}
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic",
plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(mask_tensor);
if (engine_->Has("ernie_pos_name")) {
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(malloc(
sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
plugin_collection->fields = fields.data();
auto plugin = creator->createPlugin("CustomQKVToContextPluginDynamic",
plugin_collection);
free(plugin_collection);
std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.emplace_back(fc_layer->getOutput(0));
plugin_inputs.emplace_back(mask_tensor);
if (engine_->Has("ernie_pos_name")) {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
engine_->SetTensorDynamicRange(shuffle_layer->getOutput(0), 1.0f);
plugin_inputs.emplace_back(
engine_->GetITensor(engine_->Get<std::string>("ernie_pos_name")));
} else {
plugin_inputs.emplace_back(engine_->GetITensor(
engine_->network()
->getInput(2)
->getName())); // cu_seqlens, eval_placeholder_2
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
}
auto max_seqlen_tensor =
engine_->GetITensor(engine_->network()->getInput(3)->getName());
auto* shuffle_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle,
*const_cast<nvinfer1::ITensor*>(max_seqlen_tensor));
nvinfer1::Dims shape_dim;
shape_dim.nbDims = 1;
shape_dim.d[0] = -1;
shuffle_layer->setReshapeDimensions(shape_dim);
plugin_inputs.emplace_back(
shuffle_layer->getOutput(0)); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
} else {
PADDLE_ENFORCE_EQ(
input->getDimensions().nbDims, 3,
......
......@@ -54,47 +54,85 @@ class SkipLayerNormOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->use_oss()) {
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "2");
PADDLE_ENFORCE_NE(
creator, nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
int ld = input1->getDimensions().d[2]; // hidden dimension
PADDLE_ENFORCE_GT(ld, 0, platform::errors::InvalidArgument(
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"));
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
if (engine_->with_interleaved()) {
VLOG(4) << "fused skip_layernorm op: use_oss and with_interleaved";
if (!enable_int8) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "3");
PADDLE_ENFORCE_NE(
creator, nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
const std::vector<nvinfer1::PluginField> fields{
{"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
{ "gamma",
scale,
nvinfer1::PluginFieldType::kFLOAT32,
scale_size }};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() * sizeof(nvinfer1::PluginField)));
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj = creator->createPlugin(
"CustomSkipLayerNormPluginDynamic", pluginPtr);
auto plugin_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
} else {
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomSkipLayerNormPluginDynamic", "2");
PADDLE_ENFORCE_NE(
creator, nullptr,
platform::errors::InvalidArgument(
"fail to get creator of CustomSkipLayerNormPluginDynamic"));
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
int ld = input1->getDimensions().d[2]; // hidden dimension
PADDLE_ENFORCE_GT(ld, 0,
platform::errors::InvalidArgument(
"in CustomSkipLayerNormPluginDynamic hidden "
"dimension should > 0"));
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1},
{"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
{"gamma", scale, nvinfer1::PluginFieldType::kFLOAT32, scale_size},
};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj = creator->createPlugin(
"CustomSkipLayerNormPluginDynamic", pluginPtr);
auto plugin_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
}
const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"ld", &ld, nvinfer1::PluginFieldType::kINT32, 1},
{"beta", bias, nvinfer1::PluginFieldType::kFLOAT32, bias_size},
{"gamma", scale, nvinfer1::PluginFieldType::kFLOAT32, scale_size},
};
nvinfer1::PluginFieldCollection* pluginPtr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*pluginPtr) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
pluginPtr->nbFields = static_cast<int>(fields.size());
pluginPtr->fields = fields.data();
auto pluginObj =
creator->createPlugin("CustomSkipLayerNormPluginDynamic", pluginPtr);
auto plugin_layer = engine_->network()->addPluginV2(
inputs.data(), inputs.size(), *pluginObj);
PADDLE_ENFORCE_NE(
plugin_layer, nullptr,
platform::errors::InvalidArgument(
"fail to add CustomSkipLayerNormPluginDynamic layer"));
layer = plugin_layer;
} else {
float eps = BOOST_GET_CONST(float, op_desc.GetAttr("epsilon"));
bool with_fp16 =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册