未验证 提交 aa25d5d4 编写于 作者: J Jason 提交者: GitHub

Merge pull request #332 from Channingss/reshape_bug

fix bug: Reshape attr:shape may come from Gather by a scalar indices
...@@ -814,11 +814,13 @@ class OpSet9(): ...@@ -814,11 +814,13 @@ class OpSet9():
inputs=val_shape, inputs=val_shape,
output=val_shape_cast, output=val_shape_cast,
param_attr={'dtype': string('int32')}) param_attr={'dtype': string('int32')})
node.fluid_code.add_layer( # shape may be [], come form Gather by scalar indices
'reshape', if len(val_shape.out_shapes[0]) > 0:
inputs=val_shape_cast, node.fluid_code.add_layer(
output=val_shape_cast, 'reshape',
param_attr={'shape': val_shape.out_shapes[0]}) inputs=val_shape_cast,
output=val_shape_cast,
param_attr={'shape': val_shape.out_shapes[0]})
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs={'x': val_x, inputs={'x': val_x,
...@@ -826,11 +828,13 @@ class OpSet9(): ...@@ -826,11 +828,13 @@ class OpSet9():
output=node, output=node,
param_attr=attr) param_attr=attr)
else: else:
node.fluid_code.add_layer( # shape may be [], come form Gather by scalar indices
'reshape', if len(val_shape.out_shapes[0]) > 0:
inputs=val_shape, node.fluid_code.add_layer(
output=val_shape, 'reshape',
param_attr={'shape': val_shape.out_shapes[0]}) inputs=val_shape,
output=val_shape,
param_attr={'shape': val_shape.out_shapes[0]})
node.fluid_code.add_layer( node.fluid_code.add_layer(
'reshape', 'reshape',
inputs={'x': val_x, inputs={'x': val_x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册