未验证 提交 ea1c05d0 编写于 作者: P Pei Yang 提交者: GitHub

fix bert bug using trt6 when compile with CUDA_ARCH_NAME=All (#24574)

test=develop
Co-authored-by: Nnhzlx <nhzlx.dragon@gmail.com>
上级 9e4fabc0
...@@ -91,9 +91,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -91,9 +91,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden, input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps); eps);
#else #else
PADDLE_THROW( plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
platform::errors::Fatal("use EmbEltwiseLayernormPluginDynamic " input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
"FP16, but GPU doesn't have FP16.")); eps);
#endif #endif
} else { } else {
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>( plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
......
...@@ -29,7 +29,6 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -29,7 +29,6 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("fused_embedding_eltwise_layernorm"); teller_set.insert("fused_embedding_eltwise_layernorm");
teller_set.insert("multihead_matmul"); teller_set.insert("multihead_matmul");
teller_set.insert("skip_layernorm"); teller_set.insert("skip_layernorm");
teller_set.insert("slice");
#endif #endif
} }
......
...@@ -120,7 +120,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) { ...@@ -120,7 +120,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
if (with_fp16) { if (with_fp16) {
precision = AnalysisConfig::Precision::kHalf; precision = AnalysisConfig::Precision::kHalf;
} }
config.EnableTensorRtEngine(1 << 30, 1, 1, precision, false, true); config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, true);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape); opt_input_shape);
std::vector<float> out_data; std::vector<float> out_data;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册