提交 38190066 编写于 作者: C Channingss

modify embeding weight

上级 a955e6d0
...@@ -721,11 +721,11 @@ class OpSet9(): ...@@ -721,11 +721,11 @@ class OpSet9():
op_name = name_generator("embedding", self.nn_name2id) op_name = name_generator("embedding", self.nn_name2id)
output_name = node.name output_name = node.name
layer_outputs = [op_name, output_name] layer_outputs = [op_name, output_name]
self.weights['.weight'] = _const_weight_or_none(val_x.name)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.nn.Embedding', 'paddle.nn.Embedding',
inputs={"x": indices_cast}, inputs={"x": indices_cast},
outputs=layer_outputs, outputs=layer_outputs,
weight_attr=string(val_x.name),
num_embeddings=val_x.out_shapes[0][0], num_embeddings=val_x.out_shapes[0][0],
embedding_dim=val_x.out_shapes[0][1]) embedding_dim=val_x.out_shapes[0][1])
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册