From 653fad08f8f7a717c20756eebfc1b4ab860d4618 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 17 Aug 2018 10:52:09 +0800 Subject: [PATCH] Optimize selected rows for dist lookup table with pthread rwlock (#12635) Optimize selected rows for dist lookup table with rwlock --- paddle/fluid/framework/rw_lock.h | 46 ++++++ paddle/fluid/framework/selected_rows.cc | 110 ++++++++------ paddle/fluid/framework/selected_rows.h | 63 ++++---- paddle/fluid/framework/selected_rows_test.cc | 143 +++++++++++++++--- .../operators/distributed/rpc_server_test.cc | 3 +- .../fluid/operators/lookup_sparse_table_op.cc | 53 +------ paddle/fluid/operators/sgd_op.h | 2 +- paddle/fluid/operators/uniform_random_op.cc | 4 +- paddle/fluid/pybind/pybind.cc | 1 + .../unittests/test_lookup_sparse_table_op.py | 57 +++---- .../fluid/tests/unittests/test_sgd_op.py | 1 + 11 files changed, 298 insertions(+), 185 deletions(-) create mode 100644 paddle/fluid/framework/rw_lock.h diff --git a/paddle/fluid/framework/rw_lock.h b/paddle/fluid/framework/rw_lock.h new file mode 100644 index 00000000000..2a4009b765e --- /dev/null +++ b/paddle/fluid/framework/rw_lock.h @@ -0,0 +1,46 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace framework { + +struct RWLock { + RWLock() { pthread_rwlock_init(&lock_, nullptr); } + + ~RWLock() { pthread_rwlock_destroy(&lock_); } + + void RDLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_rdlock(&lock_), 0, + "acquire read lock failed"); + } + + void WRLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_wrlock(&lock_), 0, + "acquire write lock failed"); + } + + void UNLock() { + PADDLE_ENFORCE_EQ(pthread_rwlock_unlock(&lock_), 0, "unlock failed"); + } + + private: + pthread_rwlock_t lock_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 06ed87e7e8a..c202b0a5be1 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -120,66 +120,76 @@ bool SelectedRows::HasKey(int64_t key) const { : true; } -std::vector> SelectedRows::Get( - const std::vector& keys, framework::Tensor* value) const { +int64_t SelectedRows::AutoGrownIndex(int64_t key, bool auto_grown) { + rwlock_->RDLock(); + auto iter = id_to_index_.find(key); + if (iter == id_to_index_.end()) { + rwlock_->UNLock(); + if (!auto_grown) { + PADDLE_THROW("key %d not found", key); + } + rwlock_->WRLock(); + auto map_size = id_to_index_.size(); + auto vector_size = rows_.size(); + if (map_size != vector_size) { + rwlock_->UNLock(); + PADDLE_THROW( + "id_to_index_ size %d should have the same size with rows_ %d", + map_size, vector_size); + } + auto write_iter = id_to_index_.find(key); + if (write_iter == id_to_index_.end()) { + size_t row_num = rows_.size(); + if (row_num == value_->dims()[0]) { + rwlock_->UNLock(); + PADDLE_THROW("selected rows is full, then length exceed %d", row_num); + } + // key logic to put a key into id_to_index_ + rows_.push_back(key); + auto index = static_cast(rows_.size() - 1); + id_to_index_[key] = index; + rwlock_->UNLock(); + return index; + } else { + auto index = write_iter->second; + rwlock_->UNLock(); + return index; + } + } else { + auto index = iter->second; + rwlock_->UNLock(); + return index; + } +} + +void SelectedRows::SyncIndex() { + rwlock_->WRLock(); + id_to_index_.clear(); + for (size_t i = 0; i < rows_.size(); ++i) { + id_to_index_[rows_[i]] = i; + } + rwlock_->UNLock(); +} + +void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value, + bool auto_grown) { PADDLE_ENFORCE(value->IsInitialized(), "The value tensor should be initialized."); - std::vector> non_keys_pair; - if (keys.empty()) { + if (ids.numel() == 0) { VLOG(3) << "keys is empty, please check data!"; } else { int64_t value_width = value_->numel() / value_->dims()[0]; PADDLE_ENFORCE_EQ(value_width, value->numel() / value->dims()[0], "output tensor should have the same shape with table " "except the dims[0]."); - - for (size_t i = 0; i < keys.size(); ++i) { - int64_t index = Index(keys[i]); - if (index == -1) { - non_keys_pair.push_back( - std::make_pair(keys[i], static_cast(i))); - } else { - framework::VisitDataType( - framework::ToDataType(value_->type()), - TensorCopyVisitor(value, i * value_width, *value_.get(), - index * value_width, value_width)); - } + for (size_t i = 0; i < ids.numel(); ++i) { + int64_t index = AutoGrownIndex(ids.data()[i], auto_grown); + framework::VisitDataType( + framework::ToDataType(value_->type()), + TensorCopyVisitor(value, i * value_width, *value_.get(), + index * value_width, value_width)); } } - return non_keys_pair; -} - -bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { - PADDLE_ENFORCE(value.IsInitialized(), "The value should be initialized."); - if (value_->IsInitialized()) { - PADDLE_ENFORCE_EQ( - value.type(), value_->type(), - "The type of the value should be same with the original value"); - } - PADDLE_ENFORCE_EQ(value.dims()[0], static_cast(1), - "The first dim of value should be 1."); - std::lock_guard lock(*auto_grown_mutex_.get()); - auto index = Index(key); - bool is_new_key = false; - if (index == -1) { - rows_.push_back(key); - index = rows_.size() - 1; - is_new_key = true; - // whether need to resize the table - if (static_cast(rows_.size()) > value_->dims()[0]) { - auto dims = value_->dims(); - dims[0] = (dims[0] + 1) << 1; - framework::VisitDataType(framework::ToDataType(value.type()), - ReAllocateVisitor(dims, value_.get())); - } - } - - framework::VisitDataType( - framework::ToDataType(value.type()), - TensorCopyVisitor(value_.get(), - index * value_->numel() / value_->dims()[0], value, - static_cast(0), value.numel())); - return is_new_key; } } // namespace framework diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 7160670ddd2..daf5e95304f 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -17,10 +17,12 @@ limitations under the License. */ #include #include #include // NOLINT +#include #include #include #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/memory/memcpy.h" @@ -48,13 +50,13 @@ class SelectedRows { SelectedRows(const std::vector& rows, const int64_t& height) : rows_(rows), height_(height) { value_.reset(new Tensor()); - auto_grown_mutex_.reset(new std::mutex); + rwlock_.reset(new RWLock); } SelectedRows() { height_ = 0; value_.reset(new Tensor()); - auto_grown_mutex_.reset(new std::mutex); + rwlock_.reset(new RWLock); } platform::Place place() const { return value_->place(); } @@ -74,47 +76,51 @@ class SelectedRows { void set_rows(const Vector& rows) { rows_ = rows; } /* - * @brief wheter has the specified key in the table. + * @brief Get the index of key in rows + * + * @return -1 if the key does not exists. + */ + int64_t Index(int64_t key) const { + auto it = std::find(rows_.begin(), rows_.end(), key); + if (it == rows_.end()) { + PADDLE_THROW("id %s not in table", key); + } + return static_cast(std::distance(rows_.begin(), it)); + } + + /* + * @brief whether has the specified key in the table. * * @return true if the key is exists. */ bool HasKey(int64_t key) const; /* - * @brief Get value by the key list, if the + * @brief Get value by the key list. + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. * * @return a list of pair which contains the non-exists key and the index in * the value */ - std::vector> Get(const std::vector& keys, - framework::Tensor* value) const; + void Get(const framework::Tensor& ids, framework::Tensor* value, + bool auto_grown = false); /* - * @brief Set a key-value pair into the table. - * This function will double the value memory if it's not engouth. + * @brief Get the index of the key from id_to_index_ map. If the key not + * exist, + * add the key into id_to_index_. * - * @note: - * 1. The first dim of the value should be 1 - * 2. The value should be initialized and the data type - * should be the same with the table. - * - * @return true if the key is a new one, otherwise false + * Note!!! this interface is only used when selected_rows is used as + * parameters + * for distribute lookup table. * + * @return index of the key. */ - bool Set(int64_t key, const Tensor& value); + int64_t AutoGrownIndex(int64_t key, bool auto_grown); - /* - * @brief Get the index of key in rows - * - * @return -1 if the key does not exists. - */ - int64_t Index(int64_t key) const { - auto it = std::find(rows_.begin(), rows_.end(), key); - if (it == rows_.end()) { - return static_cast(-1); - } - return static_cast(std::distance(rows_.begin(), it)); - } + void SyncIndex(); DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); @@ -127,9 +133,10 @@ class SelectedRows { // SelectedRows are simply concated when adding together. Until a // SelectedRows add a Tensor, will the duplicate rows be handled. Vector rows_; + std::unordered_map id_to_index_; std::unique_ptr value_{nullptr}; int64_t height_; - std::unique_ptr auto_grown_mutex_{nullptr}; + std::unique_ptr rwlock_{nullptr}; }; /* diff --git a/paddle/fluid/framework/selected_rows_test.cc b/paddle/fluid/framework/selected_rows_test.cc index eefcaa5672c..5ca864cfdf7 100644 --- a/paddle/fluid/framework/selected_rows_test.cc +++ b/paddle/fluid/framework/selected_rows_test.cc @@ -9,8 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/framework/selected_rows.h" +#include +#include // NOLINT + #include "gtest/gtest.h" +#include "paddle/fluid/framework/selected_rows.h" namespace paddle { namespace framework { @@ -59,39 +62,129 @@ TEST_F(SelectedRowsTester, SerializeAndDeseralize) { ASSERT_EQ(selected_rows_->GetCompleteDims(), dst_tensor.GetCompleteDims()); } -TEST_F(SelectedRowsTester, SparseTable) { +TEST(SelectedRows, SparseTable) { platform::CPUPlace cpu; SelectedRows table; + + int64_t table_size = 100; + int64_t embedding_width = 8; // initialize a sparse table - table.mutable_value()->Resize(framework::make_ddim({1, 100})); - table.mutable_value()->mutable_data(cpu); - table.mutable_rows()->push_back(1); + table.mutable_value()->Resize( + framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } + ASSERT_EQ(table.AutoGrownIndex(10, true), 0); + ASSERT_EQ(table.AutoGrownIndex(8, true), 1); + ASSERT_EQ(table.AutoGrownIndex(8, true), 1); + ASSERT_EQ(table.AutoGrownIndex(6, true), 2); + ASSERT_TRUE(table.HasKey(10)); + ASSERT_TRUE(table.HasKey(8)); + ASSERT_TRUE(table.HasKey(6)); + ASSERT_EQ(table.rows().size(), 3); + + framework::Tensor ids; + ids.Resize(framework::make_ddim({4})); + auto* ids_data = ids.mutable_data(cpu); + ids_data[0] = static_cast(6); + ids_data[1] = static_cast(6); + ids_data[2] = static_cast(8); + ids_data[3] = static_cast(10); - int64_t key = 10000; - int64_t non_key = 999; - framework::Tensor value; - value.Resize(framework::make_ddim({1, 100})); - auto ptr = value.mutable_data(cpu); - ptr[0] = static_cast(10); + framework::Tensor get_value; + auto* value_data = get_value.mutable_data( + framework::make_ddim({4, embedding_width}), cpu); + table.Get(ids, &get_value); - ASSERT_EQ(table.rows().size(), static_cast(1)); - ASSERT_EQ(table.HasKey(key), false); + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[0 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[1 * embedding_width + j], 2); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[2 * embedding_width + j], 1); + } + for (int j = 0; j < embedding_width; ++j) { + ASSERT_EQ(value_data[3 * embedding_width + j], 0); + } +} - table.Set(key, value); +void f1(SelectedRows* table, int table_size) { + for (int i = 1000000; i > 0; --i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} - ASSERT_EQ(table.rows().size(), static_cast(2)); - ASSERT_EQ(table.HasKey(key), true); - // check re-allocate - ASSERT_EQ(table.value().dims()[0], static_cast(4)); +void f2(SelectedRows* table, int table_size) { + for (int i = 0; i < 1000000; ++i) { + auto id = i % table_size; + int64_t index1 = table->AutoGrownIndex(id, true); + int64_t index2 = table->AutoGrownIndex(id, false); + int64_t index3 = table->AutoGrownIndex(id, true); + ASSERT_EQ(index1, index2); + ASSERT_EQ(index2, index3); + } +} - framework::Tensor get_value; - get_value.mutable_data(framework::make_ddim({2, 100}), cpu); - std::vector keys({non_key, key}); - auto non_key_pairs = table.Get(keys, &get_value); +void f3(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 100000; i > 0; --i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f3 run time:" << t2 - t1 << std::endl; +} + +void f4(SelectedRows* table, int table_size) { + clock_t t1 = clock(); + for (int i = 0; i < 100000; ++i) { + auto id1 = table->AutoGrownIndex(i % table_size, true); + auto id2 = table->Index(i % table_size); + ASSERT_EQ(id1, id2); + } + clock_t t2 = clock(); + std::cout << "f4 run time:" << t2 - t1 << std::endl; +} + +TEST(SelectedRows, MultiThreadAutoIndex) { + platform::CPUPlace cpu; + SelectedRows table; + + int64_t table_size = 100000; + int64_t embedding_width = 8; + // initialize a sparse table + table.mutable_value()->Resize( + framework::make_ddim({table_size, embedding_width})); + auto* data = table.mutable_value()->mutable_data(cpu); + for (int64_t i = 0; i < table_size; ++i) { + for (int64_t j = 0; j < embedding_width; ++j) { + data[i * embedding_width + j] = static_cast(i); + } + } - ASSERT_EQ(get_value.data()[100], static_cast(10)); - ASSERT_EQ(non_key_pairs.size(), static_cast(1)); - ASSERT_EQ(non_key_pairs[0].first, non_key); + std::thread t1(f1, &table, table_size); + std::thread t11(f1, &table, table_size); + std::thread t2(f2, &table, table_size); + std::thread t22(f2, &table, table_size); + t1.join(); + t11.join(); + t2.join(); + t22.join(); + std::thread t3(f3, &table, table_size); + std::thread t4(f4, &table, table_size); + t3.join(); + t4.join(); } } // namespace framework diff --git a/paddle/fluid/operators/distributed/rpc_server_test.cc b/paddle/fluid/operators/distributed/rpc_server_test.cc index b50830c362d..d6176e1443d 100644 --- a/paddle/fluid/operators/distributed/rpc_server_test.cc +++ b/paddle/fluid/operators/distributed/rpc_server_test.cc @@ -78,10 +78,9 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, int64_t rows_numel) { CreateVarsOnScope(scope, place); auto w = scope->Var("w")->GetMutable(); - auto rows = w->mutable_rows(); - for (int64_t i = 0; i < rows_numel; ++i) rows->push_back(i); auto w_value = w->mutable_value(); w_value->Resize({rows_numel, 10}); + for (int64_t i = 0; i < rows_numel; ++i) w->AutoGrownIndex(i, true); auto ptr = w_value->mutable_data(*place); diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index 2ce11e712fb..de3f0990e10 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -46,10 +45,6 @@ class LookupSparseTableOp : public framework::OperatorBase { auto out_var = scope.FindVar(Output("Out")); auto w_var = scope.FindVar(Input("W")); auto ids_var = scope.FindVar(Input("Ids")); - unsigned int seed = static_cast(Attr("seed")); - float min = Attr("min"); - float max = Attr("max"); - bool auto_grown_table = Attr("auto_grown_table"); PADDLE_ENFORCE(out_var->IsType(), "The type of Out var should be LodTensor."); @@ -60,46 +55,17 @@ class LookupSparseTableOp : public framework::OperatorBase { auto &ids_t = ids_var->Get(); auto out_t = out_var->GetMutable(); auto w_t = w_var->GetMutable(); - std::vector keys; - keys.resize(ids_t.numel()); - for (int64_t i = 0; i < ids_t.numel(); ++i) { - keys[i] = ids_t.data()[i]; - } // TODO(Yancey1989): support CUDA Place for the sparse table platform::CPUPlace cpu; auto out_shape = w_t->value().dims(); - out_shape[0] = keys.size(); + out_shape[0] = ids_t.numel(); out_t->Resize(out_shape); out_t->mutable_data(cpu, w_t->value().type()); PADDLE_ENFORCE_EQ(framework::ToDataType(w_t->value().type()), framework::proto::VarType::FP32, "The sparse table only support FP32"); - auto non_keys_pair = w_t->Get(keys, out_t); - if (!auto_grown_table) { - PADDLE_ENFORCE_EQ(non_keys_pair.size(), static_cast(0), - "there is some keys does exists in the sparse table."); - } - auto value_shape = w_t->value().dims(); - value_shape[0] = 1; - for (const auto &it : non_keys_pair) { - const auto key = it.first; - const auto index = it.second; - framework::Tensor value; - value.Resize(value_shape); - auto data = value.mutable_data(cpu); - - std::minstd_rand engine; - engine.seed(seed); - std::uniform_real_distribution dist(min, max); - int64_t size = value.numel(); - for (int64_t i = 0; i < size; ++i) { - data[i] = dist(engine); - } - w_t->Set(key, value); - memory::Copy(cpu, out_t->mutable_data(cpu) + index * value.numel(), - cpu, value.data(), value.numel() * sizeof(float)); - } + w_t->Get(ids_t, out_t, true); } }; @@ -121,21 +87,6 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker { "Otherwise the given value indicates padding the output " "with zeros whenever lookup encounters it in Ids.") .SetDefault(kNoPadding); - AddAttr("min", - "(float, default -1.0) " - "Minimum value of uniform random") - .SetDefault(-1.0f); - AddAttr("max", - "(float, default 1.0) " - "Maximum value of uniform random") - .SetDefault(1.0f); - AddAttr("seed", - "(int, default 0) " - "Random seed used for generating samples. " - "0 means use a seed generated by the system." - "Note that if seed is not 0, this operator will always " - "generate the same random numbers every time.") - .SetDefault(0); AddAttr("auto_grown_table", "(bool default false)" "Whether create new value if for nonexistent key.") diff --git a/paddle/fluid/operators/sgd_op.h b/paddle/fluid/operators/sgd_op.h index 2685ce217ee..d8b0165b2a8 100644 --- a/paddle/fluid/operators/sgd_op.h +++ b/paddle/fluid/operators/sgd_op.h @@ -111,7 +111,7 @@ class SGDOpKernel : public framework::OpKernel { for (size_t i = 0; i < grad.rows().size(); i++) { PADDLE_ENFORCE(grad.rows()[i] < grad.height(), "Input rows index should less than height"); - int64_t id_index = param.Index(grad.rows()[i]); + int64_t id_index = param_out->AutoGrownIndex(grad.rows()[i], false); PADDLE_ENFORCE_GE(id_index, static_cast(0), "id should be in the table"); for (int64_t j = 0; j < grad_row_width; j++) { diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index edd1baa4ace..5248767c2ee 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -30,8 +30,10 @@ class CPUUniformRandomKernel : public framework::OpKernel { tensor = out_var->GetMutable(); } else if (out_var->IsType()) { auto shape = ctx.Attr>("shape"); - tensor = out_var->GetMutable()->mutable_value(); + auto* selected_rows = out_var->GetMutable(); + tensor = selected_rows->mutable_value(); tensor->Resize(framework::make_ddim(shape)); + selected_rows->mutable_rows()->reserve(shape[0]); } else { PADDLE_THROW( "uniform_random_op's output only" diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 40ced8e1c78..6c58478b0dd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -249,6 +249,7 @@ PYBIND11_PLUGIN(core) { self.set_rows(new_rows); #endif }) + .def("sync_index", [](SelectedRows &instance) { instance.SyncIndex(); }) .def("rows", [](SelectedRows &self) { auto rows = self.rows(); std::vector new_rows; diff --git a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py index 7f75d0e6e9c..11e5d8b536f 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py @@ -21,36 +21,27 @@ import paddle.fluid.core as core from paddle.fluid.op import Operator -def output_hist(out): - hist, _ = np.histogram(out, range=(-5, 10)) - hist = hist.astype("float32") - hist /= float(out.size) - prob = 0.1 * np.ones((10)) - return hist, prob - - class TestLookupSpraseTable(OpTest): def check_with_place(self, place): scope = core.Scope() - # create and initialize Id Variable - ids = scope.var("Ids").get_tensor() - ids_array = np.array([0, 2, 3, 5, 100]).astype("int64") - ids.set(ids_array, place) - # create and initialize W Variable - rows = [0, 1, 2, 3, 4, 5, 6] - row_numel = 10000 + table_size = 10000 + row_numel = 8 w_selected_rows = scope.var('W').get_selected_rows() - w_selected_rows.set_height(len(rows)) - w_selected_rows.set_rows(rows) - w_array = np.ones((len(rows), row_numel)).astype("float32") - for i in range(len(rows)): + w_selected_rows.set_height(table_size) + w_array = np.ones((table_size, row_numel)).astype("float32") + for i in range(table_size): w_array[i] *= i w_tensor = w_selected_rows.get_tensor() w_tensor.set(w_array, place) + # create and initialize Id Variable + ids = scope.var("Ids").get_tensor() + ids_array1 = np.array([0, 2, 3, 2, 5, 0, 100]).astype("int64") + ids.set(ids_array1, place) + # create Out Variable out_tensor = scope.var('Out').get_tensor() @@ -66,16 +57,28 @@ class TestLookupSpraseTable(OpTest): lookup_table.run(scope, place) # get result from Out - result_array = np.array(out_tensor) + result_array1 = np.array(out_tensor) # all(): return True if all elements of the iterable are true (or if the iterable is empty) - for idx, row in enumerate(ids_array[:-2]): - assert (row == result_array[idx]).all() + assert (result_array1[0] == w_array[0]).all() + assert (result_array1[1] == w_array[1]).all() + assert (result_array1[2] == w_array[2]).all() + assert (result_array1[3] == w_array[1]).all() + assert (result_array1[4] == w_array[3]).all() + assert (result_array1[5] == w_array[0]).all() + assert (result_array1[6] == w_array[4]).all() + + # create and initialize Id Variable + ids = scope.var("Ids").get_tensor() + ids_array2 = np.array([4, 2, 3, 7, 100000]).astype("int64") + ids.set(ids_array2, place) + lookup_table.run(scope, place) - # check the random value - hist, prob = output_hist(result_array[-1]) - self.assertTrue( - np.allclose( - hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + result_array2 = np.array(out_tensor) + assert (result_array2[0] == w_array[5]).all() + assert (result_array2[1] == w_array[1]).all() + assert (result_array2[2] == w_array[2]).all() + assert (result_array2[3] == w_array[6]).all() + assert (result_array2[4] == w_array[7]).all() def test_w_is_selected_rows(self): places = [core.CPUPlace()] diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index c14a83b4bbc..b46e4bfb86b 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -126,6 +126,7 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase): w_selected_rows = scope.var('Param').get_selected_rows() w_selected_rows.set_height(len(param_rows)) w_selected_rows.set_rows(param_rows) + w_selected_rows.sync_index() w_array = np.ones((len(param_rows), row_width)).astype("float32") for i in range(len(param_rows)): w_array[i] *= i -- GitLab