未验证 提交 bb1216f5 编写于 作者: W Wangzheee 提交者: GitHub

fix trt convert fc_op'oss (#33566)

上级 63b03cf5
......@@ -48,6 +48,7 @@ class FcOpConverter : public OpConverter {
}
// Declare inputs
auto* X = engine_->GetITensor(op_desc.Input(i_name).front());
auto x_dim = X->getDimensions();
// Declare weights
auto* Y_v = scope.FindVar(op_desc.Input(w_name).front());
PADDLE_ENFORCE_NOT_NULL(
......@@ -138,7 +139,13 @@ class FcOpConverter : public OpConverter {
("fc_layer_before(Output: " + output_name + ")").c_str());
// add shuffle after fc
nvinfer1::Dims reshape_after_fc_dim;
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 1) {
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
reshape_after_fc_dim.nbDims = 4;
} else {
reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
}
for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) {
reshape_after_fc_dim.d[i] = 0;
}
......@@ -181,11 +188,15 @@ class FcOpConverter : public OpConverter {
static_cast<void*>(bias_data),
static_cast<size_t>(bias_num)};
auto x_dim = X->getDimensions();
// Running the TRT Static Shape mode: x_num_col_dims-1
if (!engine_->with_dynamic_shape()) {
x_num_col_dims--;
}
// If use tensorrt'oss, the x_dim and x_num_col_dims need change
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) {
x_num_col_dims = 1;
}
PADDLE_ENFORCE_GT(
x_dim.nbDims, x_num_col_dims,
platform::errors::InvalidArgument(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册