提交 8d205c85 编写于 作者: Q Qiao Longfei

add is_test for lookup_sparse_table

上级 9a6e2392
......@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
int64_t size_;
};
struct TensorFillVisitor {
TensorFillVisitor(framework::Tensor* dst, int64_t dst_offset, int64_t size,
float value)
: dst_(dst), dst_offset_(dst_offset), size_(size) {}
template <typename T>
void apply() const {
// TODO(Yancey1989): support other place
platform::CPUPlace cpu;
auto* tensor_data = dst_->mutable_data<T>(cpu);
auto* start = tensor_data + dst_offset_;
auto* end = start + size_;
std::fill(start, end, static_cast<T>(0.0));
}
framework::Tensor* dst_;
int64_t dst_offset_;
int64_t size_;
};
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version
......@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
: true;
}
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) {
int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown,
bool is_test) {
if (is_test) {
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
return -1;
} else {
return iter->second;
}
}
rwlock_->RDLock();
auto iter = id_to_index_.find(key);
if (iter == id_to_index_.end()) {
......@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
}
void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown) {
bool auto_grown, bool is_test) {
PADDLE_ENFORCE(value->IsInitialized(),
"The value tensor should be initialized.");
if (ids.numel() == 0) {
......@@ -183,11 +213,18 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
"output tensor should have the same shape with table "
"except the dims[0].");
for (int i = 0; i < ids.numel(); ++i) {
int64_t index = AutoGrownIndex(ids.data<int64_t>()[i], auto_grown);
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
int64_t index =
AutoGrownIndex(ids.data<int64_t>()[i], auto_grown, is_test);
if (index < 0) {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorFillVisitor(value, i * value_width, value_width, 0.0));
} else {
framework::VisitDataType(
framework::ToDataType(value_->type()),
TensorCopyVisitor(value, i * value_width, *value_.get(),
index * value_width, value_width));
}
}
}
}
......
......@@ -105,7 +105,7 @@ class SelectedRows {
* the value
*/
void Get(const framework::Tensor& ids, framework::Tensor* value,
bool auto_grown = false);
bool auto_grown = false, bool is_test = false);
/*
* @brief Get the index of the key from id_to_index_ map. If the key not
......@@ -118,7 +118,7 @@ class SelectedRows {
*
* @return index of the key.
*/
int64_t AutoGrownIndex(int64_t key, bool auto_grown);
int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false);
void SyncIndex();
......
......@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
data[i * embedding_width + j] = static_cast<float>(i);
}
}
ASSERT_EQ(table.AutoGrownIndex(10, true), 0);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1);
ASSERT_EQ(table.AutoGrownIndex(8, true), 1);
ASSERT_EQ(table.AutoGrownIndex(6, true), 2);
ASSERT_EQ(table.AutoGrownIndex(10, true, false), 0);
ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1);
ASSERT_EQ(table.AutoGrownIndex(6, true, false), 2);
for (int64_t i = 11; i < 20; i++) {
ASSERT_EQ(table.AutoGrownIndex(i, true, true), -1);
ASSERT_TRUE(!table.HasKey(i));
}
ASSERT_TRUE(table.HasKey(10));
ASSERT_TRUE(table.HasKey(8));
ASSERT_TRUE(table.HasKey(6));
......
......@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto out_var = scope.FindVar(Output("Out"));
auto w_var = scope.FindVar(Input("W"));
auto ids_var = scope.FindVar(Input("Ids"));
auto is_test = Attr<bool>("is_test");
PADDLE_ENFORCE(out_var->IsType<framework::LoDTensor>(),
"The type of Out var should be LodTensor.");
......@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()),
framework::proto::VarType::FP32,
"The sparse table only support FP32");
w_t->Get(ids_t, out_t, true);
w_t->Get(ids_t, out_t, true, is_test);
}
};
......@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool default false)"
"Whether create new value if for nonexistent key.")
.SetDefault(true);
AddAttr<bool>("is_test",
"In test mode, lookup_sparse_table will "
"return a default value for unknown id")
.SetDefault(false);
AddComment(R"DOC(
Lookup Sprase Tablel Operator.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册