diff --git a/paddle/fluid/operators/merge_selected_rows_op.cc b/paddle/fluid/operators/merge_selected_rows_op.cc index 3c15c839554599104d21a5225c078d41735c4a60..50f44c7fc5ec90420d7c38f0f536ff7adb8f9ec4 100644 --- a/paddle/fluid/operators/merge_selected_rows_op.cc +++ b/paddle/fluid/operators/merge_selected_rows_op.cc @@ -26,6 +26,13 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel { "Input(X) of MergeSelectedRowsOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of MergeSelectedRowsOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(), + framework::proto::VarType::SELECTED_ROWS, + "Input X only should be SelectedRows."); + PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(), + framework::proto::VarType::SELECTED_ROWS, + "Output Y only should be SelectedRows."); + ctx->ShareDim("X", /*->*/ "Out"); } }; @@ -43,7 +50,28 @@ class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { R"DOC( MergeSelectedRows Operator. -MergeSelectedRows is used to merge the duplicated rows of the input. +MergeSelectedRows is used to merge the duplicated rows of the input. The +output's row has no duplicated, and it's order is incremental. + +Example: + Input: + X.rows is [0, 5, 5, 4, 19] + X.height is 20 + X.value is: + [[1, 1] + [2, 2] + [3, 3] + [4, 4] + [6, 6]] + + Output: + Out.row is [0, 4, 5, 19] + Out.height is 20 + Out.value is: + [[1, 1] + [4, 4] + [5, 5] + [6, 6]] )DOC"); } }; diff --git a/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py index 021b950b3b6245caecab22d476bbb9d6b6b45c5e..6cd02dad577b681b8c452bdb9574df60ffb4f82e 100644 --- a/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py @@ -29,7 +29,7 @@ class TestGetTensorFromSelectedRows(unittest.TestCase): def check_with_place(self, place): scope = core.Scope() - x_rows = [0, 5, 5, 4, 20] + x_rows = [0, 5, 5, 4, 19] height = 20 row_numel = 2 diff --git a/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py index ce64da0478d3997f4889ca942c67e0defac80b45..d2fa344b67ab33a93f92733efd68e896c767bad2 100644 --- a/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py +++ b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py @@ -29,8 +29,8 @@ class TestMergeSelectedRows(unittest.TestCase): def check_with_place(self, place): scope = core.Scope() - x_rows = [0, 5, 5, 4, 20] - out_rows = [0, 4, 5, 20] + x_rows = [0, 5, 5, 4, 19] + out_rows = [0, 4, 5, 19] height = 20 row_numel = 2