提交 f3e37fcf 编写于 作者: C Channingss

add cast(int64) for embedding

上级 1eaaf5f6
...@@ -627,6 +627,12 @@ class OpSet9(): ...@@ -627,6 +627,12 @@ class OpSet9():
elif axis == 0 and len(indices_shape) > 1: elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance( if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode): val_x, ONNXGraphDataNode):
if indices.dtype != 'int64':
node.fluid_code.add_layer(
'cast',
inputs=indices,
output=indices,
param_attr={'dtype': string('int64')})
node.fluid_code.add_layer( node.fluid_code.add_layer(
'embedding', 'embedding',
inputs=indices, inputs=indices,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册