提交 f0d02885 编写于 作者: W wjj19950828

fixed mod op

上级 ebf440ac
...@@ -42,8 +42,7 @@ class TestModConvert(OPConvertAutoScanTest): ...@@ -42,8 +42,7 @@ class TestModConvert(OPConvertAutoScanTest):
input_data[abs(input_data) < 1.0] = 1.0 input_data[abs(input_data) < 1.0] = 1.0
return input_data return input_data
input_dtype = draw( input_dtype = draw(st.sampled_from(["int32", "int64"]))
st.sampled_from(["int32", "int64", "float32", "float64"]))
config = { config = {
"op_names": ["Mod"], "op_names": ["Mod"],
......
...@@ -46,8 +46,6 @@ def _get_same_padding(in_size, kernel_size, stride, autopad): ...@@ -46,8 +46,6 @@ def _get_same_padding(in_size, kernel_size, stride, autopad):
class OpSet10(OpSet9): class OpSet10(OpSet9):
def __init__(self, decoder, paddle_graph): def __init__(self, decoder, paddle_graph):
super(OpSet10, self).__init__(decoder, paddle_graph) super(OpSet10, self).__init__(decoder, paddle_graph)
# Support Mod op Since opset version >= 10
self.elementwise_ops.update({"Mod": "paddle.mod"})
@print_mapping_info @print_mapping_info
def AveragePool(self, node): def AveragePool(self, node):
...@@ -99,3 +97,18 @@ class OpSet10(OpSet9): ...@@ -99,3 +97,18 @@ class OpSet10(OpSet9):
inputs={'x': val_x if isinstance(val_x, str) else val_x.name}, inputs={'x': val_x if isinstance(val_x, str) else val_x.name},
outputs=layer_outputs, outputs=layer_outputs,
**layer_attrs) **layer_attrs)
@print_mapping_info
def Mod(self, node):
# Support Mod op Since opset version >= 10
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
input_dtype = str(val_x.dtype)
assert "int" in input_dtype, 'Now only support int32 or int64 dtype'
fmod = node.get_attr('fmod', 0)
assert fmod == 0, 'Now only support fmod == 0'
self.paddle_graph.add_layer(
'paddle.mod',
inputs={"x": val_x.name,
"y": val_y.name},
outputs=[node.name])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册