diff --git a/paddle/pten/core/CMakeLists.txt b/paddle/pten/core/CMakeLists.txt index 7c8ace2bc7ef4289b05f9fa64f424bb3bd4dc6fd..f6f0e1f3e26ecc56e0d48bad0a22b8054f0664e2 100644 --- a/paddle/pten/core/CMakeLists.txt +++ b/paddle/pten/core/CMakeLists.txt @@ -16,7 +16,6 @@ cc_library(lod_utils SRCS lod_utils.cc DEPS enforce mixed_vector) cc_library(dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base) cc_library(pten_device_context SRCS device_context.cc DEPS tensor_base ) - cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_tensor) cc_test(unroll_array_ops_test SRCS unroll_array_ops_test.cc) @@ -28,6 +27,8 @@ elseif(WITH_ROCM) hip_test(dim_test SRCS dim_test.cu DEPS ddim) endif() +cc_library(selected_rows SRCS selected_rows.cc DEPS dense_tensor mixed_vector enforce ddim) + # Will remove once we implemented MKLDNN_Tensor if(WITH_MKLDNN) add_dependencies(dense_tensor mkldnn) diff --git a/paddle/pten/core/selected_rows.cc b/paddle/pten/core/selected_rows.cc new file mode 100644 index 0000000000000000000000000000000000000000..6f64602bdcf4d9f70d57a76677a1796b373808ac --- /dev/null +++ b/paddle/pten/core/selected_rows.cc @@ -0,0 +1,208 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/selected_rows.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/framework/data_type.h" + +namespace pten { + +struct ReAllocateVisitor { + ReAllocateVisitor(const pten::framework::DDim& dims, + pten::DenseTensor* tensor) + : dims_(dims), tensor_(tensor) {} + + template + void operator()() const { + pten::DenseTensor cpu_tensor; + paddle::platform::CPUPlace cpu; + T* ptr = cpu_tensor.mutable_data(dims_, cpu); + const T* old_ptr = + tensor_->memory_size() == 0 ? nullptr : tensor_->data(); + if (old_ptr != nullptr) { + std::copy(old_ptr, old_ptr + tensor_->numel(), ptr); + } + tensor_->ShareDataWith(cpu_tensor); + } + + pten::framework::DDim dims_; + pten::DenseTensor* tensor_; +}; + +struct TensorCopyVisitor { + TensorCopyVisitor(pten::DenseTensor* dst, + int64_t dst_offset, + const pten::DenseTensor src, + int64_t src_offset, + int64_t size) + : dst_(dst), + dst_offset_(dst_offset), + src_(src), + src_offset_(src_offset), + size_(size) {} + + template + void apply() const { + // TODO(Yancey1989): support other place + paddle::platform::CPUPlace cpu; + paddle::memory::Copy(cpu, + dst_->mutable_data(cpu) + dst_offset_, + cpu, + src_.data() + src_offset_, + size_ * sizeof(T)); + } + + pten::DenseTensor* dst_; + int64_t dst_offset_; + pten::DenseTensor src_; + int64_t src_offset_; + int64_t size_; +}; + +struct TensorFillVisitor { + TensorFillVisitor(pten::DenseTensor* dst, + int64_t dst_offset, + int64_t size, + float value) + : dst_(dst), dst_offset_(dst_offset), size_(size) {} + + template + void apply() const { + // TODO(qiao): support other place + paddle::platform::CPUPlace cpu; + auto* tensor_data = dst_->mutable_data(cpu); + auto* start = tensor_data + dst_offset_; + auto* end = start + size_; + std::fill(start, end, static_cast(0.0)); + } + + pten::DenseTensor* dst_; + int64_t dst_offset_; + int64_t size_; +}; + +bool SelectedRows::HasKey(int64_t key) const { + return std::find(rows_.begin(), rows_.end(), key) == rows_.end() ? false + : true; +} + +int64_t SelectedRows::AutoGrownIndex(int64_t key, + bool auto_grown, + bool is_test) { + if (is_test) { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + PADDLE_ENFORCE_EQ(auto_grown, + true, + paddle::platform::errors::NotFound( + "Input key(%lld) is not found.", key)); + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Row map size(%zu) should be equal to rows size(%zu).", + map_size, + vector_size)); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + int row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Selected rows is full, then length exceed the length of first " + "dimension (%d).", + row_num)); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } +} + +void SelectedRows::SyncIndex() { + rwlock_->WRLock(); + id_to_index_.clear(); + for (size_t i = 0; i < rows_.size(); ++i) { + id_to_index_[rows_[i]] = i; + } + rwlock_->UNLock(); +} + +void SelectedRows::Get(const pten::DenseTensor& ids, + pten::DenseTensor* value, + bool auto_grown, + bool is_test) { + PADDLE_ENFORCE_EQ(value->IsInitialized(), + true, + paddle::platform::errors::InvalidArgument( + "The value tensor is not initialized.")); + if (ids.numel() == 0) { + VLOG(3) << "keys is empty, please check data!"; + } else { + int64_t value_width = value_->numel() / value_->dims()[0]; + PADDLE_ENFORCE_EQ( + value_width, + value->numel() / value->dims()[0], + paddle::platform::errors::InvalidArgument( + "Output tensor should have the same shape with table " + "except the first dimmension, excepted value width not counting " + "the first dimension is %d, actual value width is %d.", + value_width, + value->numel() / value->dims()[0])); + for (int i = 0; i < ids.numel(); ++i) { + auto id = ids.data()[i]; + int64_t index = AutoGrownIndex(id, auto_grown, is_test); + if (index < 0) { + VLOG(5) << "id " << id << " not in the table, return 0"; + paddle::framework::VisitDataType( + value_->type(), + TensorFillVisitor(value, i * value_width, value_width, 0.0)); + } else { + paddle::framework::VisitDataType(value_->type(), + TensorCopyVisitor(value, + i * value_width, + *value_.get(), + index * value_width, + value_width)); + } + } + } +} +} // namespace pten diff --git a/paddle/pten/core/selected_rows.h b/paddle/pten/core/selected_rows.h new file mode 100644 index 0000000000000000000000000000000000000000..f5be0a906dbdbb5339f995430a95a4be106a4a62 --- /dev/null +++ b/paddle/pten/core/selected_rows.h @@ -0,0 +1,164 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include // NOLINT +#include +#include +#include + +#include "paddle/pten/common/place.h" +#include "paddle/pten/core/ddim.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/utils/rw_lock.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/enforce.h" + +namespace pten { +class SelectedRows { + /* + * @brief We can use the SelectedRows structure to reproduce a sparse table. + * A sparse table is a key-value structure that the key is an `int64_t`, + * and the value is a Tensor which the first dimension is 0. + * You can use the following interface to operate the sparse table, and you + * can find + * some detail information from the comments of each interface: + * + * HasKey(key), whether the sparse table has the specified key. + * Set(key, value), set a key-value pair into the sparse table. + * Get(keys, value*), get value by given key list and apply it to the given + * value pointer + * with the specified offset. + * + */ + public: + SelectedRows(const std::vector& rows, const int64_t& height) + : rows_(rows), height_(height) { + value_.reset(new pten::DenseTensor()); + rwlock_.reset(new RWLock); + } + + SelectedRows() { + height_ = 0; + value_.reset(new pten::DenseTensor()); + rwlock_.reset(new RWLock); + } + + const pten::Place& place() const { return value_->place(); } + + const pten::DenseTensor& value() const { return *value_; } + + pten::DenseTensor* mutable_value() { return value_.get(); } + + int64_t height() const { return height_; } + + void set_height(int64_t height) { height_ = height; } + + const paddle::framework::Vector& rows() const { return rows_; } + + paddle::framework::Vector* mutable_rows() { return &rows_; } + + void set_rows(const paddle::framework::Vector& rows) { + rows_ = rows; + } + + /* + * @brief Get the index of key in rows + * + * @return -1 if the key does not exists. + */ + int64_t Index(int64_t key) const { + auto it = std::find(rows_.begin(), rows_.end(), key); + if (it == rows_.end()) { + PADDLE_THROW(paddle::platform::errors::NotFound( + "Input id (%lld) is not in current rows table.", key)); + } + return static_cast(std::distance(rows_.begin(), it)); + } + + /* + * @brief whether has the specified key in the table. + * + * @return true if the key is exists. + */ + bool HasKey(int64_t key) const; + + /* + * @brief Get value by the key list. + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. + * + * @return a list of pair which contains the non-exists key and the index in + * the value + */ + void Get(const pten::DenseTensor& ids, + pten::DenseTensor* value, + bool auto_grown = false, + bool is_test = false); + + /* + * @brief Get the index of the key from id_to_index_ map. If the key not + * exist, + * add the key into id_to_index_. + * + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. + * + * @return index of the key. + */ + int64_t AutoGrownIndex(int64_t key, bool auto_grown, bool is_test = false); + + /* + * @brief Get the index of the key from id_to_index_ map. + */ + inline int64_t GetIndexFromId(int64_t key) const { + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + return -1; + } else { + return iter->second; + } + } + + void SyncIndex(); + /* + * @brief Get complete Dims before + */ + pten::framework::DDim GetCompleteDims() const { + std::vector dims = vectorize(value_->dims()); + dims[0] = height_; + return pten::framework::make_ddim(dims); + } + + private: + // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. + // SelectedRows are simply concated when adding together. Until a + // SelectedRows add a Tensor, will the duplicate rows be handled. + paddle::framework::Vector rows_; + std::unordered_map + id_to_index_; // should not be used when rows_ has duplicate member + std::unique_ptr value_{nullptr}; + int64_t height_; // height indicates the underline tensor's height + std::unique_ptr rwlock_{nullptr}; +}; + +} // namespace pten diff --git a/paddle/pten/core/utils/rw_lock.h b/paddle/pten/core/utils/rw_lock.h new file mode 100644 index 0000000000000000000000000000000000000000..7bd190c901bc640b657231d937f8b92126b21532 --- /dev/null +++ b/paddle/pten/core/utils/rw_lock.h @@ -0,0 +1,105 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#if !defined(_WIN32) +#include +#else +#include // NOLINT +#endif // !_WIN32 + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/enforce.h" + +namespace pten { + +#if !defined(_WIN32) +struct RWLock { + RWLock() { pthread_rwlock_init(&lock_, nullptr); } + + ~RWLock() { pthread_rwlock_destroy(&lock_); } + + inline void RDLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), + 0, + paddle::platform::errors::External( + "The pthread failed to acquire read lock.")); + } + + inline void WRLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), + 0, + paddle::platform::errors::External( + "The pthread failed to acquire write lock.")); + } + + inline void UNLock() { + PADDLE_ENFORCE_EQ( + pthread_rwlock_unlock(&lock_), + 0, + paddle::platform::errors::External("The pthread failed to unlock.")); + } + + private: + pthread_rwlock_t lock_; +}; +// TODO(paddle-dev): Support RWLock for WIN32 for correctness. +#else +// https://stackoverflow.com/questions/7125250/making-pthread-rwlock-wrlock-recursive +// In windows, rw_lock seems like a hack. Use empty object and do nothing. +struct RWLock { + // FIXME(minqiyang): use mutex here to do fake lock + inline void RDLock() { mutex_.lock(); } + + inline void WRLock() { mutex_.lock(); } + + inline void UNLock() { mutex_.unlock(); } + + private: + std::mutex mutex_; +}; +#endif + +class AutoWRLock { + public: + explicit AutoWRLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); } + + ~AutoWRLock() { UnLock(); } + + private: + inline void Lock() { lock_->WRLock(); } + + inline void UnLock() { lock_->UNLock(); } + + private: + RWLock* lock_; +}; + +class AutoRDLock { + public: + explicit AutoRDLock(RWLock* rw_lock) : lock_(rw_lock) { Lock(); } + + ~AutoRDLock() { UnLock(); } + + private: + inline void Lock() { lock_->RDLock(); } + + inline void UnLock() { lock_->UNLock(); } + + private: + RWLock* lock_; +}; + +} // namespace pten diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 117d6a29252c1b286153649daf8c6c775dcb7a2a..363a57f036b9bb39fef9007e80073f83f3045bea 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -4,3 +4,10 @@ cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) cc_test(test_pten_device_context SRCS test_device_context.cc DEPS pten_context cpu_context) +cc_test(selected_rows_test SRCS test_selected_rows.cc DEPS selected_rows) +if(WITH_TESTING AND TEST selected_rows_test) + set_tests_properties(selected_rows_test PROPERTIES TIMEOUT 120) +endif() +if (NOT WIN32) +cc_test(test_rw_lock SRCS test_rw_lock.cc) +endif (NOT WIN32) diff --git a/paddle/pten/tests/core/test_rw_lock.cc b/paddle/pten/tests/core/test_rw_lock.cc new file mode 100644 index 0000000000000000000000000000000000000000..5cd81fa76b40e9b88fd19ce342d0a7fc0280ca2c --- /dev/null +++ b/paddle/pten/tests/core/test_rw_lock.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/pten/core/utils/rw_lock.h" + +#include // NOLINT +#include // NOLINT + +namespace pten { +namespace tests { + +void f1(pten::RWLock *lock) { + lock->RDLock(); + lock->UNLock(); +} + +TEST(RWLOCK, read_read) { + pten::RWLock lock; + lock.RDLock(); + std::thread t1(f1, &lock); + std::thread t2(f1, &lock); + t1.join(); + t2.join(); + lock.UNLock(); +} + +void f2(pten::RWLock *lock, std::vector *result) { + lock->RDLock(); + ASSERT_EQ(result->size(), 0UL); + lock->UNLock(); +} + +void f3(pten::RWLock *lock, std::vector *result) { + lock->WRLock(); + result->push_back(1); + lock->UNLock(); +} + +TEST(RWLOCK, read_write) { + pten::RWLock lock; + std::vector result; + + lock.RDLock(); + std::thread t1(f2, &lock, &result); + t1.join(); + std::thread t2(f3, &lock, &result); + std::this_thread::sleep_for(std::chrono::seconds(1)); + ASSERT_EQ(result.size(), 0UL); + lock.UNLock(); + t2.join(); + ASSERT_EQ(result.size(), 1UL); +} + +void f4(pten::RWLock *lock, std::vector *result) { + lock->RDLock(); + ASSERT_EQ(result->size(), 1UL); + lock->UNLock(); +} + +TEST(RWLOCK, write_read) { + pten::RWLock lock; + std::vector result; + + lock.WRLock(); + std::thread t1(f4, &lock, &result); + std::this_thread::sleep_for(std::chrono::seconds(1)); + result.push_back(1); + lock.UNLock(); + t1.join(); +} +} // namespace tests +} // namespace pten diff --git a/paddle/pten/tests/core/test_selected_rows.cc b/paddle/pten/tests/core/test_selected_rows.cc new file mode 100644 index 0000000000000000000000000000000000000000..81c7ff4a838a702a1c32df2fb4a7f082b1b39f3b --- /dev/null +++ b/paddle/pten/tests/core/test_selected_rows.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include // NOLINT + +#include "gtest/gtest.h" +#include "paddle/pten/core/selected_rows.h" + +namespace pten { +namespace tests { +class SelectedRowsTester : public ::testing::Test { + public: + void SetUp() override { + std::vector rows{0, 4, 7}; + int64_t height = 10; + int64_t row_numel = 100; + selected_rows_.reset(new SelectedRows(rows, height)); + + pten::DenseTensor* value = selected_rows_->mutable_value(); + auto* data = value->mutable_data( + pten::framework::make_ddim( + {static_cast(rows.size()), row_numel}), + place_); + for (int64_t i = 0; i < value->numel(); ++i) { + data[i] = static_cast(i); + } + } + + protected: + pten::CPUPlace place_; + std::unique_ptr selected_rows_{nullptr}; +}; + +TEST_F(SelectedRowsTester, height) { ASSERT_EQ(selected_rows_->height(), 10); } + +TEST_F(SelectedRowsTester, dims) { + ASSERT_EQ(selected_rows_->value().dims(), + pten::framework::make_ddim({3, 100})); +} + +TEST_F(SelectedRowsTester, complete_dims) { + ASSERT_EQ(selected_rows_->GetCompleteDims(), + pten::framework::make_ddim({10, 100})); +} + +TEST(SelectedRows, SparseTable) { + pten::CPUPlace cpu; + SelectedRows table; + + int64_t table_size = 100; + int64_t embedding_width = 8; + // initialize a sparse table + table.mutable_value()->Resize( + pten::framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } + ASSERT_EQ(table.AutoGrownIndex(10, true, false), 0); + ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1); + ASSERT_EQ(table.AutoGrownIndex(8, true, false), 1); + ASSERT_EQ(table.AutoGrownIndex(6, true, false), 2); + for (int64_t i = 11; i < 20; i++) { + ASSERT_EQ(table.AutoGrownIndex(i, true, true), -1); + ASSERT_TRUE(!table.HasKey(i)); + } + ASSERT_TRUE(table.HasKey(10)); + ASSERT_TRUE(table.HasKey(8)); + ASSERT_TRUE(table.HasKey(6)); + ASSERT_EQ(table.rows().size(), 3UL); + + pten::DenseTensor ids; + ids.Resize(pten::framework::make_ddim({4})); + auto* ids_data = ids.mutable_data(cpu); + ids_data[0] = static_cast(6); + ids_data[1] = static_cast(6); + ids_data[2] = static_cast(8); + ids_data[3] = static_cast(10); + + pten::DenseTensor get_value; + auto* value_data = get_value.mutable_data( + pten::framework::make_ddim({4, embedding_width}), cpu); + table.Get(ids, &get_value); + + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[0 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[1 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[2 * embedding_width + j], 1); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[3 * embedding_width + j], 0); + } +} + +void f1(SelectedRows* table, int table_size) { + for (int i = 1000000; i > 0; --i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} + +void f2(SelectedRows* table, int table_size) { + for (int i = 0; i < 1000000; ++i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} + +void f3(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 100000; i > 0; --i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f3 run time:" << t2 - t1 << std::endl; +} + +void f4(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 0; i < 100000; ++i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f4 run time:" << t2 - t1 << std::endl; +} + +TEST(SelectedRows, MultiThreadAutoIndex) { + pten::CPUPlace cpu; + SelectedRows table; + + int64_t table_size = 100000; + int64_t embedding_width = 8; + // initialize a sparse table + table.mutable_value()->Resize( + pten::framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } + + std::thread t1(f1, &table, table_size); + std::thread t11(f1, &table, table_size); + std::thread t2(f2, &table, table_size); + std::thread t22(f2, &table, table_size); + t1.join(); + t11.join(); + t2.join(); + t22.join(); + std::thread t3(f3, &table, table_size); + std::thread t4(f4, &table, table_size); + t3.join(); + t4.join(); +} +} // namespace tests +} // namespace pten