diff --git a/oneflow_onnx/oneflow2onnx/handlers/array.py b/oneflow_onnx/oneflow2onnx/handlers/array.py index 199c0ea1cb5385f865807cb094916a4ae950d8ee..0185939f59bbbcbccda73b5c4b9ac0b7f6379eb0 100644 --- a/oneflow_onnx/oneflow2onnx/handlers/array.py +++ b/oneflow_onnx/oneflow2onnx/handlers/array.py @@ -121,12 +121,26 @@ class Flatten: assert dtype == 1, f"onnx opset version 1/9 only support float32 data_type!" assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" node.attrs["axis"] = start_dim + + @classmethod + def Version_9(cls, ctx, node, **kwargs): + start_dim = node.attrs.get("start_dim", 1) + dtype = ctx.get_dtype(node.input_tensor_names[0]) + assert dtype == 1, f"onnx opset version 1/9 only support float32 data_type!" + assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" + node.attrs["axis"] = start_dim @classmethod def Version_11(cls, ctx, node, **kwargs): start_dim = node.attrs.get("start_dim", 1) assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" node.attrs["axis"] = start_dim + + @classmethod + def Version_13(cls, ctx, node, **kwargs): + start_dim = node.attrs.get("start_dim", 1) + assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" + node.attrs["axis"] = start_dim @flow_op("squeeze", "Squeeze") class Squeeze: