diff --git a/docs/oneflow2onnx/op_list.md b/docs/oneflow2onnx/op_list.md index a0bef9e5944e6584202cb213735bf9b0956f96a2..7911772dd36819e46337a3e0f9993ae253f9b99c 100644 --- a/docs/oneflow2onnx/op_list.md +++ b/docs/oneflow2onnx/op_list.md @@ -22,4 +22,4 @@ | 58 | ReduceSum| 59 | ReduceProd | 60 | ArgMax | 61 | ArgMin | |62 | Reshape | 63 | Squeeze | 64 | Transpose| 65 | Concat | | 66 | Cast | 67 | Identity | 68 | Mul | 69 | PReLU | -| 70 | LeakyReLU| 71 | Constant | +| 70 | LeakyReLU| 71 | Constant | 72 | Flatten | diff --git a/examples/oneflow2onnx/nodes/test_flatten.py b/examples/oneflow2onnx/nodes/test_flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..5d05f358e33b08d3f876a635053450f70a9a450c --- /dev/null +++ b/examples/oneflow2onnx/nodes/test_flatten.py @@ -0,0 +1,27 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +import oneflow.typing as tp +from oneflow_onnx.oneflow2onnx.util import convert_to_onnx_and_check + + +def test_flatten(): + @flow.global_function() + def flatten(x: tp.Numpy.Placeholder((3, 4, 2, 5))): + return flow.flatten(x, start_dim=1, end_dim=-1) + + convert_to_onnx_and_check(flatten) + diff --git a/oneflow_onnx/oneflow2onnx/handlers/array.py b/oneflow_onnx/oneflow2onnx/handlers/array.py index 5d92e9d1e74e1e45205e4420a84467dd61976335..2afc7f3137e6684cdffb264a6a39aab434d87672 100644 --- a/oneflow_onnx/oneflow2onnx/handlers/array.py +++ b/oneflow_onnx/oneflow2onnx/handlers/array.py @@ -112,6 +112,13 @@ class Reshape: node.output_tensor_names[0], output_cast.output_tensor_names[0] ) +@flow_op("flatten", "Flatten") +class Flatten: + @classmethod + def Version_1(cls, ctx, node, **kwargs): + start_dim = node.attrs.get("start_dim", None) + node.attrs["axis"] = start_dim + @flow_op("squeeze", "Squeeze") class Squeeze: