From 02c3121ea7bbd31f28b1039e738e05fd3860bf10 Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Tue, 13 Dec 2022 15:29:28 +0800 Subject: [PATCH] Support fmod=1 in Mod OP (#921) * fixed Gemm bug * re-lint * fixed typo error * support fmod=1 --- tests/onnx/test_auto_scan_mod.py | 2 +- x2paddle/op_mapper/onnx2paddle/opset10.py | 53 +++++++++++++++++++---- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/tests/onnx/test_auto_scan_mod.py b/tests/onnx/test_auto_scan_mod.py index 201a943..6a355c6 100644 --- a/tests/onnx/test_auto_scan_mod.py +++ b/tests/onnx/test_auto_scan_mod.py @@ -42,7 +42,7 @@ class TestModConvert(OPConvertAutoScanTest): input_data[abs(input_data) < 1.0] = 1.0 return input_data - input_dtype = draw(st.sampled_from(["int32", "int64"])) + input_dtype = draw(st.sampled_from(["float32", "int32", "int64"])) config = { "op_names": ["Mod"], diff --git a/x2paddle/op_mapper/onnx2paddle/opset10.py b/x2paddle/op_mapper/onnx2paddle/opset10.py index 9310eb3..d299dc9 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset10.py +++ b/x2paddle/op_mapper/onnx2paddle/opset10.py @@ -103,15 +103,52 @@ class OpSet10(OpSet9): # Support Mod op Since opset version >= 10 val_x = self.graph.get_input_node(node, idx=0, copy=True) val_y = self.graph.get_input_node(node, idx=1, copy=True) - input_dtype = str(val_x.dtype) - assert "int" in input_dtype, 'Now only support int32 or int64 dtype' fmod = node.get_attr('fmod', 0) - assert fmod == 0, 'Now only support fmod == 0' - self.paddle_graph.add_layer( - 'paddle.mod', - inputs={"x": val_x.name, - "y": val_y.name}, - outputs=[node.name]) + x_dtype = str(val_x.dtype) + y_dtype = str(val_y.dtype) + if "int" in x_dtype and "int" in y_dtype and fmod == 1: + fmod = 0 + if fmod == 0: + self.paddle_graph.add_layer( + 'paddle.mod', + inputs={"x": val_x.name, + "y": val_y.name}, + outputs=[node.name]) + else: + # Step1:trunc_div(a, b) = sign(a / b) * floor(abs(a / b)) + self.paddle_graph.add_layer( + 'paddle.divide', + inputs={"x": val_x.name, + "y": val_y.name}, + outputs=[node.name + "_divide"]) + self.paddle_graph.add_layer( + 'paddle.sign', + inputs={"x": node.name + "_divide"}, + outputs=[node.name + "_sign"]) + self.paddle_graph.add_layer( + 'paddle.abs', + inputs={"x": node.name + "_divide"}, + outputs=[node.name + "_abs"]) + self.paddle_graph.add_layer( + 'paddle.floor', + inputs={"x": node.name + "_abs"}, + outputs=[node.name + "_floor"]) + self.paddle_graph.add_layer( + 'paddle.multiply', + inputs={"x": node.name + "_sign", + "y": node.name + "_floor"}, + outputs=[node.name + "_trunc_div"]) + # Step2:result = a - trunc_div(a, b) * b + self.paddle_graph.add_layer( + 'paddle.multiply', + inputs={"x": node.name + "_trunc_div", + "y": val_y.name}, + outputs=[node.name + "_trunc_div"]) + self.paddle_graph.add_layer( + 'paddle.subtract', + inputs={"x": val_x.name, + "y": node.name + "_trunc_div"}, + outputs=[node.name]) @print_mapping_info def IsInf(self, node): -- GitLab