未验证 提交 6af1ed14 编写于 作者: T TomWildenhain-Microsoft 提交者: GitHub

Improve rank inference for Expand op (#3807)

* Improve rank inference for Expand op
Signed-off-by: NTom Wildenhain <tomwi@microsoft.com>

* Unnecessary dtype check in Expand shape inference
Signed-off-by: NTom Wildenhain <tomwi@microsoft.com>

* Add test for Expand rank inference
Signed-off-by: NTom Wildenhain <tomwi@microsoft.com>
上级 ee4d1e11
......@@ -2119,27 +2119,32 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
// For shape inference (and rank inference), we need both input shape
// and values in 'shape' tensor
// For shape inference, we need both input shape
const auto* shape_initializer = ctx.getInputData(1);
if (hasNInputShapes(ctx, 2) && nullptr != shape_initializer) {
const auto& shape_initializer_shape =
if (hasNInputShapes(ctx, 2)) {
const auto& shape_input_shape =
ctx.getInputType(1)->tensor_type().shape();
if (shape_initializer_shape.dim_size() != 1 ||
shape_initializer->data_type() != TensorProto::INT64) {
fail_shape_inference("'shape' input must be 1D tensor of type INT64");
if (shape_input_shape.dim_size() != 1) {
fail_shape_inference("'shape' input must be 1D tensor");
}
const auto& input_shape =
ctx.getInputType(0)->tensor_type().shape();
const auto& shape_data = ParseData<int64_t>(shape_initializer);
TensorShapeProto second_shape;
for (const auto& e : shape_data) {
auto* dim = second_shape.add_dim();
dim->set_dim_value(e);
}
if (nullptr != shape_initializer) {
const auto& shape_data = ParseData<int64_t>(shape_initializer);
for (const auto& e : shape_data) {
auto* dim = second_shape.add_dim();
dim->set_dim_value(e);
}
} else if (shape_input_shape.dim(0).has_dim_value()) {
// Attempt rank inference using shape of shape input
int64_t dim_value = shape_input_shape.dim(0).dim_value();
for (int64_t i = 0; i < dim_value; ++i) {
second_shape.add_dim();
}
}
bidirectionalBroadcastShapeInference(
input_shape, second_shape, *getOutputShape(ctx, 0));
}
......
......@@ -995,27 +995,32 @@ ONNX_OPERATOR_SET_SCHEMA(
propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
// For shape inference (and rank inference), we need both input shape
// and values in 'shape' tensor
// For shape inference, we need both input shape
const auto* shape_initializer = ctx.getInputData(1);
if (hasNInputShapes(ctx, 2) && nullptr != shape_initializer) {
const auto& shape_initializer_shape =
if (hasNInputShapes(ctx, 2)) {
const auto& shape_input_shape =
ctx.getInputType(1)->tensor_type().shape();
if (shape_initializer_shape.dim_size() != 1 ||
shape_initializer->data_type() != TensorProto::INT64) {
fail_shape_inference("'shape' input must be 1D tensor of type INT64");
if (shape_input_shape.dim_size() != 1) {
fail_shape_inference("'shape' input must be 1D tensor");
}
const auto& input_shape =
ctx.getInputType(0)->tensor_type().shape();
const auto& shape_data = ParseData<int64_t>(shape_initializer);
TensorShapeProto second_shape;
for (const auto& e : shape_data) {
auto* dim = second_shape.add_dim();
dim->set_dim_value(e);
if (nullptr != shape_initializer) {
const auto& shape_data = ParseData<int64_t>(shape_initializer);
for (const auto& e : shape_data) {
auto* dim = second_shape.add_dim();
dim->set_dim_value(e);
}
} else if (shape_input_shape.dim(0).has_dim_value()) {
// Attempt rank inference using shape of shape input
int64_t dim_value = shape_input_shape.dim(0).dim_value();
for (int64_t i = 0; i < dim_value; ++i) {
second_shape.add_dim();
}
}
bidirectionalBroadcastShapeInference(
input_shape, second_shape, *getOutputShape(ctx, 0));
}
......
......@@ -361,6 +361,17 @@ class TestShapeInference(unittest.TestCase):
graph,
[make_tensor_value_info('y', TensorProto.INT32, (3, 4))])
def test_expand_dynamic_shape(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.INT32, (1, 2, None)),
('shape', TensorProto.INT64, (3,))],
[make_node("Expand", ['x', 'shape'], ['y'])],
[],
initializer=[])
self._assert_inferred(
graph,
[make_tensor_value_info('y', TensorProto.INT32, (None, 2, None))])
def test_resize_size(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.INT32, (2, 4, 3, 5)),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册