From bb1216f5cef9fdbbc5af95f3764831fad1b6fa7c Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 17 Jun 2021 10:28:46 +0800 Subject: [PATCH] fix trt convert fc_op'oss (#33566) --- paddle/fluid/inference/tensorrt/convert/fc_op.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index d2dcd4d11b..74bb854e55 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -48,6 +48,7 @@ class FcOpConverter : public OpConverter { } // Declare inputs auto* X = engine_->GetITensor(op_desc.Input(i_name).front()); + auto x_dim = X->getDimensions(); // Declare weights auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); PADDLE_ENFORCE_NOT_NULL( @@ -138,7 +139,13 @@ class FcOpConverter : public OpConverter { ("fc_layer_before(Output: " + output_name + ")").c_str()); // add shuffle after fc nvinfer1::Dims reshape_after_fc_dim; - reshape_after_fc_dim.nbDims = x_num_col_dims + 1; + if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && + x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 1) { + // If use tensorrt'oss, the x_dim and x_num_col_dims need change + reshape_after_fc_dim.nbDims = 4; + } else { + reshape_after_fc_dim.nbDims = x_num_col_dims + 1; + } for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { reshape_after_fc_dim.d[i] = 0; } @@ -181,11 +188,15 @@ class FcOpConverter : public OpConverter { static_cast(bias_data), static_cast(bias_num)}; - auto x_dim = X->getDimensions(); // Running the TRT Static Shape mode: x_num_col_dims-1 if (!engine_->with_dynamic_shape()) { x_num_col_dims--; } + // If use tensorrt'oss, the x_dim and x_num_col_dims need change + if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && + x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) { + x_num_col_dims = 1; + } PADDLE_ENFORCE_GT( x_dim.nbDims, x_num_col_dims, platform::errors::InvalidArgument( -- GitLab