提交 7a6ffb62 编写于 作者: Q qiaolongfei

add TestLookupTableWIsSelectedRows

上级 a94e2574
...@@ -18,6 +18,22 @@ limitations under the License. */ ...@@ -18,6 +18,22 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static inline framework::OpKernelType ExpectedKernelType(
const framework::ExecutionContext& ctx) {
auto* table_var = ctx.InputVar("W");
if (table_var->IsType<LoDTensor>()) {
return framework::OpKernelType(
framework::ToDataType(table_var->Get<LoDTensor>().type()),
ctx.device_context());
} else if (table_var->IsType<SelectedRows>()) {
return framework::OpKernelType(
framework::ToDataType(table_var->Get<SelectedRows>().value().type()),
ctx.device_context());
} else {
PADDLE_THROW("W should be LoDTensor or SelectedRows");
}
}
class LookupTableOp : public framework::OperatorWithKernel { class LookupTableOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -51,9 +67,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -51,9 +67,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return ExpectedKernelType(ctx);
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
} }
}; };
...@@ -124,9 +138,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -124,9 +138,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return ExpectedKernelType(ctx);
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
ctx.device_context());
} }
}; };
......
...@@ -96,5 +96,47 @@ class TestLookupTableIdsIsSelectedRows(OpTest): ...@@ -96,5 +96,47 @@ class TestLookupTableIdsIsSelectedRows(OpTest):
self.check_with_place(place) self.check_with_place(place)
class TestLookupTableWIsSelectedRows(OpTest):
def check_with_place(self, place):
scope = core.Scope()
# create and initialize Id Variable
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.array([[0], [4], [3], [5]]).astype("int64")
ids_tensor.set(ids_array, place)
# create and initialize W Variable
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 12
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
ids_tensor = w_selected_rows.get_tensor()
ids_tensor.set(w_array, place)
# create Out Variable
Out_tensor = scope.var('Out').get_tensor()
# create and run lookup_table operator
lookup_table = Operator("lookup_table", W='W', Ids='Ids', Out='Out')
lookup_table.run(scope, place)
# get result from Out
result_array = np.array(Out_tensor)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array):
assert (row[0] == result_array[idx]).all()
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册