未验证 提交 20e8bf1f 编写于 作者: B baoachun 提交者: GitHub

fix FlattenContiguousRangeOpConverter out dim error (#42087) (#42184)

* fix FlattenContiguousRangeOpConverter out dim error

* update code
上级 8c3c6dae
......@@ -30,14 +30,17 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid flatten_contiguous_range op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
int dims = input->getDimensions().nbDims;
const auto input_dim = input->getDimensions();
const int dims = input_dim.nbDims;
int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis"));
nvinfer1::IShuffleLayer* layer = nullptr;
nvinfer1::IShuffleLayer* layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (!engine_->with_dynamic_shape()) {
if (start_axis < 0) start_axis += dims + 1;
if (stop_axis < 0) stop_axis += dims + 1;
......@@ -46,7 +49,7 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
flatten_dim.nbDims = dims - (stop_axis - start_axis);
for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input->getDimensions().d[i];
int dim_i = input_dim.d[i];
PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument(
"flatten_contiguous_range input dim "
"should be > 0, but got %d.",
......@@ -56,14 +59,42 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
flatten_dim.d[j++] = dim_prod;
}
} else {
flatten_dim.d[j++] = input->getDimensions().d[i];
flatten_dim.d[j++] = input_dim.d[i];
}
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setReshapeDimensions(flatten_dim);
} else {
if (start_axis < 0) start_axis += dims;
if (stop_axis < 0) stop_axis += dims;
int dim_prod = 1;
int dim_negative = 0;
nvinfer1::Dims flatten_dim;
flatten_dim.nbDims = dims - (stop_axis - start_axis);
bool need_slice = false;
for (int i = 0, j = 0; i < dims; ++i) {
int dim_i = input_dim.d[i];
if (start_axis <= i && i <= stop_axis) {
if (dim_i < 0) {
need_slice = true;
break;
}
dim_prod *= dim_i;
if (i == stop_axis) {
flatten_dim.d[j++] = dim_prod;
}
} else {
if (dim_i < 0) dim_negative++;
if (dim_negative > 1) {
need_slice = true;
break;
}
flatten_dim.d[j++] = input_dim.d[i];
}
}
if (need_slice) {
VLOG(3) << "slice input dim when the input dimension has -1";
auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* shape_layer_itensor = shape_layer->getOutput(0);
......@@ -75,8 +106,8 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
size_dim.d[0] = stop_axis - start_axis + 1;
stride_dim.d[0] = 1;
auto* slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor, start_dim,
size_dim, stride_dim);
TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor,
start_dim, size_dim, stride_dim);
uint32_t reduce_dim = 1;
auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *(slice_layer->getOutput(0)),
......@@ -119,9 +150,12 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
concat_layer->setAxis(0);
input_shape = concat_layer->getOutput(0);
}
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setInput(1, *input_shape);
} else {
layer->setReshapeDimensions(flatten_dim);
}
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name},
test_mode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册