diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 162b0355b58b1dd5769da32de820cdb2b5319fd1..d9d6b7dd67f1c6e4bbd6a4e1a8f0843d4cb93c05 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -17,12 +17,6 @@ limitations under the License. */ namespace paddle { namespace framework { -size_t GetIndex(const std::vector& rows, int64_t value) { - auto it = std::find(rows.begin(), rows.end(), value); - PADDLE_ENFORCE(it != rows.end(), "id should be in rows"); - return static_cast(std::distance(rows.begin(), it)); -} - void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, const platform::DeviceContext& dev_ctx) { { // the 1st field, uint32_t version diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index 2979d408701976599c9da95e3bed6a3448b27942..8e2d9470d3954e0f66c74828a8d8292c2875a8f4 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -50,6 +50,15 @@ class SelectedRows { void set_rows(const Vector& rows) { rows_ = rows; } + /** + * get the index of id in rows + */ + int64_t index(int64_t id) const { + auto it = std::find(rows_.begin(), rows_.end(), id); + PADDLE_ENFORCE(it != rows_.end(), "id should be in rows"); + return static_cast(std::distance(rows_.begin(), it)); + } + DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); dims[0] = height_; @@ -65,11 +74,6 @@ class SelectedRows { int64_t height_; }; -/** - * Find the index of value in rows. - */ -size_t GetIndex(const std::vector& rows, int64_t value); - /* * Serialize/Desiralize SelectedRows to std::ostream * You can pass ofstream or ostringstream to serilize to file diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index fff5edda62d4b115605a4cab35ed5457b4db5f21..cb088c267bcc028ff11583cd73de5ca1722a9b69 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; -static constexpr int64_t kNoPadding = -1; - -inline size_t getIndex(const std::vector &rows, int64_t value) { - auto it = std::find(rows.begin(), rows.end(), value); - PADDLE_ENFORCE(it != rows.end(), "id should be in rows"); - return static_cast(std::distance(rows.begin(), it)); -} +constexpr int64_t kNoPadding = -1; template class LookupTableKernel : public framework::OpKernel { @@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel { auto *table_t = context.Input("W"); table_dim = table_t->value().dims(); } else { - PADDLE_THROW("table only support LoDTensor and SelectedRows"); + PADDLE_THROW( + "The parameter W of a LookupTable " + "must be either LoDTensor or SelectedRows"); } int64_t *ids; @@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel { memset(output + i * row_width, 0, row_width * sizeof(T)); } else { PADDLE_ENFORCE_GE(ids[i], 0); - auto id_index = getIndex(table_t.rows(), ids[i]); + auto id_index = table_t.index(ids[i]); memcpy(output + i * row_width, table + id_index * row_width, row_width * sizeof(T)); } @@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel { auto *table_t = context.Input("W"); table_dim = table_t->value().dims(); } else { - PADDLE_THROW("table only support LoDTensor and SelectedRows"); + PADDLE_THROW( + "The parameter W of a LookupTable " + "must be either LoDTensor or SelectedRows"); } bool is_sparse = context.Attr("is_sparse"); diff --git a/paddle/fluid/operators/sgd_op.h b/paddle/fluid/operators/sgd_op.h index 237cd2f8122466458e8356bb6d462a816ed1593a..8d2bdf75903b4958e14605781f65c5a214cb5300 100644 --- a/paddle/fluid/operators/sgd_op.h +++ b/paddle/fluid/operators/sgd_op.h @@ -106,7 +106,7 @@ class SGDOpKernel : public framework::OpKernel { for (size_t i = 0; i < grad.rows().size(); i++) { PADDLE_ENFORCE(grad.rows()[i] < grad.height(), "Input rows index should less than height"); - size_t id_index = framework::GetIndex(param.rows(), grad.rows()[i]); + int64_t id_index = param.index(grad.rows()[i]); for (int64_t j = 0; j < grad_row_width; j++) { out_data[id_index * grad_row_width + j] -= lr[0] * grad_data[i * grad_row_width + j];