diff --git a/models/recall/word2vec/model.py b/models/recall/word2vec/model.py index 4b19513d627722310ee4fb4d39fbdb9b80b327bf..8775e08bf094b22ff9070d5d7f6b08cab9a3d873 100755 --- a/models/recall/word2vec/model.py +++ b/models/recall/word2vec/model.py @@ -209,7 +209,7 @@ class Model(ModelBase): emb_all_label_l2 = fluid.layers.l2_normalize(x=emb_all_label, axis=1) dist = fluid.layers.matmul( x=target, y=emb_all_label_l2, transpose_y=True) - values, pred_idx = fluid.layers.topk(input=dist, 1) + values, pred_idx = fluid.layers.topk(input=dist, k=1) label = fluid.layers.expand( fluid.layers.unsqueeze( inputs[3], axes=[1]), expand_times=[1, 1])