提交 36f461ee 编写于 作者: C Channingss

update

上级 7c5fb3dd
......@@ -617,12 +617,11 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
if indices.dtype != 'int64':
node.fluid_code.add_layer(
'cast',
inputs=indices,
output=indices,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'cast',
inputs=indices,
output=indices,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer(
'embedding',
inputs=indices,
......@@ -634,7 +633,6 @@ class OpSet9():
})
else:
from functools import reduce
#indices_shape = [1,7]
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.layer_name + '_shape'
node.fluid_code.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册