From ea2a38f451becb57e68288ea7626a06b9f0ff4c9 Mon Sep 17 00:00:00 2001 From: Yinggang Wang Date: Thu, 18 Mar 2021 13:23:46 +0800 Subject: [PATCH] Feat PySize support slicing (#4437) * feat(PySize): support slicing * style(PySize): slice.compute throw error_already_set Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/api/python/framework/shape.cpp | 14 ++++++++++++++ oneflow/python/test/ops/test_size.py | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/oneflow/api/python/framework/shape.cpp b/oneflow/api/python/framework/shape.cpp index 63f1828809..8cc2637f17 100644 --- a/oneflow/api/python/framework/shape.cpp +++ b/oneflow/api/python/framework/shape.cpp @@ -42,6 +42,19 @@ struct ShapeExportUtil final { } } + static std::shared_ptr Slicing(const Shape& shape, const py::slice& slice) { + size_t start, stop, step, slicelength; + if (!slice.compute(shape.dim_vec().size(), &start, &stop, &step, &slicelength)) { + throw py::error_already_set(); + } + DimVector shape_dims; + for (size_t i = 0; i < slicelength; ++i) { + shape_dims.emplace_back(shape.dim_vec().at(start)); + start += step; + } + return std::make_shared(shape_dims); + } + static std::string ToString(const Shape& shape) { std::stringstream ss; int32_t idx = 0; @@ -90,6 +103,7 @@ ONEFLOW_API_PYBIND11_MODULE("", m) { .def("__str__", &ShapeExportUtil::ToString) .def("__repr__", &ShapeExportUtil::ToString) .def("__getitem__", [](const Shape& shape, int idx) { return shape.At(idx); }) + .def("__getitem__", &ShapeExportUtil::Slicing) .def( "__iter__", [](const Shape& shape) { diff --git a/oneflow/python/test/ops/test_size.py b/oneflow/python/test/ops/test_size.py index 30c952ee46..ef7689e82f 100644 --- a/oneflow/python/test/ops/test_size.py +++ b/oneflow/python/test/ops/test_size.py @@ -110,6 +110,14 @@ class TestSize(flow.unittest.TestCase): with test_case.assertRaises(ValueError): size.index(2, start=3) + def test_slicing(test_case): + size = flow.Size([2, 3, 4, 5]) + test_case.assertTrue(size[1:3] == flow.Size((3, 4))) + test_case.assertTrue(size[1:] == flow.Size((3, 4, 5))) + test_case.assertTrue(size[:2] == (2, 3)) + test_case.assertTrue(size[-3:] == flow.Size((3, 4, 5))) + test_case.assertTrue(size[-3:-1] == flow.Size((3, 4))) + if __name__ == "__main__": unittest.main() -- GitLab