You need to sign in or sign up before continuing.
未验证 提交 37d9a72e 编写于 作者: Q Qiao Longfei 提交者: GitHub

Merge pull request #9575 from jacquesqiao/lookup_table_support_SelectedRows_as_parameter

Lookup table support selected rows as parameter
...@@ -10,6 +10,9 @@ See the License for the specific language governing permissions and ...@@ -10,6 +10,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -52,7 +55,7 @@ class SelectedRows { ...@@ -52,7 +55,7 @@ class SelectedRows {
private: private:
// Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here.
// SelectedRows are simplely concated when adding together. Until a // SelectedRows are simply concated when adding together. Until a
// SelectedRows add a Tensor, will the duplicate rows be handled. // SelectedRows add a Tensor, will the duplicate rows be handled.
Vector<int64_t> rows_; Vector<int64_t> rows_;
std::unique_ptr<Tensor> value_{nullptr}; std::unique_ptr<Tensor> value_{nullptr};
......
...@@ -18,6 +18,22 @@ limitations under the License. */ ...@@ -18,6 +18,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static inline framework::OpKernelType ExpectedKernelType(
const framework::ExecutionContext& ctx) {
auto* table_var = ctx.InputVar("W");
if (table_var->IsType<LoDTensor>()) {
return framework::OpKernelType(
framework::ToDataType(table_var->Get<LoDTensor>().type()),
ctx.device_context());
} else if (table_var->IsType<SelectedRows>()) {
return framework::OpKernelType(
framework::ToDataType(table_var->Get<SelectedRows>().value().type()),
ctx.device_context());
} else {
PADDLE_THROW("W should be LoDTensor or SelectedRows");
}
}
class LookupTableOp : public framework::OperatorWithKernel { class LookupTableOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -51,9 +67,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -51,9 +67,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return ExpectedKernelType(ctx);
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
} }
}; };
...@@ -84,7 +98,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -84,7 +98,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"If the value is -1, it makes no effect to lookup. " "If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output " "Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids.") "with zeros whenever lookup encounters it in Ids.")
.SetDefault(-1); .SetDefault(kNoPadding);
AddComment(R"DOC( AddComment(R"DOC(
Lookup Table Operator. Lookup Table Operator.
...@@ -124,9 +138,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -124,9 +138,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return ExpectedKernelType(ctx);
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
} }
}; };
......
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -25,16 +28,37 @@ namespace operators { ...@@ -25,16 +28,37 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
static constexpr int64_t kNoPadding = -1;
inline size_t getIndex(const std::vector<int64_t> &rows, int64_t value) {
auto it = std::find(rows.begin(), rows.end(), value);
PADDLE_ENFORCE(it != rows.end(), "id should be in rows");
return static_cast<size_t>(std::distance(rows.begin(), it));
}
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto* table_t = context.Input<LoDTensor>("W"); auto *table_var = context.InputVar("W");
auto* ids_var = context.InputVar("Ids"); auto *ids_var = context.InputVar("Ids");
Tensor* output_t = context.Output<Tensor>("Out"); Tensor *output_t = context.Output<Tensor>("Out");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
DDim table_dim;
int64_t* ids; if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW("table only support LoDTensor and SelectedRows");
}
int64_t *ids;
int64_t ids_numel; int64_t ids_numel;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type // The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
...@@ -42,39 +66,50 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -42,39 +66,50 @@ class LookupTableKernel : public framework::OpKernel<T> {
// when Ids's type is SelectedRows, the rows of Ids contains the // when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W. // ids to be looked up in W.
if (ids_var->IsType<LoDTensor>()) { if (ids_var->IsType<LoDTensor>()) {
auto* ids_t = context.Input<LoDTensor>("Ids"); auto *ids_t = context.Input<LoDTensor>("Ids");
ids = const_cast<int64_t*>(ids_t->data<int64_t>()); ids = const_cast<int64_t *>(ids_t->data<int64_t>());
ids_numel = ids_t->numel(); ids_numel = ids_t->numel();
} else if (ids_var->IsType<SelectedRows>()) { } else if (ids_var->IsType<SelectedRows>()) {
auto* ids_t = context.Input<SelectedRows>("Ids"); auto *ids_t = context.Input<SelectedRows>("Ids");
ids = const_cast<int64_t*>(ids_t->rows().data()); ids = const_cast<int64_t *>(ids_t->rows().data());
ids_numel = ids_t->rows().size(); ids_numel = ids_t->rows().size();
output_t->Resize({ids_numel, table_t->dims()[1]}); output_t->Resize({ids_numel, table_dim[1]});
} else { } else {
PADDLE_THROW("Unsupported Variable Type of Ids"); PADDLE_THROW("Unsupported Variable Type of Ids");
} }
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
int N = table_t->dims()[0]; auto *table = table_t->data<T>();
int D = table_t->dims()[1]; auto *output = output_t->mutable_data<T>(context.GetPlace());
auto* table = table_t->data<T>();
auto* output = output_t->mutable_data<T>(context.GetPlace());
if (padding_idx == -1) {
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
PADDLE_ENFORCE_LT(ids[i], N); if (padding_idx != kNoPadding && ids[i] == padding_idx) {
PADDLE_ENFORCE_GE(ids[i], 0); memset(output + i * row_width, 0, row_width * sizeof(T));
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); } else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
} }
} else { } else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
if (ids[i] == padding_idx) { if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * D, 0, D * sizeof(T)); memset(output + i * row_width, 0, row_width * sizeof(T));
} else { } else {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0); PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T)); auto id_index = getIndex(table_t.rows(), ids[i]);
memcpy(output + i * row_width, table + id_index * row_width,
row_width * sizeof(T));
} }
} }
} }
...@@ -84,17 +119,27 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -84,17 +119,27 @@ class LookupTableKernel : public framework::OpKernel<T> {
template <typename T> template <typename T>
class LookupTableGradKernel : public framework::OpKernel<T> { class LookupTableGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W");
DDim table_dim;
if (table_var->IsType<LoDTensor>()) {
table_dim = context.Input<LoDTensor>("W")->dims();
} else if (table_var->IsType<SelectedRows>()) {
auto *table_t = context.Input<SelectedRows>("W");
table_dim = table_t->value().dims();
} else {
PADDLE_THROW("table only support LoDTensor and SelectedRows");
}
bool is_sparse = context.Attr<bool>("is_sparse"); bool is_sparse = context.Attr<bool>("is_sparse");
// Since paddings are not trainable and fixed in forward, the gradient of // Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward. // paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
auto* ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto* table = context.Input<LoDTensor>("W"); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto* ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); auto ids_dim = ids->dims();
framework::Vector<int64_t> new_rows; framework::Vector<int64_t> new_rows;
...@@ -104,31 +149,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -104,31 +149,30 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
} }
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
auto* d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_dim[0], table->dims()[1]}); d_table_value->Resize({ids_dim[0], table_dim[1]});
d_table_value->mutable_data<T>(context.GetPlace()); d_table_value->mutable_data<T>(context.GetPlace());
d_table->set_height(table->dims()[0]); d_table->set_height(table_dim[0]);
auto* d_output_data = d_output->data<T>(); auto *d_output_data = d_output->data<T>();
auto* d_table_data = d_table_value->data<T>(); auto *d_table_data = d_table_value->data<T>();
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims()); PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
} else { } else {
auto* ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto* d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* d_table = context.Output<LoDTensor>(framework::GradVarName("W")); auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto* table = context.Input<LoDTensor>("W");
auto* ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); auto ids_dim = ids->dims();
int N = table->dims()[0]; int N = table_dim[0];
int D = d_output->dims()[1]; int D = d_output->dims()[1];
auto* d_output_data = d_output->data<T>(); auto *d_output_data = d_output->data<T>();
auto* d_table_data = d_table->mutable_data<T>(context.GetPlace()); auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
memset(d_table_data, 0, d_table->numel() * sizeof(T)); memset(d_table_data, 0, d_table->numel() * sizeof(T));
......
...@@ -96,5 +96,47 @@ class TestLookupTableIdsIsSelectedRows(OpTest): ...@@ -96,5 +96,47 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
self.check_with_place(place) self.check_with_place(place)
class TestLookupTableWIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Id Variable
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.array([[0], [4], [3], [5]]).astype("int64")
ids_tensor.set(ids_array, place)
# create and initialize W Variable
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 12
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
ids_tensor = w_selected_rows.get_tensor()
ids_tensor.set(w_array, place)
# create Out Variable
Out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(scope, place)
# get result from Out
result_array = np.array(Out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array):
assert (row[0] == result_array[idx]).all()
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册