提交 bc5d32fe 编写于 作者: C Channingss

fix bug

上级 36f461ee
...@@ -672,7 +672,7 @@ class OpSet9(): ...@@ -672,7 +672,7 @@ class OpSet9():
perm = list(range(len(val_x.out_shapes[0]))) perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:] perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm} attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans' name_trans = val_x.layer_name + '_transpose'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'transpose', 'transpose',
inputs=val_x, inputs=val_x,
...@@ -684,8 +684,12 @@ class OpSet9(): ...@@ -684,8 +684,12 @@ class OpSet9():
'index': indices_reshape}, 'index': indices_reshape},
output=node, output=node,
param_attr=None) param_attr=None)
input_transpose = node.layer_name + '_transpose'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans) 'transpose',
inputs=node,
output=input_transpose,
param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0] val_x_shape = val_x.out_shapes[0]
reshaped_shape = [] reshaped_shape = []
for i in perm: for i in perm:
...@@ -694,7 +698,7 @@ class OpSet9(): ...@@ -694,7 +698,7 @@ class OpSet9():
reshaped_shape.append(i) reshaped_shape.append(i)
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs=node, inputs=input_transpose,
output=node, output=node,
param_attr={'shape': reshaped_shape}) param_attr={'shape': reshaped_shape})
...@@ -748,17 +752,21 @@ class OpSet9(): ...@@ -748,17 +752,21 @@ class OpSet9():
} }
else: else:
if starts.dtype != 'int32': if starts.dtype != 'int32':
starts_cast = starts.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', 'cast',
inputs=starts, inputs=starts,
output=starts, output=starts_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype': string('int32')})
attr['starts'] = starts_cast
if ends.dtype != 'int32': if ends.dtype != 'int32':
ends_cast = ens.layer_name + '_cast'
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', 'cast',
inputs=ends, inputs=ends,
output=ends, output=ends_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype': string('int32')})
attr['ends'] = ends_cast
else: else:
starts = node.get_attr('starts') starts = node.get_attr('starts')
ends = node.get_attr('ends') ends = node.get_attr('ends')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册