提交 36f461ee 编写于 作者: C Channingss

update

上级 7c5fb3dd
...@@ -617,7 +617,6 @@ class OpSet9(): ...@@ -617,7 +617,6 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance( if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode): val_x, ONNXGraphDataNode):
if indices.dtype != 'int64':
node.fluid_code.add_layer( node.fluid_code.add_layer(
'cast', 'cast',
inputs=indices, inputs=indices,
...@@ -634,7 +633,6 @@ class OpSet9(): ...@@ -634,7 +633,6 @@ class OpSet9():
}) })
else: else:
from functools import reduce from functools import reduce
#indices_shape = [1,7]
reshape_shape = reduce(lambda x, y: x * y, indices_shape) reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape' indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer( node.fluid_code.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册