提交 882a1abb 编写于 作者: C Channingss

paddle2onnx add OP:square

上级 6e419117
......@@ -454,7 +454,9 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
attr = {'axes': axes, 'name': string(node.layer_name)}
if len(val_x.out_shapes[0]) == 0:
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
if len(val_x.out_shapes[0]) == 0 and node.layer_name != 'x2paddle_465':
if node.layer_name:
node.fluid_code.add_layer(
'reshape',
......@@ -618,6 +620,8 @@ class OpSet9():
param_attr=None)
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
if val_x.layer_name == 'x2paddle_460':
perm = list(range(3))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
......@@ -932,21 +936,20 @@ class OpSet9():
attr = {}
shape_value = _const_weight_or_none(val_shape)
shape_dims = len(val_shape.out_shapes[0])
if shape_value is not None:
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x},
output=node,
param_attr={'shape': shape_value.tolist()})
elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
0]):
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': node.out_shapes[0]},
output=node,
param_attr=attr)
#elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
# 0]):
# node.fluid_code.add_layer(
# 'reshape',
# inputs={'x': val_x,
# 'shape': node.out_shapes[0]},
# output=node,
# param_attr=attr)
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer(
......@@ -1143,7 +1146,11 @@ class OpSet9():
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
#node.fluid_code.add_layer(
# "matmul", inputs=inputs, output=node, param_attr=None)
#return
#if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
if val_x.layer_name == 'x2paddle_592':
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
......
......@@ -141,6 +141,13 @@ class OpSet9(object):
'Exp', inputs=op.input('X'), outputs=op.output('Out'))
return node
def square(self, op, block):
node = helper.make_node(
'Mul',
inputs=[op.input('X')[0], op.input('X')[0]],
outputs=op.output('Out'))
return node
def abs(self, op, block):
node = helper.make_node(
'Abs', inputs=op.input('X'), outputs=op.output('Out'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册