未验证 提交 ea2a38f4 编写于 作者: Y Yinggang Wang 提交者: GitHub

Feat PySize support slicing (#4437)

* feat(PySize): support slicing

* style(PySize): slice.compute throw error_already_set
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 81c982d6
......@@ -42,6 +42,19 @@ struct ShapeExportUtil final {
}
}
static std::shared_ptr<Shape> Slicing(const Shape& shape, const py::slice& slice) {
size_t start, stop, step, slicelength;
if (!slice.compute(shape.dim_vec().size(), &start, &stop, &step, &slicelength)) {
throw py::error_already_set();
}
DimVector shape_dims;
for (size_t i = 0; i < slicelength; ++i) {
shape_dims.emplace_back(shape.dim_vec().at(start));
start += step;
}
return std::make_shared<Shape>(shape_dims);
}
static std::string ToString(const Shape& shape) {
std::stringstream ss;
int32_t idx = 0;
......@@ -90,6 +103,7 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
.def("__str__", &ShapeExportUtil::ToString)
.def("__repr__", &ShapeExportUtil::ToString)
.def("__getitem__", [](const Shape& shape, int idx) { return shape.At(idx); })
.def("__getitem__", &ShapeExportUtil::Slicing)
.def(
"__iter__",
[](const Shape& shape) {
......
......@@ -110,6 +110,14 @@ class TestSize(flow.unittest.TestCase):
with test_case.assertRaises(ValueError):
size.index(2, start=3)
def test_slicing(test_case):
size = flow.Size([2, 3, 4, 5])
test_case.assertTrue(size[1:3] == flow.Size((3, 4)))
test_case.assertTrue(size[1:] == flow.Size((3, 4, 5)))
test_case.assertTrue(size[:2] == (2, 3))
test_case.assertTrue(size[-3:] == flow.Size((3, 4, 5)))
test_case.assertTrue(size[-3:-1] == flow.Size((3, 4)))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册