提交 bc5d32fe 编写于 作者: C Channingss

fix bug

上级 36f461ee
......@@ -672,7 +672,7 @@ class OpSet9():
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
attr_trans = {'perm': perm}
name_trans = val_x.layer_name + '_trans'
name_trans = val_x.layer_name + '_transpose'
node.fluid_code.add_layer(
'transpose',
inputs=val_x,
......@@ -684,8 +684,12 @@ class OpSet9():
'index': indices_reshape},
output=node,
param_attr=None)
input_transpose = node.layer_name + '_transpose'
node.fluid_code.add_layer(
'transpose', inputs=node, output=node, param_attr=attr_trans)
'transpose',
inputs=node,
output=input_transpose,
param_attr=attr_trans)
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
......@@ -694,7 +698,7 @@ class OpSet9():
reshaped_shape.append(i)
node.fluid_code.add_layer(
'reshape',
inputs=node,
inputs=input_transpose,
output=node,
param_attr={'shape': reshaped_shape})
......@@ -748,17 +752,21 @@ class OpSet9():
}
else:
if starts.dtype != 'int32':
starts_cast = starts.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=starts,
output=starts,
output=starts_cast,
param_attr={'dtype': string('int32')})
attr['starts'] = starts_cast
if ends.dtype != 'int32':
ends_cast = ens.layer_name + '_cast'
node.fluid_code.add_layer(
'cast',
inputs=ends,
output=ends,
output=ends_cast,
param_attr={'dtype': string('int32')})
attr['ends'] = ends_cast
else:
starts = node.get_attr('starts')
ends = node.get_attr('ends')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册