提交 06c95839 编写于 作者: C Channingss

merge paddle/develop

上级 882a1abb
......@@ -454,9 +454,7 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
attr = {'axes': axes, 'name': string(node.layer_name)}
node.fluid_code.add_layer(
'unsqueeze', inputs=val_x, output=node, param_attr=attr)
if len(val_x.out_shapes[0]) == 0 and node.layer_name != 'x2paddle_465':
if len(val_x.out_shapes[0]) == 0:
if node.layer_name:
node.fluid_code.add_layer(
'reshape',
......@@ -620,8 +618,6 @@ class OpSet9():
param_attr=None)
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
if val_x.layer_name == 'x2paddle_460':
perm = list(range(3))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
......@@ -936,20 +932,21 @@ class OpSet9():
attr = {}
shape_value = _const_weight_or_none(val_shape)
shape_dims = len(val_shape.out_shapes[0])
if shape_value is not None:
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x},
output=node,
param_attr={'shape': shape_value.tolist()})
#elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
# 0]):
# node.fluid_code.add_layer(
# 'reshape',
# inputs={'x': val_x,
# 'shape': node.out_shapes[0]},
# output=node,
# param_attr=attr)
elif len(node.out_shapes[0]) > 0 and _is_static_shape(node.out_shapes[
0]):
node.fluid_code.add_layer(
'reshape',
inputs={'x': val_x,
'shape': node.out_shapes[0]},
output=node,
param_attr=attr)
elif val_shape.dtype == 'int64':
val_shape_cast = val_shape.layer_name + '_cast'
node.fluid_code.add_layer(
......@@ -1146,11 +1143,7 @@ class OpSet9():
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs = {"x": val_x, "y": val_y}
#node.fluid_code.add_layer(
# "matmul", inputs=inputs, output=node, param_attr=None)
#return
#if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
if val_x.layer_name == 'x2paddle_592':
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.layer_name + '_squeeze'
node.fluid_code.add_layer(
"squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册