提交 31e8d807 编写于 作者: Q qiaolongfei

optimize code

上级 af1d3f5b
......@@ -17,12 +17,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
size_t GetIndex(const std::vector<int64_t>& rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return static_cast<size_t>(std::distance(rows.begin(), it));
}
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version
......
......@@ -50,6 +50,15 @@ class SelectedRows {
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
/**
* get the index of id in rows
*/
int64_t index(int64_t id) const {
auto it = std::find(rows_.begin(), rows_.end(), id);
PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_;
......@@ -65,11 +74,6 @@ class SelectedRows {
int64_t height_;
};
/**
* Find the index of value in rows.
*/
size_t GetIndex(const std::vector<int64_t>& rows, int64_t value);
/*
* Serialize/Desiralize SelectedRows to std::ostream
* You can pass ofstream or ostringstream to serilize to file
......
......@@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static constexpr int64_t kNoPadding = -1;
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return static_cast<size_t>(std::distance(rows.begin(), it));
}
constexpr int64_t kNoPadding = -1;
template <typename T>
class LookupTableKernel : public framework::OpKernel<T> {
......@@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW("table only support LoDTensor and SelectedRows");
PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
}
int64_t *ids;
......@@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = getIndex(table_t.rows(), ids[i]);
auto id_index = table_t.index(ids[i]);
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
}
......@@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW("table only support LoDTensor and SelectedRows");
PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
}
bool is_sparse = context.Attr<bool>("is_sparse");
......
......@@ -106,7 +106,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < grad.rows().size(); i++) {
PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
"Input rows index should less than height");
size_t id_index = framework::GetIndex(param.rows(), grad.rows()[i]);
int64_t id_index = param.index(grad.rows()[i]);
for (int64_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册