未验证 提交 a015a8a3 编写于 作者: C chengduo 提交者: GitHub

Refine merge_selected_rows Doc (#14748)

* add doc for MergeSelectedRows
test=develop

* checkout selected_rows
test=develop
上级 3babc801
......@@ -26,6 +26,13 @@ class MergeSelectedRowsOp : public framework::OperatorWithKernel {
"Input(X) of MergeSelectedRowsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of MergeSelectedRowsOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("X").front(),
framework::proto::VarType::SELECTED_ROWS,
"Input X only should be SelectedRows.");
PADDLE_ENFORCE_EQ(ctx->GetOutputsVarType("Out").front(),
framework::proto::VarType::SELECTED_ROWS,
"Output Y only should be SelectedRows.");
ctx->ShareDim("X", /*->*/ "Out");
}
};
......@@ -43,7 +50,28 @@ class MergeSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
R"DOC(
MergeSelectedRows Operator.
MergeSelectedRows is used to merge the duplicated rows of the input.
MergeSelectedRows is used to merge the duplicated rows of the input. The
output's row has no duplicated, and it's order is incremental.
Example:
Input:
X.rows is [0, 5, 5, 4, 19]
X.height is 20
X.value is:
[[1, 1]
[2, 2]
[3, 3]
[4, 4]
[6, 6]]
Output:
Out.row is [0, 4, 5, 19]
Out.height is 20
Out.value is:
[[1, 1]
[4, 4]
[5, 5]
[6, 6]]
)DOC");
}
};
......
......@@ -29,7 +29,7 @@ class TestGetTensorFromSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
x_rows = [0, 5, 5, 4, 20]
x_rows = [0, 5, 5, 4, 19]
height = 20
row_numel = 2
......
......@@ -29,8 +29,8 @@ class TestMergeSelectedRows(unittest.TestCase):
def check_with_place(self, place):
scope = core.Scope()
x_rows = [0, 5, 5, 4, 20]
out_rows = [0, 4, 5, 20]
x_rows = [0, 5, 5, 4, 19]
out_rows = [0, 4, 5, 19]
height = 20
row_numel = 2
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册