From a21c51646edb52b8edd3ee2909bef87af9710f77 Mon Sep 17 00:00:00 2001 From: BBuf <1182563586@qq.com> Date: Thu, 24 Jun 2021 14:14:45 +0800 Subject: [PATCH] fix commnets --- oneflow_onnx/oneflow2onnx/handlers/array.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/oneflow_onnx/oneflow2onnx/handlers/array.py b/oneflow_onnx/oneflow2onnx/handlers/array.py index 199c0ea..0185939 100644 --- a/oneflow_onnx/oneflow2onnx/handlers/array.py +++ b/oneflow_onnx/oneflow2onnx/handlers/array.py @@ -121,12 +121,26 @@ class Flatten: assert dtype == 1, f"onnx opset version 1/9 only support float32 data_type!" assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" node.attrs["axis"] = start_dim + + @classmethod + def Version_9(cls, ctx, node, **kwargs): + start_dim = node.attrs.get("start_dim", 1) + dtype = ctx.get_dtype(node.input_tensor_names[0]) + assert dtype == 1, f"onnx opset version 1/9 only support float32 data_type!" + assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" + node.attrs["axis"] = start_dim @classmethod def Version_11(cls, ctx, node, **kwargs): start_dim = node.attrs.get("start_dim", 1) assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" node.attrs["axis"] = start_dim + + @classmethod + def Version_13(cls, ctx, node, **kwargs): + start_dim = node.attrs.get("start_dim", 1) + assert start_dim >= 0, f"oneflow flatten can't support neagetive dim now!" + node.attrs["axis"] = start_dim @flow_op("squeeze", "Squeeze") class Squeeze: -- GitLab