未验证 提交 bb1216f5 编写于 作者: W Wangzheee 提交者: GitHub

fix trt convert fc_op'oss (#33566)

上级 63b03cf5
...@@ -48,6 +48,7 @@ class FcOpConverter : public OpConverter { ...@@ -48,6 +48,7 @@ class FcOpConverter : public OpConverter {
} }
// Declare inputs // Declare inputs
auto* X = engine_->GetITensor(op_desc.Input(i_name).front()); auto* X = engine_->GetITensor(op_desc.Input(i_name).front());
auto x_dim = X->getDimensions();
// Declare weights // Declare weights
auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); auto* Y_v = scope.FindVar(op_desc.Input(w_name).front());
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -138,7 +139,13 @@ class FcOpConverter : public OpConverter { ...@@ -138,7 +139,13 @@ class FcOpConverter : public OpConverter {
("fc_layer_before(Output: " + output_name + ")").c_str()); ("fc_layer_before(Output: " + output_name + ")").c_str());
// add shuffle after fc // add shuffle after fc
nvinfer1::Dims reshape_after_fc_dim; nvinfer1::Dims reshape_after_fc_dim;
reshape_after_fc_dim.nbDims = x_num_col_dims + 1; if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 1) {
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
reshape_after_fc_dim.nbDims = 4;
} else {
reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
}
for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) {
reshape_after_fc_dim.d[i] = 0; reshape_after_fc_dim.d[i] = 0;
} }
...@@ -181,11 +188,15 @@ class FcOpConverter : public OpConverter { ...@@ -181,11 +188,15 @@ class FcOpConverter : public OpConverter {
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<size_t>(bias_num)}; static_cast<size_t>(bias_num)};
auto x_dim = X->getDimensions();
// Running the TRT Static Shape mode: x_num_col_dims-1 // Running the TRT Static Shape mode: x_num_col_dims-1
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
x_num_col_dims--; x_num_col_dims--;
} }
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) {
x_num_col_dims = 1;
}
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
x_dim.nbDims, x_num_col_dims, x_dim.nbDims, x_num_col_dims,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册