未验证 提交 f68d4fb3 编写于 作者: Z Zhaolong Xing 提交者: GitHub

fix bert bug using trt6 when compile with CUDA_ARCH_NAME=All (#24517)

test=develop
上级 a13a4dbc
......@@ -91,9 +91,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
#else
PADDLE_THROW(
platform::errors::Fatal("use EmbEltwiseLayernormPluginDynamic "
"FP16, but GPU doesn't have FP16."));
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
#endif
} else {
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
......
......@@ -29,7 +29,6 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("fused_embedding_eltwise_layernorm");
teller_set.insert("multihead_matmul");
teller_set.insert("skip_layernorm");
teller_set.insert("slice");
#endif
}
......
......@@ -120,7 +120,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
if (with_fp16) {
precision = AnalysisConfig::Precision::kHalf;
}
config.EnableTensorRtEngine(1 << 30, 1, 1, precision, false, true);
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, true);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
std::vector<float> out_data;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册