未验证 提交 85192b55 编写于 作者: J Jason 提交者: GitHub

Merge pull request #301 from Channingss/develop_paddle1.8

[ONNX] add Greater, ReduceMax
......@@ -107,6 +107,10 @@ class OpSet9():
'reduce_min', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
'ReduceMax': [
'reduce_max', ['X'], ['Out'], dict(
axes='dim', keepdims='keep_dim'), dict(keep_dim=1)
],
#active function
'Relu': ['relu', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)],
......@@ -131,10 +135,7 @@ class OpSet9():
'Abs': ['abs', ['X'], ['Out']],
}
default_ioa_constraint = {
'Gather':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported')],
}
default_ioa_constraint = {}
def __init__(self, decoder):
super(OpSet9, self).__init__()
......@@ -1082,6 +1083,17 @@ class OpSet9():
output=node,
param_attr=None)
@print_mapping_info
def Greater(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
node.fluid_code.add_layer(
"greater_than",
inputs={'x': val_x,
'y': val_y},
output=node,
param_attr=None)
@print_mapping_info
def Where(self, node):
condition = self.graph.get_input_node(node, idx=0, copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册