提交 8340c4a6 编写于 作者: C Channingss

fix bug of elementwise_ops

上级 3789090b
...@@ -155,6 +155,7 @@ class OpSet9(object): ...@@ -155,6 +155,7 @@ class OpSet9(object):
return node return node
def elementwise_add(self, op, block): def elementwise_add(self, op, block):
print(op.input('Y'))
axis = op.attr('axis') axis = op.attr('axis')
x_shape = block.var(op.input('X')[0]).shape x_shape = block.var(op.input('X')[0]).shape
y_shape = block.var(op.input('Y')[0]).shape y_shape = block.var(op.input('Y')[0]).shape
...@@ -174,14 +175,14 @@ class OpSet9(object): ...@@ -174,14 +175,14 @@ class OpSet9(object):
inputs=[op.input('X')[0], temp_value], inputs=[op.input('X')[0], temp_value],
outputs=op.output('Out')) outputs=op.output('Out'))
return [shape_node, y_node, node] return [shape_node, y_node, node]
elif len(x_shape) == len(y_shape): elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1):
node = helper.make_node( node = helper.make_node(
'Add', 'Add',
inputs=[op.input('X')[0], op.input('Y')[0]], inputs=[op.input('X')[0], op.input('Y')[0]],
outputs=op.output('Out')) outputs=op.output('Out'))
return node return node
else: else:
raise Excpetion("Unexpected situation happend in elementwise_add") raise Exception("Unexpected situation happend in elementwise_add")
def elementwise_sub(self, op, block): def elementwise_sub(self, op, block):
axis = op.attr('axis') axis = op.attr('axis')
...@@ -203,14 +204,14 @@ class OpSet9(object): ...@@ -203,14 +204,14 @@ class OpSet9(object):
inputs=[op.input('X')[0], temp_value], inputs=[op.input('X')[0], temp_value],
outputs=op.output('Out')) outputs=op.output('Out'))
return [shape_node, y_node, node] return [shape_node, y_node, node]
elif len(x_shape) == len(y_shape): elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1):
node = helper.make_node( node = helper.make_node(
'Sub', 'Sub',
inputs=[op.input('X')[0], op.input('Y')[0]], inputs=[op.input('X')[0], op.input('Y')[0]],
outputs=op.output('Out')) outputs=op.output('Out'))
return node return node
else: else:
raise Excpetion("Unexpected situation happend in elementwise_sub") raise Exception("Unexpected situation happend in elementwise_sub")
def pool2d(self, op, block): def pool2d(self, op, block):
pool_type = { pool_type = {
...@@ -763,14 +764,14 @@ class OpSet9(object): ...@@ -763,14 +764,14 @@ class OpSet9(object):
inputs=[op.input('X')[0], temp_value], inputs=[op.input('X')[0], temp_value],
outputs=op.output('Out')) outputs=op.output('Out'))
return [shape_node, y_node, node] return [shape_node, y_node, node]
elif len(x_shape) == len(y_shape): elif axis == -1 or axis == 0 or axis == (len(x_shape) - 1):
node = helper.make_node( node = helper.make_node(
'Mul', 'Mul',
inputs=[op.input('X')[0], op.input('Y')[0]], inputs=[op.input('X')[0], op.input('Y')[0]],
outputs=op.output('Out')) outputs=op.output('Out'))
return node return node
else: else:
raise Excpetion("Unexpected situation happend in elementwise_add") raise Exception("Unexpected situation happend in elementwise_mul")
return node return node
def feed(self, op, block): def feed(self, op, block):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册