提交 9e7e3739 编写于 作者: C Channingss

fix bug: Reshape attr:shape may come from Gather by a scalar indices

上级 9b8cd312
...@@ -814,6 +814,8 @@ class OpSet9(): ...@@ -814,6 +814,8 @@ class OpSet9():
inputs=val_shape, inputs=val_shape,
output=val_shape_cast, output=val_shape_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype': string('int32')})
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs=val_shape_cast, inputs=val_shape_cast,
...@@ -826,6 +828,8 @@ class OpSet9(): ...@@ -826,6 +828,8 @@ class OpSet9():
output=node, output=node,
param_attr=attr) param_attr=attr)
else: else:
# shape may be [], come form Gather by scalar indices
if len(val_shape.out_shapes[0]) > 0:
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs=val_shape, inputs=val_shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册