提交 a6c820c7 编写于 作者: W wjj19950828

fixed for ci

上级 a229bff2
...@@ -911,6 +911,13 @@ class OpSet9(): ...@@ -911,6 +911,13 @@ class OpSet9():
'index': indices.name}, 'index': indices.name},
outputs=[node.name], outputs=[node.name],
axis=axis) axis=axis)
# deal with indice is scalar(0D) Tensor
if isinstance(indices_values, int) and len(val_x_shape) > 1:
self.paddle_graph.add_layer(
'paddle.squeeze',
inputs={'x': node.name},
outputs=[node.name],
axis=[axis])
else: else:
# if val_x is DataNode, convert gather to embedding # if val_x is DataNode, convert gather to embedding
if axis == 0 and isinstance(val_x, ONNXGraphDataNode): if axis == 0 and isinstance(val_x, ONNXGraphDataNode):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册