未验证 提交 1e887794 编写于 作者: N Negin Raoof 提交者: GitHub

Fix for inverse spec (#2706)

* Fixed spec for input rank

* Full-rank enforce

* Updated spec

* Updated spec for behavior

* Type constrain updated

* Refactor desc

* Shape inference changes

* Fail for invalid shapes

* float type only

* Typo
上级 58a2d86c
......@@ -14445,8 +14445,10 @@ This version of the operator has been available since version 12 of the default
Calculates inverse of a square matrix or batches of square matrices.
Inverse takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions,
and the inner-most 2 dimensions form square matrices.
The output is a tensor of shape `[*, M, M]`, containing the individual inverses of all input submatrices.
and the inner-most 2 dimensions form square matrices. These matrices must be invertible (full-rank).
The behavior where one of the matrices is not invertible is undefined. The implementation can choose
to throw an error or output (garbage) results as is. The output is a tensor of shape `[*, M, M]`,
containing the individual inverses of all input submatrices.
#### Version
......@@ -14456,21 +14458,21 @@ This version of the operator has been available since version 12 of the default
<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input tensor</dd>
<dd>Input tensor. Every matrix in the batch must be invertible.</dd>
</dl>
#### Outputs
<dl>
<dt><tt>Y</tt> : T</dt>
<dd>Output tensor of the same type as input.</dd>
<dd>Output tensor of the same type and shape as the input tensor.</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to all numerical tensor types.</dd>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>
### <a name="LessOrEqual-12"></a>**LessOrEqual-12**</a>
......
......@@ -7368,8 +7368,10 @@ expect(node, inputs=[x, s, bias], outputs=[y],
Calculates inverse of a square matrix or batches of square matrices.
Inverse takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions,
and the inner-most 2 dimensions form square matrices.
The output is a tensor of shape `[*, M, M]`, containing the individual inverses of all input submatrices.
and the inner-most 2 dimensions form square matrices. These matrices must be invertible (full-rank).
The behavior where one of the matrices is not invertible is undefined. The implementation can choose
to throw an error or output (garbage) results as is. The output is a tensor of shape `[*, M, M]`,
containing the individual inverses of all input submatrices.
#### Version
......@@ -7379,21 +7381,21 @@ This version of the operator has been available since version 12 of the default
<dl>
<dt><tt>X</tt> : T</dt>
<dd>Input tensor</dd>
<dd>Input tensor. Every matrix in the batch must be invertible.</dd>
</dl>
#### Outputs
<dl>
<dt><tt>Y</tt> : T</dt>
<dd>Output tensor of the same type as input.</dd>
<dd>Output tensor of the same type and shape as the input tensor.</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(uint8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(int8), tensor(int16), tensor(int32), tensor(int64), tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to all numerical tensor types.</dd>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>
......
......@@ -2297,8 +2297,10 @@ ONNX_OPERATOR_SET_SCHEMA(
static const char* Inverse_ver12_doc = R"DOC(
Calculates inverse of a square matrix or batches of square matrices.
Inverse takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions,
and the inner-most 2 dimensions form square matrices.
The output is a tensor of shape `[*, M, M]`, containing the individual inverses of all input submatrices.
and the inner-most 2 dimensions form square matrices. These matrices must be invertible (full-rank).
The behavior where one of the matrices is not invertible is undefined. The implementation can choose
to throw an error or output (garbage) results as is. The output is a tensor of shape `[*, M, M]`,
containing the individual inverses of all input submatrices.
)DOC";
ONNX_OPERATOR_SET_SCHEMA(
......@@ -2306,12 +2308,14 @@ ONNX_OPERATOR_SET_SCHEMA(
12,
OpSchema()
.SetDoc(Inverse_ver12_doc)
.Input(0, "X", "Input tensor", "T")
.Output(0, "Y", "Output tensor of the same type as input.", "T")
.Input(0, "X", "Input tensor. Every matrix in the batch must be invertible.", "T")
.Output(0, "Y", "Output tensor of the same type and shape as the input tensor.", "T")
.TypeConstraint(
"T",
OpSchema::all_numeric_types(),
"Constrain input and output types to all numerical tensor types.")
{"tensor(float16)",
"tensor(float)",
"tensor(double)"},
"Constrain input and output types to float tensors.")
.TypeAndShapeInferenceFunction([](InferenceContext& ctx) {
// Type inference
propagateElemTypeFromInputToOutput(ctx, 0, 0);
......
......@@ -1284,6 +1284,25 @@ class TestShapeInference(unittest.TestCase):
[])
self._assert_inferred(graph, [make_tensor_value_info('z', TensorProto.FLOAT, (4, 5, 6))])
def test_inverse_float(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT16, (3, 4, 4))],
[make_node('Inverse', ['x'], 'y')],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT16, (3, 4, 4))])
graph = self._make_graph(
[('x', TensorProto.FLOAT, (2, 5, 5))],
[make_node('Inverse', ['x'], 'y')],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.FLOAT, (2, 5, 5))])
graph = self._make_graph(
[('x', TensorProto.DOUBLE, (5, 5))],
[make_node('Inverse', ['x'], 'y')],
[])
self._assert_inferred(graph, [make_tensor_value_info('y', TensorProto.DOUBLE, (5, 5))])
def test_logsoftmax_2d(self): # type: () -> None
graph = self._make_graph(
[('x', TensorProto.FLOAT, (4, 5))],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册