提交 31e8d807 编写于 作者: Q qiaolongfei

optimize code

上级 af1d3f5b
...@@ -17,12 +17,6 @@ limitations under the License. */ ...@@ -17,12 +17,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
size_t GetIndex(const std::vector<int64_t>& rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return static_cast<size_t>(std::distance(rows.begin(), it));
}
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows, void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) { const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version { // the 1st field, uint32_t version
......
...@@ -50,6 +50,15 @@ class SelectedRows { ...@@ -50,6 +50,15 @@ class SelectedRows {
void set_rows(const Vector<int64_t>& rows) { rows_ = rows; } void set_rows(const Vector<int64_t>& rows) { rows_ = rows; }
/**
* get the index of id in rows
*/
int64_t index(int64_t id) const {
auto it = std::find(rows_.begin(), rows_.end(), id);
PADDLE_ENFORCE(it != rows_.end(), "id should be in rows");
return static_cast<int64_t>(std::distance(rows_.begin(), it));
}
DDim GetCompleteDims() const { DDim GetCompleteDims() const {
std::vector<int64_t> dims = vectorize(value_->dims()); std::vector<int64_t> dims = vectorize(value_->dims());
dims[0] = height_; dims[0] = height_;
...@@ -65,11 +74,6 @@ class SelectedRows { ...@@ -65,11 +74,6 @@ class SelectedRows {
int64_t height_; int64_t height_;
}; };
/**
* Find the index of value in rows.
*/
size_t GetIndex(const std::vector<int64_t>& rows, int64_t value);
/* /*
* Serialize/Desiralize SelectedRows to std::ostream * Serialize/Desiralize SelectedRows to std::ostream
* You can pass ofstream or ostringstream to serilize to file * You can pass ofstream or ostringstream to serilize to file
......
...@@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor; ...@@ -30,13 +30,7 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
static constexpr int64_t kNoPadding = -1; constexpr int64_t kNoPadding = -1;
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return static_cast<size_t>(std::distance(rows.begin(), it));
}
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
...@@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -55,7 +49,9 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<SelectedRows>("W"); auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW("table only support LoDTensor and SelectedRows"); PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
} }
int64_t *ids; int64_t *ids;
...@@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -107,7 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
memset(output + i * row_width, 0, row_width * sizeof(T)); memset(output + i * row_width, 0, row_width * sizeof(T));
} else { } else {
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = getIndex(table_t.rows(), ids[i]); auto id_index = table_t.index(ids[i]);
memcpy(output + i * row_width, table + id_index * row_width, memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T)); row_width * sizeof(T));
} }
...@@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -128,7 +124,9 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
auto *table_t = context.Input<SelectedRows>("W"); auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims(); table_dim = table_t->value().dims();
} else { } else {
PADDLE_THROW("table only support LoDTensor and SelectedRows"); PADDLE_THROW(
"The parameter W of a LookupTable "
"must be either LoDTensor or SelectedRows");
} }
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
......
...@@ -106,7 +106,7 @@ class SGDOpKernel : public framework::OpKernel<T> { ...@@ -106,7 +106,7 @@ class SGDOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < grad.rows().size(); i++) { for (size_t i = 0; i < grad.rows().size(); i++) {
PADDLE_ENFORCE(grad.rows()[i] < grad.height(), PADDLE_ENFORCE(grad.rows()[i] < grad.height(),
"Input rows index should less than height"); "Input rows index should less than height");
size_t id_index = framework::GetIndex(param.rows(), grad.rows()[i]); int64_t id_index = param.index(grad.rows()[i]);
for (int64_t j = 0; j < grad_row_width; j++) { for (int64_t j = 0; j < grad_row_width; j++) {
out_data[id_index * grad_row_width + j] -= out_data[id_index * grad_row_width + j] -=
lr[0] * grad_data[i * grad_row_width + j]; lr[0] * grad_data[i * grad_row_width + j];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册