提交 36f461ee 编写于 作者: C Channingss

update

上级 7c5fb3dd
...@@ -617,12 +617,11 @@ class OpSet9(): ...@@ -617,12 +617,11 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance( if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode): val_x, ONNXGraphDataNode):
if indices.dtype != 'int64': node.fluid_code.add_layer(
node.fluid_code.add_layer( 'cast',
'cast', inputs=indices,
inputs=indices, output=indices,
output=indices, param_attr={'dtype': string('int64')})
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer( node.fluid_code.add_layer(
'embedding', 'embedding',
inputs=indices, inputs=indices,
...@@ -634,7 +633,6 @@ class OpSet9(): ...@@ -634,7 +633,6 @@ class OpSet9():
}) })
else: else:
from functools import reduce from functools import reduce
#indices_shape = [1,7]
reshape_shape = reduce(lambda x, y: x * y, indices_shape) reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape' indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer( node.fluid_code.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册