提交 06c95839 编写于 作者: C Channingss

merge paddle/develop

上级 882a1abb
...@@ -454,9 +454,7 @@ class OpSet9(): ...@@ -454,9 +454,7 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
attr = {'axes': axes, 'name': string(node.layer_name)} attr = {'axes': axes, 'name': string(node.layer_name)}
node.fluid_code.add_layer( if len(val_x.out_shapes[0]) == 0:
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
if len(val_x.out_shapes[0]) == 0 and node.layer_name != 'x2paddle_465':
if node.layer_name: if node.layer_name:
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
...@@ -620,8 +618,6 @@ class OpSet9(): ...@@ -620,8 +618,6 @@ class OpSet9():
param_attr=None) param_attr=None)
elif axis > 0 and len(indices_shape) <= 1: elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
if val_x.layer_name == 'x2paddle_460':
perm = list(range(3))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm} attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans' name_trans = val_x.layer_name + '_trans'
...@@ -936,20 +932,21 @@ class OpSet9(): ...@@ -936,20 +932,21 @@ class OpSet9():
attr = {} attr = {}
shape_value = _const_weight_or_none(val_shape) shape_value = _const_weight_or_none(val_shape)
shape_dims = len(val_shape.out_shapes[0]) shape_dims = len(val_shape.out_shapes[0])
if shape_value is not None: if shape_value is not None:
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs={'x': val_x}, inputs={'x': val_x},
output=node, output=node,
param_attr={'shape': shape_value.tolist()}) param_attr={'shape': shape_value.tolist()})
#elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[ elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
# 0]): 0]):
# node.fluid_code.add_layer( node.fluid_code.add_layer(
# 'reshape', 'reshape',
# inputs={'x': val_x, inputs={'x': val_x,
# 'shape': node.out_shapes[0]}, 'shape': node.out_shapes[0]},
# output=node, output=node,
# param_attr=attr) param_attr=attr)
elif val_shape.dtype == 'int64': elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast' val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
...@@ -1146,11 +1143,7 @@ class OpSet9(): ...@@ -1146,11 +1143,7 @@ class OpSet9():
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y} inputs = {"x": val_x, "y": val_y}
#node.fluid_code.add_layer( if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
# "matmul", inputs=inputs, output=node, param_attr=None)
#return
#if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
if val_x.layer_name == 'x2paddle_592':
y_squeeze = val_y.layer_name + '_squeeze' y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer( node.fluid_code.add_layer(
"squeeze", "squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册