未验证 提交 4be2b0a3 编写于 作者: H Houjiang Chen 提交者: GitHub

Refactor oneflow.Size (#6645)

* Refactor oneflow.Size

* refine

* add pybind11 caster

* Support Shape cast

* refine

* fix size index

* include size header if need export C++ Shape to Python.
上级 b4ec60f7
......@@ -17,6 +17,7 @@ limitations under the License.
#include <pybind11/functional.h>
#include <pybind11/stl.h>
#include <functional>
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/instructions_builder.h"
#include "oneflow/core/framework/tensor.h"
......
......@@ -16,6 +16,7 @@ limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/parallel_conf_util.h"
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <vector>
#include <pybind11/pybind11.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/shape.h"
namespace py = pybind11;
namespace oneflow {
namespace {
struct ShapeExportUtil final {
static Maybe<Shape> MakeShape(const py::tuple& py_shape) {
DimVector shape_dims;
for (const auto& dim : py_shape) { shape_dims.emplace_back(dim.cast<int64_t>()); }
return std::make_shared<Shape>(shape_dims);
}
static std::shared_ptr<Shape> ApiMakeShape(const py::object& py_obj) {
if (py::isinstance<Shape>(py_obj)) {
return std::make_shared<Shape>(py_obj.cast<Shape>().dim_vec());
} else if (py::isinstance<py::tuple>(py_obj)) {
return MakeShape(py_obj.cast<py::tuple>()).GetPtrOrThrow();
} else if (py::isinstance<py::list>(py_obj)) {
return MakeShape(py::tuple(py_obj.cast<py::list>())).GetPtrOrThrow();
} else {
throw py::type_error("Input must be Tuple, List or oneflow.Size");
}
}
static int GetItem(const Shape& shape, int idx) {
const int len = shape.dim_vec().size();
if (idx < -len || idx >= len) { throw py::index_error("oneflow.Size index out of range"); }
if (idx < 0) { idx += len; }
return shape.At(idx);
}
static std::shared_ptr<Shape> Slicing(const Shape& shape, const py::slice& slice) {
size_t start, stop, step, slicelength;
if (!slice.compute(shape.dim_vec().size(), &start, &stop, &step, &slicelength)) {
throw py::error_already_set();
}
DimVector shape_dims;
for (size_t i = 0; i < slicelength; ++i) {
shape_dims.emplace_back(shape.dim_vec().at(start));
start += step;
}
return std::make_shared<Shape>(shape_dims);
}
static std::string ToString(const Shape& shape) {
std::stringstream ss;
int32_t idx = 0;
ss << "oneflow.Size([";
for (int64_t dim : shape.dim_vec()) {
ss << dim;
if (++idx != shape.dim_vec().size()) { ss << ", "; }
}
ss << "])";
return ss.str();
}
static int GetIndexOrError(const Shape& shape, int64_t value, int start = 0,
int end = SHAPE_MAX_AXIS_SIZE) {
if (end > shape.dim_vec().size()) { end = shape.dim_vec().size(); }
const auto& it =
std::find(shape.dim_vec().begin() + start, shape.dim_vec().begin() + end, value);
if (it == shape.dim_vec().begin() + end) {
throw std::invalid_argument("tuple.index(x): x not in tuple");
}
return std::distance(shape.dim_vec().begin(), it);
}
static bool IsEqual(const Shape& shape, const py::object& py_obj) {
std::shared_ptr<Shape> other;
if (py::isinstance<Shape>(py_obj)) {
other = std::make_shared<Shape>(py_obj.cast<Shape>());
} else if (py::isinstance<py::tuple>(py_obj)) {
other = ApiMakeShape(py_obj.cast<py::tuple>());
} else {
return false;
}
if (shape.NumAxes() != other->NumAxes()) { return false; }
for (int i = 0; i < shape.NumAxes(); i++) {
if (shape.At(i) != other->At(i)) { return false; }
}
return true;
}
};
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<Shape, std::shared_ptr<Shape>>(m, "Size")
.def(py::init(&ShapeExportUtil::ApiMakeShape))
.def("__str__", &ShapeExportUtil::ToString)
.def("__repr__", &ShapeExportUtil::ToString)
.def("__getitem__", &ShapeExportUtil::GetItem)
.def("__getitem__", &ShapeExportUtil::Slicing)
.def(
"__iter__",
[](const Shape& shape) {
return py::make_iterator(shape.dim_vec().begin(), shape.dim_vec().end());
},
py::keep_alive<0, 1>())
.def("__len__", [](const Shape& shape) { return shape.NumAxes(); })
.def("__eq__", &ShapeExportUtil::IsEqual)
.def("numel", [](const Shape& shape) { return shape.elem_cnt(); })
.def("count",
[](const Shape& shape, int64_t value) {
return std::count(shape.dim_vec().begin(), shape.dim_vec().end(), value);
})
.def("index", &ShapeExportUtil::GetIndexOrError, py::arg(), py::arg("start") = 0,
py::arg("end") = SHAPE_MAX_AXIS_SIZE);
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include "oneflow/api/python/functional/common.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/shape.h"
namespace py = pybind11;
namespace oneflow {
using one::functional::PyObjectPtr;
static PyObject* TensorSize_repr(TensorSize* self) {
std::stringstream ss;
int32_t idx = 0;
int32_t size = PyTuple_Size((PyObject*)self);
ss << "oneflow.Size([";
for (int i = 0; i < size; ++i) {
int64_t dim = PyLong_AsLongLong(PyTuple_GET_ITEM(self, i));
ss << dim;
if (++idx != size) { ss << ", "; }
}
ss << "])";
return PyUnicode_FromString(ss.str().c_str());
}
static PyObject* TensorSize_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
PyObjectPtr self(PyTuple_Type.tp_new(type, args, kwargs));
if (self.get()) {
for (int i = 0; i < PyTuple_Size(self.get()); ++i) {
PyObject* item = PyTuple_GET_ITEM(self.get(), i);
if (!PyLong_Check(item)) {
return PyErr_Format(PyExc_TypeError,
"oneflow.Size() takes an iterable of 'int', but item '%d' is '%s'", i,
Py_TYPE(item)->tp_name);
}
}
}
return self.release();
}
static Py_ssize_t TensorSize_length(TensorSize* self) {
return PyTuple_Type.tp_as_sequence->sq_length((PyObject*)self);
}
static PyObject* TensorSize_concat(TensorSize* self, PyObject* other) {
PyObjectPtr result(PyTuple_Type.tp_as_sequence->sq_concat((PyObject*)self, other));
if (!result.get()) { return nullptr; }
if (PyTuple_Check(result.get())) {
PyObjectPtr args(PyTuple_Pack(1, result.get()));
return TensorSize_new(&TensorSize_Type, args.get(), nullptr);
}
return result.release();
}
static PyObject* TensorSize_repeat(TensorSize* self, Py_ssize_t n) {
PyObjectPtr result(PyTuple_Type.tp_as_sequence->sq_repeat((PyObject*)self, n));
if (!result.get()) { return nullptr; }
if (PyTuple_Check(result.get())) {
PyObjectPtr args(PyTuple_Pack(1, result.get()));
return TensorSize_new(&TensorSize_Type, args.get(), nullptr);
}
return result.release();
}
static PyObject* TensorSize_item(TensorSize* self, Py_ssize_t i) {
return PyTuple_Type.tp_as_sequence->sq_item((PyObject*)self, i);
}
static int TensorSize_contains(TensorSize* self, PyObject* el) {
return PyTuple_Type.tp_as_sequence->sq_contains((PyObject*)self, el);
}
static PySequenceMethods TensorSize_as_sequence = {
(lenfunc)TensorSize_length, /* sq_length */
(binaryfunc)TensorSize_concat, /* sq_concat */
(ssizeargfunc)TensorSize_repeat, /* sq_repeat */
(ssizeargfunc)TensorSize_item, /* sq_item */
0, /* sq_slice */
0, /* sq_ass_item */
0, /* sq_ass_slice */
(objobjproc)TensorSize_contains, /* sq_contains */
};
static PyObject* TensorSize_subscript(TensorSize* self, PyObject* item) {
PyObjectPtr result(PyTuple_Type.tp_as_mapping->mp_subscript((PyObject*)self, item));
if (!result.get()) { return nullptr; }
if (PyTuple_Check(result.get())) {
PyObjectPtr args(PyTuple_Pack(1, result.get()));
return TensorSize_new(&TensorSize_Type, args.get(), nullptr);
}
return result.release();
};
static PyMappingMethods TensorSize_as_mapping = {
(lenfunc)TensorSize_length, /* mp_length */
(binaryfunc)TensorSize_subscript, /* mp_subscript */
0, /* mp_ass_subscript */
};
static PyObject* TensorSize_numel(PyObject* self, PyObject* args) {
int64_t numel = 1;
for (int i = 0; i < PyTuple_Size(self); ++i) {
numel *= PyLong_AsLongLong(PyTuple_GET_ITEM((TensorSize*)self, i));
}
return PyLong_FromLongLong(numel);
}
static PyMethodDef TensorSize_methods[] = {
{"numel", (PyCFunction)TensorSize_numel, METH_NOARGS, NULL}, {NULL}};
PyTypeObject TensorSize_Type = {
PyVarObject_HEAD_INIT(NULL, 0) "oneflow.Size", /* tp_name */
sizeof(TensorSize), /* tp_basicsize */
0, /* tp_itemsize */
NULL, /* tp_dealloc */
0, /* tp_vectorcall_offset */
NULL, /* tp_getattr */
NULL, /* tp_setattr */
NULL, /* tp_reserved */
(reprfunc)TensorSize_repr, /* tp_repr */
NULL, /* tp_as_number */
&TensorSize_as_sequence, /* tp_as_sequence */
&TensorSize_as_mapping, /* tp_as_mapping */
NULL, /* tp_hash */
NULL, /* tp_call */
NULL, /* tp_str */
NULL, /* tp_getattro */
NULL, /* tp_setattro */
NULL, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
NULL, /* tp_doc */
NULL, /* tp_traverse */
NULL, /* tp_clear */
NULL, /* tp_richcompare */
0, /* tp_weaklistoffset */
NULL, /* tp_iter */
NULL, /* tp_iternext */
TensorSize_methods, /* tp_methods */
NULL, /* tp_members */
NULL, /* tp_getset */
&PyTuple_Type, /* tp_base */
NULL, /* tp_dict */
NULL, /* tp_descr_get */
NULL, /* tp_descr_set */
0, /* tp_dictoffset */
NULL, /* tp_init */
NULL, /* tp_alloc */
TensorSize_new, /* tp_new */
NULL, /* tp_free */
};
int TensorSize_Check(PyObject* p) { return p && p->ob_type == &TensorSize_Type; }
PyObject* TensorSize_New(Py_ssize_t len) { return TensorSize_Type.tp_alloc(&TensorSize_Type, len); }
PyObject* TensorSize_NewFromShape(const Shape& size) {
PyObjectPtr self(TensorSize_New(size.NumAxes()));
if (self.get()) {
for (int i = 0; i < size.NumAxes(); ++i) {
PyTuple_SET_ITEM(self.get(), i, PyLong_FromLongLong(size.At(i)));
}
}
return self.release();
}
Shape TensorSize_AsShape(PyObject* self) {
if (!TensorSize_Check(self)) {
PyErr_Format(PyExc_TypeError, "can only convert TensorSize(not \"%s\") to Shape",
Py_TYPE(self)->tp_name);
return Shape();
}
int size = TensorSize_length((TensorSize*)self);
DimVector dim_vec(size);
for (int i = 0; i < size; ++i) {
dim_vec[i] = PyLong_AsLongLong(PyTuple_GET_ITEM((TensorSize*)self, i));
}
return Shape(std::move(dim_vec));
}
ONEFLOW_API_PYBIND11_MODULE("", m) {
if (PyType_Ready(&TensorSize_Type) < 0) { return; }
Py_INCREF(&TensorSize_Type);
if (PyModule_AddObject(m.ptr(), "Size", (PyObject*)&TensorSize_Type) < 0) { return; }
}
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_
#define ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_
#include <type_traits>
#include <Python.h>
#include <pybind11/pybind11.h>
#include "oneflow/core/common/shape.h"
namespace oneflow {
typedef struct {
PyTupleObject ob_base;
} TensorSize;
extern PyTypeObject TensorSize_Type;
int TensorSize_Check(PyObject* p);
PyObject* TensorSize_New(Py_ssize_t len);
PyObject* TensorSize_NewFromShape(const Shape& size);
Shape TensorSize_AsShape(PyObject* self);
} // namespace oneflow
PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
class shape : public object {
public:
PYBIND11_OBJECT_CVT(shape, object, oneflow::TensorSize_Check, raw_shape)
explicit shape(size_t size = 0) : object(oneflow::TensorSize_New((ssize_t)size), stolen_t{}) {
if (!m_ptr) pybind11_fail("Could not allocate tensor size object!");
}
size_t size() const { return (size_t)PyTuple_Size(m_ptr); }
bool empty() const { return size() == 0; }
detail::tuple_accessor operator[](size_t index) const { return {*this, index}; }
detail::item_accessor operator[](handle h) const { return object::operator[](h); }
detail::tuple_iterator begin() const { return {*this, 0}; }
detail::tuple_iterator end() const { return {*this, PyTuple_GET_SIZE(m_ptr)}; }
private:
static PyObject* raw_shape(PyObject* op) {
if (oneflow::TensorSize_Check(op)) return handle(op).inc_ref().ptr();
return PyObject_CallFunctionObjArgs((PyObject*)&oneflow::TensorSize_Type, op, NULL);
}
};
PYBIND11_NAMESPACE_BEGIN(detail)
template<typename T>
struct shape_type_caster {
public:
bool load(handle src, bool convert) {
value_ = nullptr;
if (src && src.is_none()) { return true; }
if (!oneflow::TensorSize_Check(src.ptr())) { return false; }
value_ = std::make_shared<T>(oneflow::TensorSize_AsShape(src.ptr()));
return true;
}
template<typename U>
static handle cast(U&& src, return_value_policy /*policy*/, handle /*parent*/) {
return cast_impl(std::forward<U>(src));
}
template<typename U>
static handle cast(U* src, return_value_policy policy, handle parent) {
if (!src) { return none().release(); }
return cast(*src, policy, parent);
}
operator T*() { return value_.get(); }
operator T&() { return *value_; }
operator T&&() && { return std::move(*value_); }
operator std::shared_ptr<T>*() { return &value_; }
operator std::shared_ptr<T>&() { return value_; }
operator std::shared_ptr<T>&&() && { return std::move(value_); }
static constexpr auto name = _("shape");
template<typename U>
using cast_op_type = pybind11::detail::cast_op_type<std::shared_ptr<T>>;
private:
static handle cast_impl(const oneflow::Shape& src) {
return reinterpret_steal<shape>(oneflow::TensorSize_NewFromShape(src)).release();
}
static handle cast_impl(const std::shared_ptr<const oneflow::Shape>& src) {
return reinterpret_steal<shape>(oneflow::TensorSize_NewFromShape(*src)).release();
}
protected:
std::shared_ptr<T> value_;
};
template<>
struct type_caster<oneflow::Shape> : public shape_type_caster<oneflow::Shape> {};
template<>
struct type_caster<std::shared_ptr<oneflow::Shape>> : public shape_type_caster<oneflow::Shape> {};
template<>
struct type_caster<std::shared_ptr<const oneflow::Shape>>
: public shape_type_caster<const oneflow::Shape> {};
PYBIND11_NAMESPACE_END(detail)
PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
#endif // ONEFLOW_API_PYTHON_FRAMEWORK_SIZE_H_
......@@ -19,6 +19,7 @@ limitations under the License.
#include <pybind11/numpy.h>
#include "oneflow/api/python/framework/throw.h"
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/api/python/ofblob/ofblob.e.h"
#include "oneflow/api/python/utils/tensor_utils.h"
......
......@@ -109,16 +109,6 @@ Maybe<Symbol<DType>> PyUnpackDType(PyObject* obj) {
return *py::cast<Symbol<DType>*>(handle);
}
// Shape
bool PyShapeCheck(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj);
return py::isinstance<Shape>(handle);
}
Maybe<Shape> PyUnpackShape(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj);
return py::cast<std::shared_ptr<Shape>>(obj);
}
// Generator
bool PyGeneratorCheck(PyObject* obj) {
auto handle = py::reinterpret_borrow<py::object>(obj);
......
......@@ -127,10 +127,6 @@ Maybe<TensorTuple> PyUnpackTensorTuple(PyObject* obj);
bool PyDTypeCheck(PyObject* obj);
Maybe<Symbol<DType>> PyUnpackDType(PyObject* obj);
// Shape
bool PyShapeCheck(PyObject* obj);
Maybe<Shape> PyUnpackShape(PyObject* obj);
// Generator
bool PyGeneratorCheck(PyObject* obj);
Maybe<Generator> PyUnpackGenerator(PyObject* obj);
......
......@@ -107,7 +107,6 @@ Maybe<Symbol<DType>> PythonArg::ObjectAs<Symbol<DType>>() const {
template<>
Maybe<Shape> PythonArg::ObjectAs<Shape>() const {
if (PyShapeCheck(object_)) { return PyUnpackShape(object_); }
const auto& shape = JUST(PyUnpackLongSequence<int64_t>(object_));
return std::make_shared<Shape>(DimVector(shape->begin(), shape->end()));
}
......@@ -193,7 +192,7 @@ Maybe<bool> PythonArg::TypeCheck(ValueType type) const {
case kTENSOR_REF: return PyTensorCheck(object_);
case kTENSOR_TUPLE: return PyTensorTupleCheck(object_) || PyTensorSequenceCheck(object_);
case kDTYPE: return PyDTypeCheck(object_);
case kSHAPE: return PyShapeCheck(object_) || PyLongSequenceCheck(object_);
case kSHAPE: return PyLongSequenceCheck(object_);
case kGENERATOR:
case kGENERATOR_REF: return PyGeneratorCheck(object_);
case kTENSOR_INDEX: return PyTensorIndexCheck(object_);
......
......@@ -16,6 +16,7 @@ limitations under the License.
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/operators.h>
#include "oneflow/api/python/framework/size.h"
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/control/global_process_ctx.h"
#include "oneflow/core/common/symbol.h"
......
......@@ -100,17 +100,17 @@ class TestSize(flow.unittest.TestCase):
def test_index(test_case):
size = flow.Size((2, 3, 2, 4, 4))
test_case.assertEqual(size.index(2), 0)
test_case.assertEqual(size.index(2, start=0), 0)
test_case.assertEqual(size.index(2, start=0, end=20), 0)
test_case.assertEqual(size.index(2, start=1, end=20), 2)
test_case.assertEqual(size.index(2, 0), 0)
test_case.assertEqual(size.index(2, 0, 20), 0)
test_case.assertEqual(size.index(2, 1, 20), 2)
test_case.assertEqual(size.index(4), 3)
test_case.assertEqual(size.index(4, start=4), 4)
test_case.assertEqual(size.index(4, 4), 4)
with test_case.assertRaises(ValueError):
size.index(4, start=0, end=3)
size.index(4, 0, 3)
with test_case.assertRaises(ValueError):
size.index(5)
with test_case.assertRaises(ValueError):
size.index(2, start=3)
size.index(2, 3)
def test_slicing(test_case):
size = flow.Size([2, 3, 4, 5])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册