未验证 提交 02c3121e 编写于 作者: W WJJ1995 提交者: GitHub

Support fmod=1 in Mod OP (#921)

* fixed Gemm bug

* re-lint

* fixed typo error

* support fmod=1
上级 ce96b0f5
......@@ -42,7 +42,7 @@ class TestModConvert(OPConvertAutoScanTest):
input_data[abs(input_data) < 1.0] = 1.0
return input_data
input_dtype = draw(st.sampled_from(["int32", "int64"]))
input_dtype = draw(st.sampled_from(["float32", "int32", "int64"]))
config = {
"op_names": ["Mod"],
......
......@@ -103,15 +103,52 @@ class OpSet10(OpSet9):
# Support Mod op Since opset version >= 10
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
input_dtype = str(val_x.dtype)
assert "int" in input_dtype, 'Now only support int32 or int64 dtype'
fmod = node.get_attr('fmod', 0)
assert fmod == 0, 'Now only support fmod == 0'
x_dtype = str(val_x.dtype)
y_dtype = str(val_y.dtype)
if "int" in x_dtype and "int" in y_dtype and fmod == 1:
fmod = 0
if fmod == 0:
self.paddle_graph.add_layer(
'paddle.mod',
inputs={"x": val_x.name,
"y": val_y.name},
outputs=[node.name])
else:
# Step1:trunc_div(a, b) = sign(a / b) * floor(abs(a / b))
self.paddle_graph.add_layer(
'paddle.divide',
inputs={"x": val_x.name,
"y": val_y.name},
outputs=[node.name + "_divide"])
self.paddle_graph.add_layer(
'paddle.sign',
inputs={"x": node.name + "_divide"},
outputs=[node.name + "_sign"])
self.paddle_graph.add_layer(
'paddle.abs',
inputs={"x": node.name + "_divide"},
outputs=[node.name + "_abs"])
self.paddle_graph.add_layer(
'paddle.floor',
inputs={"x": node.name + "_abs"},
outputs=[node.name + "_floor"])
self.paddle_graph.add_layer(
'paddle.multiply',
inputs={"x": node.name + "_sign",
"y": node.name + "_floor"},
outputs=[node.name + "_trunc_div"])
# Step2:result = a - trunc_div(a, b) * b
self.paddle_graph.add_layer(
'paddle.multiply',
inputs={"x": node.name + "_trunc_div",
"y": val_y.name},
outputs=[node.name + "_trunc_div"])
self.paddle_graph.add_layer(
'paddle.subtract',
inputs={"x": val_x.name,
"y": node.name + "_trunc_div"},
outputs=[node.name])
@print_mapping_info
def IsInf(self, node):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册