未验证 提交 2bcec75a 编写于 作者: B baoachun 提交者: GitHub

fix FlattenContiguousRangeOpConverter out dim error (#42087)

* fix FlattenContiguousRangeOpConverter out dim error

* update code
上级 ccafd2e5
...@@ -30,14 +30,17 @@ class FlattenContiguousRangeOpConverter : public OpConverter { ...@@ -30,14 +30,17 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
public: public:
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid flatten_contiguous_range op to tensorrt layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]); auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
int dims = input->getDimensions().nbDims; const auto input_dim = input->getDimensions();
const int dims = input_dim.nbDims;
int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis")); int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis"));
int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis")); int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis"));
nvinfer1::IShuffleLayer* layer = nullptr; nvinfer1::IShuffleLayer* layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
if (start_axis < 0) start_axis += dims + 1; if (start_axis < 0) start_axis += dims + 1;
if (stop_axis < 0) stop_axis += dims + 1; if (stop_axis < 0) stop_axis += dims + 1;
...@@ -46,7 +49,7 @@ class FlattenContiguousRangeOpConverter : public OpConverter { ...@@ -46,7 +49,7 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
flatten_dim.nbDims = dims - (stop_axis - start_axis); flatten_dim.nbDims = dims - (stop_axis - start_axis);
for (int i = 0, j = 0; i < dims; ++i) { for (int i = 0, j = 0; i < dims; ++i) {
if (start_axis <= i + 1 && i + 1 <= stop_axis) { if (start_axis <= i + 1 && i + 1 <= stop_axis) {
int dim_i = input->getDimensions().d[i]; int dim_i = input_dim.d[i];
PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument(
"flatten_contiguous_range input dim " "flatten_contiguous_range input dim "
"should be > 0, but got %d.", "should be > 0, but got %d.",
...@@ -56,14 +59,42 @@ class FlattenContiguousRangeOpConverter : public OpConverter { ...@@ -56,14 +59,42 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
flatten_dim.d[j++] = dim_prod; flatten_dim.d[j++] = dim_prod;
} }
} else { } else {
flatten_dim.d[j++] = input->getDimensions().d[i]; flatten_dim.d[j++] = input_dim.d[i];
} }
} }
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setReshapeDimensions(flatten_dim); layer->setReshapeDimensions(flatten_dim);
} else { } else {
if (start_axis < 0) start_axis += dims; if (start_axis < 0) start_axis += dims;
if (stop_axis < 0) stop_axis += dims; if (stop_axis < 0) stop_axis += dims;
int dim_prod = 1;
int dim_negative = 0;
nvinfer1::Dims flatten_dim;
flatten_dim.nbDims = dims - (stop_axis - start_axis);
bool need_slice = false;
for (int i = 0, j = 0; i < dims; ++i) {
int dim_i = input_dim.d[i];
if (start_axis <= i && i <= stop_axis) {
if (dim_i < 0) {
need_slice = true;
break;
}
dim_prod *= dim_i;
if (i == stop_axis) {
flatten_dim.d[j++] = dim_prod;
}
} else {
if (dim_i < 0) dim_negative++;
if (dim_negative > 1) {
need_slice = true;
break;
}
flatten_dim.d[j++] = input_dim.d[i];
}
}
if (need_slice) {
VLOG(3) << "slice input dim when the input dimension has -1";
auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input); auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input);
auto* shape_layer_itensor = shape_layer->getOutput(0); auto* shape_layer_itensor = shape_layer->getOutput(0);
...@@ -75,8 +106,8 @@ class FlattenContiguousRangeOpConverter : public OpConverter { ...@@ -75,8 +106,8 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
size_dim.d[0] = stop_axis - start_axis + 1; size_dim.d[0] = stop_axis - start_axis + 1;
stride_dim.d[0] = 1; stride_dim.d[0] = 1;
auto* slice_layer = auto* slice_layer =
TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor, start_dim, TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor,
size_dim, stride_dim); start_dim, size_dim, stride_dim);
uint32_t reduce_dim = 1; uint32_t reduce_dim = 1;
auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER( auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *(slice_layer->getOutput(0)), engine_, Reduce, *(slice_layer->getOutput(0)),
...@@ -119,9 +150,12 @@ class FlattenContiguousRangeOpConverter : public OpConverter { ...@@ -119,9 +150,12 @@ class FlattenContiguousRangeOpConverter : public OpConverter {
concat_layer->setAxis(0); concat_layer->setAxis(0);
input_shape = concat_layer->getOutput(0); input_shape = concat_layer->getOutput(0);
} }
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
layer->setInput(1, *input_shape); layer->setInput(1, *input_shape);
} else {
layer->setReshapeDimensions(flatten_dim);
} }
}
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name}, RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name},
test_mode); test_mode);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册