diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index 249896993b5961bbc45e33edd12b563fddfe83da..88fa59c5fb99adb058c4d6e1c2cc6ac519470942 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -49,6 +49,7 @@ class LookupSparseTableOp : public framework::OperatorBase { unsigned int seed = static_cast(Attr("seed")); float min = Attr("min"); float max = Attr("max"); + bool auto_grown_table = Attr("auto_grown_table"); PADDLE_ENFORCE(out_var->IsType(), "The type of Out var should be LodTensor."); @@ -71,8 +72,11 @@ class LookupSparseTableOp : public framework::OperatorBase { PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), framework::proto::VarType::FP32, "The sparse table only support FP32"); - auto non_keys_pair = w_t->Get(keys, out_t); + if (!auto_grown_table) { + PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast(0), + "there is some keys does exists in the sparse table."); + } auto value_shape = w_t->value().dims(); value_shape[0] = 1; for (const auto &it : non_keys_pair) { @@ -130,6 +134,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { "Note that if seed is not 0, this operator will always " "generate the same random numbers every time.") .SetDefault(0); + AddAttr("auto_grown_table", + "(bool default false)" + "Whether create new value if for nonexistent key.") + .SetDefault(true); AddComment(R"DOC( Lookup Sprase Tablel Operator.