未验证 提交 32f07216 编写于 作者: P Pei Yang 提交者: GitHub

fix bert bug using trt6 when compile with CUDA_ARCH_NAME=All (#24576)

test=develop
Co-authored-by: Nnhzlx <nhzlx.dragon@gmail.com>
上级 db0c1ea8
......@@ -91,9 +91,9 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
#else
PADDLE_THROW(
platform::errors::Fatal("use EmbEltwiseLayernormPluginDynamic "
"FP16, but GPU doesn't have FP16."));
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
input_embs, bias, scale, emb_sizes, bias_size, scale_size, hidden,
eps);
#endif
} else {
plugin = new plugin::EmbEltwiseLayernormPluginDynamic<float>(
......
......@@ -29,7 +29,6 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("fused_embedding_eltwise_layernorm");
teller_set.insert("multihead_matmul");
teller_set.insert("skip_layernorm");
teller_set.insert("slice");
#endif
}
......
......@@ -120,7 +120,7 @@ void trt_ernie(bool with_fp16, std::vector<float> result) {
if (with_fp16) {
precision = AnalysisConfig::Precision::kHalf;
}
config.EnableTensorRtEngine(1 << 30, 1, 1, precision, false, true);
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, true);
config.SetTRTDynamicShapeInfo(min_input_shape, max_input_shape,
opt_input_shape);
std::vector<float> out_data;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册