From a015a8a39df1902d43ec97251f13f0b60d68bd6a Mon Sep 17 00:00:00 2001 From: chengduo Date: Thu, 20 Dec 2018 06:15:25 -0600 Subject: [PATCH] Refine merge_selected_rows Doc (#14748) * add doc for MergeSelectedRows test=develop * checkout selected_rows test=develop --- .../fluid/operators/merge_selected_rows_op.cc | 30 ++++++++++++++++++- .../test_get_tensor_from_selected_rows_op.py | 2 +- .../unittests/test_merge_selectedrows_op.py | 4 +-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/merge_selected_rows_op.cc b/paddle/fluid/operators/merge_selected_rows_op.cc index 3c15c839554..50f44c7fc5e 100644 --- a/paddle/fluid/operators/merge_selected_rows_op.cc +++ b/paddle/fluid/operators/merge_selected_rows_op.cc @@ -26,6 +26,13 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel { "Input(X) of MergeSelectedRowsOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of MergeSelectedRowsOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(), + framework::proto::VarType::SELECTED_ROWS, + "Input X only should be SelectedRows."); + PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(), + framework::proto::VarType::SELECTED_ROWS, + "Output Y only should be SelectedRows."); + ctx->ShareDim("X", /*->*/ "Out"); } }; @@ -43,7 +50,28 @@ class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker { R"DOC( MergeSelectedRows Operator. -MergeSelectedRows is used to merge the duplicated rows of the input. +MergeSelectedRows is used to merge the duplicated rows of the input. The +output's row has no duplicated, and it's order is incremental. + +Example: + Input: + X.rows is [0, 5, 5, 4, 19] + X.height is 20 + X.value is: + [[1, 1] + [2, 2] + [3, 3] + [4, 4] + [6, 6]] + + Output: + Out.row is [0, 4, 5, 19] + Out.height is 20 + Out.value is: + [[1, 1] + [4, 4] + [5, 5] + [6, 6]] )DOC"); } }; diff --git a/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py index 021b950b3b6..6cd02dad577 100644 --- a/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py +++ b/python/paddle/fluid/tests/unittests/test_get_tensor_from_selected_rows_op.py @@ -29,7 +29,7 @@ class TestGetTensorFromSelectedRows(unittest.TestCase): def check_with_place(self, place): scope = core.Scope() - x_rows = [0, 5, 5, 4, 20] + x_rows = [0, 5, 5, 4, 19] height = 20 row_numel = 2 diff --git a/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py index ce64da0478d..d2fa344b67a 100644 --- a/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py +++ b/python/paddle/fluid/tests/unittests/test_merge_selectedrows_op.py @@ -29,8 +29,8 @@ class TestMergeSelectedRows(unittest.TestCase): def check_with_place(self, place): scope = core.Scope() - x_rows = [0, 5, 5, 4, 20] - out_rows = [0, 4, 5, 20] + x_rows = [0, 5, 5, 4, 19] + out_rows = [0, 4, 5, 19] height = 20 row_numel = 2 -- GitLab