未验证 提交 91f80f79 编写于 作者: Q Qiao Longfei 提交者: GitHub

Topk share lod (#7373)

* add lod tensor ToAbsOffset test

* add share lod to topk op and softmax op
上级 cedd9805
......@@ -115,5 +115,21 @@ TEST(LoD, AppendLoD) {
EXPECT_EQ(origin, expected);
}
TEST(LoD, ToAbsOffset) {
LoD relative_lod;
relative_lod.push_back(std::vector<size_t>({0, 2}));
relative_lod.push_back(std::vector<size_t>({0, 1, 3}));
relative_lod.push_back(std::vector<size_t>({0, 2, 4, 5}));
LoD abs_lod = paddle::framework::ToAbsOffset(relative_lod);
LoD expected;
expected.push_back(std::vector<size_t>({0, 5}));
expected.push_back(std::vector<size_t>({0, 2, 5}));
expected.push_back(std::vector<size_t>({0, 2, 4, 5}));
EXPECT_EQ(abs_lod, expected);
}
} // namespace framework
} // namespace paddle
......@@ -31,6 +31,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix.");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
......
......@@ -41,6 +41,8 @@ class TopkOp : public framework::OperatorWithKernel {
dims[dims.size() - 1] = k;
ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册