提交 8d205c85 编写于 作者: Q Qiao Longfei

add is_test for lookup_sparse_table

上级 9a6e2392
...@@ -63,6 +63,26 @@ struct TensorCopyVisitor { ...@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
int64_t size_; int64_t size_;
}; };
struct TensorFillVisitor {
TensorFillVisitor(framework::Tensor* dst, int64_t dst_offset, int64_t size,
float value)
: dst_(dst), dst_offset_(dst_offset), size_(size) {}
template <typename T>
void apply() const {
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
auto* tensor_data = dst_->mutable_data<T>(cpu);
auto* start = tensor_data + dst_offset_;
auto* end = start + size_;
std::fill(start, end, static_cast<T>(0.0));
}
framework::Tensor* dst_;
int64_t dst_offset_;
int64_t size_;
};
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version { // the 1st field, uint32_t version
...@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const { ...@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
: true; : true;
} }
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) { int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
bool is_test) {
if (is_test) {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
rwlock_->RDLock(); rwlock_->RDLock();
auto iter = id_to_index_.find(key); auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) { if (iter == id_to_index_.end()) {
...@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() { ...@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
} }
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown) { bool auto_grown, bool is_test) {
PADDLE_ENFORCE(value->IsInitialized(), PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized."); "The value tensor should be initialized.");
if (ids.numel() == 0) { if (ids.numel() == 0) {
...@@ -183,13 +213,20 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, ...@@ -183,13 +213,20 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
"output tensor should have the same shape with table " "output tensor should have the same shape with table "
"except the dims[0]."); "except the dims[0].");
for (int i = 0; i < ids.numel(); ++i) { for (int i = 0; i < ids.numel(); ++i) {
int64_t index = AutoGrownIndex(ids.data<int64_t>()[i], auto_grown); int64_t index =
AutoGrownIndex(ids.data<int64_t>()[i], auto_grown, is_test);
if (index < 0) {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else {
framework::VisitDataType( framework::VisitDataType(
framework::ToDataType(value_->type()), framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(), TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width)); index * value_width, value_width));
} }
} }
}
} }
} // namespace framework } // namespace framework
......
...@@ -105,7 +105,7 @@ class SelectedRows { ...@@ -105,7 +105,7 @@ class SelectedRows {
* the value * the value
*/ */
void Get(const framework::Tensor& ids, framework::Tensor* value, void Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown = false); bool auto_grown = false, bool is_test = false);
/* /*
* @brief Get the index of the key from id_to_index_ map. If the key not * @brief Get the index of the key from id_to_index_ map. If the key not
...@@ -118,7 +118,7 @@ class SelectedRows { ...@@ -118,7 +118,7 @@ class SelectedRows {
* *
* @return index of the key. * @return index of the key.
*/ */
int64_t AutoGrownIndex(int64_t key, bool auto_grown); int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
void SyncIndex(); void SyncIndex();
......
...@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) { ...@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
data[i * embedding_width + j] = static_cast<float>(i); data[i * embedding_width + j] = static_cast<float>(i);
} }
} }
ASSERT_EQ(table.AutoGrownIndex(10, true), 0); ASSERT_EQ(table.AutoGrownIndex(10, true, false), 0);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1); ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1); ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
ASSERT_EQ(table.AutoGrownIndex(6, true), 2); ASSERT_EQ(table.AutoGrownIndex(6, true, false), 2);
for (int64_t i = 11; i < 20; i++) {
ASSERT_EQ(table.AutoGrownIndex(i, true, true), -1);
ASSERT_TRUE(!table.HasKey(i));
}
ASSERT_TRUE(table.HasKey(10)); ASSERT_TRUE(table.HasKey(10));
ASSERT_TRUE(table.HasKey(8)); ASSERT_TRUE(table.HasKey(8));
ASSERT_TRUE(table.HasKey(6)); ASSERT_TRUE(table.HasKey(6));
......
...@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto out_var = scope.FindVar(Output("Out")); auto out_var = scope.FindVar(Output("Out"));
auto w_var = scope.FindVar(Input("W")); auto w_var = scope.FindVar(Input("W"));
auto ids_var = scope.FindVar(Input("Ids")); auto ids_var = scope.FindVar(Input("Ids"));
auto is_test = Attr<bool>("is_test");
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(), PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
"The type of Out var should be LodTensor."); "The type of Out var should be LodTensor.");
...@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase { ...@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
"The sparse table only support FP32"); "The sparse table only support FP32");
w_t->Get(ids_t, out_t, true); w_t->Get(ids_t, out_t, true, is_test);
} }
}; };
...@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool default false)" "(bool default false)"
"Whether create new value if for nonexistent key.") "Whether create new value if for nonexistent key.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("is_test",
"In test mode, lookup_sparse_table will "
"return a default value for unknown id")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Sprase Tablel Operator. Lookup Sprase Tablel Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册