提交 e400b7ff 编写于 作者: M Megvii Engine Team

perf(imperative): enable memory forwarding for imperative

GitOrigin-RevId: 7c1993979c051a1b01e168eefaa03a0386ddc7bc
上级 84d1a440
......@@ -285,7 +285,8 @@ struct TensorLayout : public TensorShape {
* stride
*/
void add_axis_cont_inplace(size_t axis) {
add_axis_inplace(axis, 1, stride[axis] * shape[axis]);
ptrdiff_t stride_ = axis < ndim ? stride[axis] * shape[axis] : 1;
add_axis_inplace(axis, 1, stride_);
}
/*!
......
......@@ -382,7 +382,7 @@ bool TensorLayout::eq_layout(const TensorLayout& rhs) const {
MEGDNN_STATIC_ASSERT(MAX_NDIM == 7, "please update the code");
auto ax = [](size_t shape0, size_t shape1, ptrdiff_t stride0, ptrdiff_t stride1) {
return (shape0 == shape1) & ((shape0 == 1) | (stride0 == stride1));
return (shape0 == shape1) & ((shape0 <= 1) | (stride0 == stride1));
};
if (ndim == rhs.ndim) {
size_t eq = 0;
......
......@@ -13,7 +13,8 @@
using namespace megdnn;
const std::shared_ptr<Handle>& megdnn::inplace_cpu_handle(int debug_level) {
MGE_WIN_DECLSPEC_FUC const std::shared_ptr<Handle>& megdnn::inplace_cpu_handle(
int debug_level) {
auto make = [](int deb_level) {
megcoreDeviceHandle_t dev_handle;
megcoreCreateDeviceHandle(&dev_handle, megcorePlatformCPU);
......
......@@ -32,6 +32,7 @@
#include "./module_trace.h"
#include "./numpy_dtypes.h"
#include "./tensor.h"
#include "./tensor_utils.h"
#include "./transformation.h"
#include <object.h>
......@@ -549,557 +550,6 @@ CompNode _get_device(PyObject* const* args, size_t nargs) {
return cn;
}
bool is_scalar(PyObject* tensor) {
if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
auto var = py::handle(tensor).cast<PySymbolVar*>();
return var->is_scalar;
}
auto* tw = TensorWrapper::try_cast(tensor);
if (tw) {
return tw->m_tensor->is_scalar();
}
return PyArray_CheckAnyScalar(tensor);
}
bool is_bool_list(PyObject* arg) {
if (!PyList_Check(arg)) {
return false;
}
size_t sz = PyList_Size(arg);
if (!sz) {
return false;
}
for (size_t i = 0; i < sz; ++i) {
PyObject* handle = PyList_GetItem(arg, i);
if (!PyBool_Check(handle)) {
return false;
}
}
return true;
}
bool is_bool_dtype(PyObject* args) {
if (!PyObject_HasAttrString(args, "dtype"))
return false;
PyObject* dobj = PyObject_GetAttrString(args, "dtype");
PyArray_Descr* dtype;
PyArray_DescrConverter(dobj, &dtype);
bool ret = (dtype->kind == 'b');
Py_XDECREF(dtype);
Py_XDECREF(dobj);
return ret;
}
py::object _Const(
py::handle value, py::handle dtype, py::handle device, py::handle ref) {
py::object val = py::reinterpret_borrow<py::object>(value);
if (PyArray_Check(value.ptr())) {
py::tuple strides =
py::reinterpret_borrow<py::tuple>(getattr(value, "strides"));
bool need_squeeze = false;
for (size_t i = 0; i < strides.size(); ++i) {
if (strides[i].cast<ptrdiff_t>() == 0) {
need_squeeze = true;
}
}
if (need_squeeze) {
val = py::reinterpret_borrow<py::array>(value);
val = val.attr("squeeze")();
val = val.attr("reshape")(val.attr("shape"));
}
}
if (py::isinstance<PySymbolVar>(ref)) {
auto ref_var = ref.cast<PySymbolVar*>();
auto* graph = ref_var->m_node->owner_graph();
auto cn = device.cast<CompNode>();
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
}
py::tuple _make_shape_tuple(py::handle shape) {
py::list orig;
py::list ret(0);
auto solve_one = [&](py::handle val) {
if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) {
py::object np = getattr(val, "numpy")();
PyArrayObject* arr = (PyArrayObject*)np.ptr();
PyObject* maybe_list = PyArray_ToList(arr);
if (PyList_Check(maybe_list)) {
py::list may = py::reinterpret_steal<py::list>(maybe_list);
for (size_t i = 0; i < may.size(); ++i) {
ret.append(may[i]);
}
} else {
mgb_assert(PyLong_Check(maybe_list));
ret.append(PyLong_AsLong(maybe_list));
Py_XDECREF(maybe_list);
}
} else if (PyArray_Check(val.ptr())) {
ret.append(PyArray_PyIntAsInt(val.ptr()));
} else {
ret.append(PyLong_AsLong(val.ptr()));
}
};
if (PyArray_Check(shape.ptr()) && !PyArray_CheckAnyScalar(shape.ptr())) {
orig = py::reinterpret_steal<py::list>(
PyArray_ToList((PyArrayObject*)shape.ptr()));
for (size_t i = 0; i < orig.size(); ++i) {
solve_one(orig[i]);
}
} else if (PyList_Check(shape.ptr())) {
orig = py::reinterpret_borrow<py::list>(shape);
for (size_t i = 0; i < orig.size(); ++i) {
solve_one(orig[i]);
}
} else if (PyTuple_Check(shape.ptr())) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
for (size_t i = 0; i < tup.size(); ++i) {
solve_one(tup[i]);
}
} else {
solve_one(shape);
}
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
}
py::object _get_index(py::object tensor, py::object src) {
if (!TensorWrapper::try_cast(tensor.ptr()) &&
!py::isinstance<PySymbolVar>(tensor)) {
auto get_const = [&](mgb::DType dtype) -> py::object {
return _Const(tensor, py::cast(dtype), src.attr("device"), src);
};
if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) {
tensor = get_const(dtype::Bool());
} else {
tensor = get_const(dtype::Int32());
}
if (!is_bool_dtype(tensor.ptr())) {
return tensor;
}
} else {
if (!is_bool_dtype(tensor.ptr())) {
return tensor;
}
}
static std::shared_ptr<OpDef> op = CondTake::make();
std::vector<PyObject*> p;
p.resize(3);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
p[2] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[1];
}
py::tuple _try_cond_take(py::handle tensor, py::handle index) {
if (!hasattr(index, "dtype") || !hasattr(index, "shape")) {
return py::tuple();
}
if (!is_bool_dtype(index.ptr()) ||
_make_shape_tuple(getattr(index, "shape"))
.not_equal(_make_shape_tuple(getattr(tensor, "shape")))) {
return py::tuple();
}
py::object iobj;
if (PyArray_Check(index.ptr())) {
iobj =
_Const(index, py::cast((mgb::DType)dtype::Bool()),
getattr(tensor, "device"), tensor);
} else {
iobj = py::reinterpret_borrow<py::object>(index);
}
static std::shared_ptr<OpDef> op = CondTake::make();
std::vector<PyObject*> p;
p.resize(3);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
p[2] = iobj.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret;
}
py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
size_t tuple_size = tuple_val.size();
size_t ndim_sum = 0, cur_sum = 0;
int pos = -1;
bool has_unknown_ndim_bool_index = false;
for (size_t i = 0; i < tuple_size; ++i) {
py::object handle = tuple_val[i];
if (handle.ptr() == Py_Ellipsis) {
pos = static_cast<int>(i);
for (size_t j = 0; j < i; ++j) {
py::object t = tuple_val[j];
if (t.ptr() == Py_Ellipsis) {
throw py::index_error("only one ellipsis is allowed.");
}
}
} else {
size_t ndim_incr = 1;
if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) &&
hasattr(handle, "ndim")) {
py::object ndim = getattr(handle, "ndim");
if (PyLong_Check(ndim.ptr())) {
ndim_incr = PyLong_AsLong(ndim.ptr());
} else {
has_unknown_ndim_bool_index = true;
}
}
cur_sum += ndim_incr;
}
}
if (pos == -1) {
return tuple_val;
} else {
if (has_unknown_ndim_bool_index) {
throw py::index_error(
"does not support bool index with unknown shape when using "
"Ellipsis.");
}
try {
ndim_sum = getattr(tensor, "ndim").cast<size_t>();
} catch (py::error_already_set& err) {
throw py::index_error(
"does not support Ellipsis when tensor's ndim is unknown.");
}
py::tuple ret(ndim_sum - cur_sum + tuple_size - 1);
size_t idx = 0;
for (size_t i = 0; i < tuple_size; ++i) {
if (i == pos) {
for (size_t j = cur_sum; j < ndim_sum; ++j) {
ret[idx++] = PySlice_New(NULL, NULL, NULL);
}
} else {
ret[idx++] = tuple_val[i];
}
}
return ret;
}
}
py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape")));
py::list new_tuple_val(0);
size_t offset = 0;
size_t tdim = 0;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::handle k = tuple_val[i];
if (is_bool_dtype(k.ptr())) {
size_t ndim = getattr(k, "ndim").cast<size_t>();
if (ndim > 1) {
py::tuple ishape = _make_shape_tuple(py::handle(getattr(k, "shape")));
for (size_t j = 0; j < ndim; ++j) {
if (cur_shape[tdim + j - offset].cast<size_t>() !=
ishape[j].cast<size_t>()) {
std::string msg =
"boolean index did not match tensor along dimension " +
std::to_string(tdim + j) + "; dimension is " +
std::to_string(
cur_shape[tdim + j - offset].cast<size_t>()) +
" but corresponding boolean dimension is " +
std::to_string(ishape[j].cast<size_t>());
throw py::index_error(msg.c_str());
}
}
py::object new_k = getattr(k, "reshape")(-1);
py::object kshape = getattr(new_k, "shape");
py::list new_shape(0);
PyObject* sym = PyObject_CallObject(cpp_use_symbolic_shape, nullptr);
bool is_sym = (sym == Py_True);
Py_XDECREF(sym);
if (is_sym) {
py::object tshape = getattr(tensor, "shape");
for (size_t j = 0; j < i; ++j) {
new_shape.append(tshape[py::int_(j)]);
}
new_shape.append(kshape[py::int_(0)]);
for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
new_shape.append(cur_shape[j]);
}
py::tuple args = py::make_tuple(new_shape);
PyObject* shape_tensor =
PyObject_CallObject(cpp_astensor1d, args.ptr());
py::object reshape_func = getattr(tensor, "reshape");
Py_INCREF(shape_tensor);
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, shape_tensor);
PyObject* new_tensor =
PyObject_CallObject(reshape_func.ptr(), Args);
Py_XDECREF(Args);
tensor = py::reinterpret_steal<py::object>(new_tensor);
cur_shape = _make_shape_tuple(py::handle(shape_tensor));
Py_XDECREF(shape_tensor);
} else {
for (size_t j = 0; j < i; ++j) {
new_shape.append(cur_shape[j]);
}
new_shape.append(py::reinterpret_borrow<py::tuple>(kshape)[0]);
for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
new_shape.append(cur_shape[j]);
}
cur_shape = new_shape;
tensor = getattr(tensor, "reshape")(cur_shape);
}
offset++;
tdim += ndim;
}
new_tuple_val.append(k);
} else {
new_tuple_val.append(k);
tdim++;
}
}
return py::make_tuple(tensor, py::reinterpret_borrow<py::tuple>(new_tuple_val));
}
py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
py::tuple tuple_val;
if (py::isinstance<py::tuple>(idx_hdl)) {
tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
} else {
tuple_val = py::make_tuple(idx_hdl);
}
bool use_subtensor = true;
bool need_remove_ellipsis = false;
bool need_expand_bool_dim = false;
size_t idx_ndim = 0;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::object k = tuple_val[i];
if (k.ptr() == Py_None) {
throw py::index_error("newaxis is not allowed here");
} else if (k.ptr() == Py_Ellipsis) {
need_remove_ellipsis = true;
} else {
if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) {
size_t ndim = getattr(k, "ndim").cast<size_t>();
idx_ndim += ndim;
if (ndim > 1) {
need_expand_bool_dim = true;
}
} else {
idx_ndim++;
}
}
}
try {
size_t inp_ndim = getattr(inp, "ndim").cast<size_t>();
if (idx_ndim > inp_ndim) {
std::string msg = "too many indices for tensor: tensor is " +
std::to_string(inp_ndim) + "-dimensional, but " +
std::to_string(idx_ndim) + " were indexed";
throw py::index_error(msg.c_str());
}
} catch (py::error_already_set& err) {
; // ignore
}
if (need_remove_ellipsis) {
tuple_val = _remove_ellipsis(inp, tuple_val);
}
if (need_expand_bool_dim) {
py::object shape = getattr(inp, "shape");
if (shape.ptr() != Py_None) {
py::tuple ret = _expand_bool_dim(inp, tuple_val);
inp = ret[0];
tuple_val = ret[1];
}
}
py::list items;
py::list tensors;
int cur_axis = -1;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::object handle = tuple_val[i];
cur_axis++;
if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) {
use_subtensor = false;
}
py::list item;
item.append(cur_axis);
auto push = [&](PyObject* v) {
if (v == Py_None) {
item.append(false);
} else {
item.append(true);
tensors.append(_get_index(py::reinterpret_borrow<py::object>(v), inp));
}
};
if (PySlice_Check(handle.ptr())) {
PySliceObject* s = (PySliceObject*)handle.ptr();
if (s->start == Py_None && s->stop == Py_None && s->step == Py_None) {
continue;
}
push(s->start);
push(s->stop);
push(s->step);
item.append(false);
} else {
for (size_t j = 0; j < 3; j++)
item.append(false);
push(handle.ptr());
}
items.append(item);
}
return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim);
}
py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
if (try_res.size() == 2) {
return try_res[0];
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
py::object org_shape = getattr(inp_hdl, "shape");
py::object val = py::reinterpret_borrow<py::object>(val_hdl);
if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
val =
_Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
inp_hdl);
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op, set_op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
py::object tmp_result = ret[0];
try {
py::object value_tuple_shape = val.attr("_tuple_shape");
py::object tmp_result_tuple_shape = tmp_result.attr("_tuple_shape");
py::tuple value_shape = py::reinterpret_borrow<py::tuple>(value_tuple_shape);
py::tuple tmp_result_shape =
py::reinterpret_borrow<py::tuple>(tmp_result_tuple_shape);
for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
size_t ts =
tmp_result_shape[tmp_result_shape.size() - i - 1].cast<size_t>();
if (vs != 1 && vs != ts) {
std::string lhs = "", rhs = "";
for (size_t j = 0; j < tmp_result_shape.size(); ++j) {
lhs += std::to_string(tmp_result_shape[j].cast<size_t>());
if (j)
lhs += ",";
}
for (size_t j = 0; j < value_shape.size(); ++j) {
rhs += std::to_string(value_shape[j].cast<size_t>());
if (j)
rhs += ",";
}
throw py::value_error(
"cannot copy tensor with shape (" + rhs +
") to subtensor with shape (" + lhs + ")");
}
}
} catch (py::error_already_set& err) {
;
}
py::object broadcast_func = getattr(val, "_broadcast");
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, getattr(tmp_result, "shape").release().ptr());
PyObject* new_val = PyObject_CallObject(broadcast_func.ptr(), Args);
Py_XDECREF(Args);
val = py::reinterpret_steal<py::object>(new_val);
if (up[3].cast<bool>()) {
set_op = SetSubtensor::make(cpp_items);
} else {
set_op = IndexingSetMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> q;
q.resize(tensors.size() + 3);
py::object Set_Op = py::cast(set_op);
q[0] = Set_Op.ptr();
q[1] = tensor.ptr();
q[2] = val.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
q[i + 3] = tensors[i].ptr();
}
py::tuple result =
py::reinterpret_steal<py::object>(py_apply(NULL, q.data(), q.size()));
py::object res = result[0];
if (up[4].cast<bool>()) {
py::object reshape_func = getattr(res, "reshape");
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, org_shape.release().ptr());
PyObject* new_tensor = PyObject_CallObject(reshape_func.ptr(), Args);
Py_XDECREF(Args);
res = py::reinterpret_steal<py::object>(new_tensor);
}
return res;
}
// Returns the dtype that would result from performing an arithmetic
// operation on the provided input tensors and scalars.
PyObject* dtype_promotion(PyObject* self, PyObject* const* args, size_t nargs) {
......@@ -1126,30 +576,6 @@ PyObject* get_device(PyObject* self, PyObject* const* args, size_t nargs) {
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _getitem_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _setitem_cpp(
py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
#ifdef METH_FASTCALL
#define MGE_PY_INTERFACE(NAME, FUNC) \
{ #NAME, (PyCFunction)FUNC, METH_FASTCALL, nullptr }
......
......@@ -38,6 +38,8 @@ namespace mgb::imperative::python {
extern interpreter::Interpreter::Channel* interpreter_for_py;
extern PyTypeObject* py_tensor_type;
extern PyObject* cpp_use_symbolic_shape;
extern PyObject* cpp_astensor1d;
struct Tensor : NonCopyableObj {
private:
......
/**
* \file imperative/python/src/tensor.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/common.h"
#include "megbrain/dtype.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/profiler.h"
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/lazy.h"
#include "megbrain/imperative/transformations/scalar.h"
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/io.h"
#include "megbrain/plugin/profiler.h"
#include "./common.h"
#include "./grad.h"
#include "./graph_rt.h"
#include "./helper.h"
#include "./module_trace.h"
#include "./numpy_dtypes.h"
#include "./tensor.h"
#include "./tensor_utils.h"
#include "./transformation.h"
#include <object.h>
#include <pybind11/numpy.h>
#include <pybind11/operators.h>
#include <pybind11/pytypes.h>
#include <pyerrors.h>
#include <range/v3/all.hpp>
#include <string>
#include <unordered_map>
#include "../../src/impl/mgb_cg_impl.h"
namespace py = pybind11;
namespace views = ranges::views;
namespace mgb::imperative::python {
bool is_scalar(PyObject* tensor) {
if (py::isinstance<PySymbolVar>(py::handle(tensor))) {
auto var = py::handle(tensor).cast<PySymbolVar*>();
return var->is_scalar;
}
auto* tw = TensorWrapper::try_cast(tensor);
if (tw) {
return tw->m_tensor->is_scalar();
}
return PyArray_CheckAnyScalar(tensor);
}
bool is_bool_list(PyObject* arg) {
if (!PyList_Check(arg)) {
return false;
}
size_t sz = PyList_Size(arg);
if (!sz) {
return false;
}
for (size_t i = 0; i < sz; ++i) {
PyObject* handle = PyList_GetItem(arg, i);
if (!PyBool_Check(handle)) {
return false;
}
}
return true;
}
bool is_bool_dtype(PyObject* args) {
if (!PyObject_HasAttrString(args, "dtype"))
return false;
PyObject* dobj = PyObject_GetAttrString(args, "dtype");
PyArray_Descr* dtype;
PyArray_DescrConverter(dobj, &dtype);
bool ret = (dtype->kind == 'b');
Py_XDECREF(dtype);
Py_XDECREF(dobj);
return ret;
}
py::object _Const(
py::handle value, py::handle dtype, py::handle device, py::handle ref) {
py::object val = py::reinterpret_borrow<py::object>(value);
if (PyArray_Check(value.ptr())) {
py::tuple strides =
py::reinterpret_borrow<py::tuple>(getattr(value, "strides"));
bool need_squeeze = false;
for (size_t i = 0; i < strides.size(); ++i) {
if (strides[i].cast<ptrdiff_t>() == 0) {
need_squeeze = true;
}
}
if (need_squeeze) {
val = py::reinterpret_borrow<py::array>(value);
val = val.attr("squeeze")();
val = val.attr("reshape")(val.attr("shape"));
}
}
if (py::isinstance<PySymbolVar>(ref)) {
auto ref_var = ref.cast<PySymbolVar*>();
auto* graph = ref_var->m_node->owner_graph();
auto cn = device.cast<CompNode>();
OperatorNodeConfig config(cn);
auto hv = npy::np2tensor(
val.ptr(), npy::Meth::borrow(cn), dtype.cast<mgb::DType>());
auto typeobj = ref.get_type();
return typeobj(opr::ImmutableTensor::make(*graph, hv, config).node());
}
py::tuple tup = py::make_tuple(val, dtype, device, true, false, py::none());
return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr);
}
py::tuple _make_shape_tuple(py::handle shape) {
py::list orig;
py::list ret(0);
auto solve_one = [&](py::handle val) {
if (TensorWrapper::try_cast(val.ptr()) || py::isinstance<PySymbolVar>(val)) {
py::object np = getattr(val, "numpy")();
PyArrayObject* arr = (PyArrayObject*)np.ptr();
PyObject* maybe_list = PyArray_ToList(arr);
if (PyList_Check(maybe_list)) {
py::list may = py::reinterpret_steal<py::list>(maybe_list);
for (size_t i = 0; i < may.size(); ++i) {
ret.append(may[i]);
}
} else {
mgb_assert(PyLong_Check(maybe_list));
ret.append(PyLong_AsLong(maybe_list));
Py_XDECREF(maybe_list);
}
} else if (PyArray_Check(val.ptr())) {
ret.append(PyArray_PyIntAsInt(val.ptr()));
} else {
ret.append(PyLong_AsLong(val.ptr()));
}
};
if (PyArray_Check(shape.ptr()) && !PyArray_CheckAnyScalar(shape.ptr())) {
orig = py::reinterpret_steal<py::list>(
PyArray_ToList((PyArrayObject*)shape.ptr()));
for (size_t i = 0; i < orig.size(); ++i) {
solve_one(orig[i]);
}
} else if (PyList_Check(shape.ptr())) {
orig = py::reinterpret_borrow<py::list>(shape);
for (size_t i = 0; i < orig.size(); ++i) {
solve_one(orig[i]);
}
} else if (PyTuple_Check(shape.ptr())) {
py::tuple tup = py::reinterpret_borrow<py::tuple>(shape);
for (size_t i = 0; i < tup.size(); ++i) {
solve_one(tup[i]);
}
} else {
solve_one(shape);
}
return py::reinterpret_steal<py::tuple>(PyList_AsTuple(ret.ptr()));
}
py::object _get_index(py::object tensor, py::object src) {
if (!TensorWrapper::try_cast(tensor.ptr()) &&
!py::isinstance<PySymbolVar>(tensor)) {
auto get_const = [&](mgb::DType dtype) -> py::object {
return _Const(tensor, py::cast(dtype), src.attr("device"), src);
};
if (is_bool_list(tensor.ptr()) || is_bool_dtype(tensor.ptr())) {
tensor = get_const(dtype::Bool());
} else {
tensor = get_const(dtype::Int32());
}
if (!is_bool_dtype(tensor.ptr())) {
return tensor;
}
} else {
if (!is_bool_dtype(tensor.ptr())) {
return tensor;
}
}
static std::shared_ptr<OpDef> op = CondTake::make();
std::vector<PyObject*> p;
p.resize(3);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
p[2] = tensor.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[1];
}
py::tuple _try_cond_take(py::handle tensor, py::handle index) {
if (!hasattr(index, "dtype") || !hasattr(index, "shape")) {
return py::tuple();
}
if (!is_bool_dtype(index.ptr()) ||
_make_shape_tuple(getattr(index, "shape"))
.not_equal(_make_shape_tuple(getattr(tensor, "shape")))) {
return py::tuple();
}
py::object iobj;
if (PyArray_Check(index.ptr())) {
iobj =
_Const(index, py::cast((mgb::DType)dtype::Bool()),
getattr(tensor, "device"), tensor);
} else {
iobj = py::reinterpret_borrow<py::object>(index);
}
static std::shared_ptr<OpDef> op = CondTake::make();
std::vector<PyObject*> p;
p.resize(3);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
p[2] = iobj.ptr();
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret;
}
py::tuple _remove_ellipsis(py::object tensor, py::tuple tuple_val) {
size_t tuple_size = tuple_val.size();
size_t ndim_sum = 0, cur_sum = 0;
int pos = -1;
bool has_unknown_ndim_bool_index = false;
for (size_t i = 0; i < tuple_size; ++i) {
py::object handle = tuple_val[i];
if (handle.ptr() == Py_Ellipsis) {
pos = static_cast<int>(i);
for (size_t j = 0; j < i; ++j) {
py::object t = tuple_val[j];
if (t.ptr() == Py_Ellipsis) {
throw py::index_error("only one ellipsis is allowed.");
}
}
} else {
size_t ndim_incr = 1;
if (hasattr(handle, "dtype") && is_bool_dtype(handle.ptr()) &&
hasattr(handle, "ndim")) {
py::object ndim = getattr(handle, "ndim");
if (PyLong_Check(ndim.ptr())) {
ndim_incr = PyLong_AsLong(ndim.ptr());
} else {
has_unknown_ndim_bool_index = true;
}
}
cur_sum += ndim_incr;
}
}
if (pos == -1) {
return tuple_val;
} else {
if (has_unknown_ndim_bool_index) {
throw py::index_error(
"does not support bool index with unknown shape when using "
"Ellipsis.");
}
try {
ndim_sum = getattr(tensor, "ndim").cast<size_t>();
} catch (py::error_already_set& err) {
throw py::index_error(
"does not support Ellipsis when tensor's ndim is unknown.");
}
py::tuple ret(ndim_sum - cur_sum + tuple_size - 1);
size_t idx = 0;
for (size_t i = 0; i < tuple_size; ++i) {
if (i == pos) {
for (size_t j = cur_sum; j < ndim_sum; ++j) {
ret[idx++] = PySlice_New(NULL, NULL, NULL);
}
} else {
ret[idx++] = tuple_val[i];
}
}
return ret;
}
}
py::tuple _expand_bool_dim(py::object tensor, py::tuple tuple_val) {
py::tuple cur_shape = _make_shape_tuple(py::handle(getattr(tensor, "shape")));
py::list new_tuple_val(0);
size_t offset = 0;
size_t tdim = 0;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::handle k = tuple_val[i];
if (is_bool_dtype(k.ptr())) {
size_t ndim = getattr(k, "ndim").cast<size_t>();
if (ndim > 1) {
py::tuple ishape = _make_shape_tuple(py::handle(getattr(k, "shape")));
for (size_t j = 0; j < ndim; ++j) {
if (cur_shape[tdim + j - offset].cast<size_t>() !=
ishape[j].cast<size_t>()) {
std::string msg =
"boolean index did not match tensor along dimension " +
std::to_string(tdim + j) + "; dimension is " +
std::to_string(
cur_shape[tdim + j - offset].cast<size_t>()) +
" but corresponding boolean dimension is " +
std::to_string(ishape[j].cast<size_t>());
throw py::index_error(msg.c_str());
}
}
py::object new_k = getattr(k, "reshape")(-1);
py::object kshape = getattr(new_k, "shape");
py::list new_shape(0);
PyObject* sym = PyObject_CallObject(cpp_use_symbolic_shape, nullptr);
bool is_sym = (sym == Py_True);
Py_XDECREF(sym);
if (is_sym) {
py::object tshape = getattr(tensor, "shape");
for (size_t j = 0; j < i; ++j) {
new_shape.append(tshape[py::int_(j)]);
}
new_shape.append(kshape[py::int_(0)]);
for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
new_shape.append(cur_shape[j]);
}
py::tuple args = py::make_tuple(new_shape);
PyObject* shape_tensor =
PyObject_CallObject(cpp_astensor1d, args.ptr());
py::object reshape_func = getattr(tensor, "reshape");
Py_INCREF(shape_tensor);
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, shape_tensor);
PyObject* new_tensor =
PyObject_CallObject(reshape_func.ptr(), Args);
Py_XDECREF(Args);
tensor = py::reinterpret_steal<py::object>(new_tensor);
cur_shape = _make_shape_tuple(py::handle(shape_tensor));
Py_XDECREF(shape_tensor);
} else {
for (size_t j = 0; j < i; ++j) {
new_shape.append(cur_shape[j]);
}
new_shape.append(py::reinterpret_borrow<py::tuple>(kshape)[0]);
for (size_t j = tdim + ndim - offset; j < cur_shape.size(); ++j) {
new_shape.append(cur_shape[j]);
}
cur_shape = new_shape;
tensor = getattr(tensor, "reshape")(cur_shape);
}
offset++;
tdim += ndim;
}
new_tuple_val.append(k);
} else {
new_tuple_val.append(k);
tdim++;
}
}
return py::make_tuple(tensor, py::reinterpret_borrow<py::tuple>(new_tuple_val));
}
py::tuple _unpack_indexes(py::handle inp_hdl, py::handle idx_hdl) {
py::object inp = py::reinterpret_borrow<py::object>(inp_hdl);
py::tuple tuple_val;
if (py::isinstance<py::tuple>(idx_hdl)) {
tuple_val = py::reinterpret_borrow<py::tuple>(idx_hdl);
} else {
tuple_val = py::make_tuple(idx_hdl);
}
bool use_subtensor = true;
bool need_remove_ellipsis = false;
bool need_expand_bool_dim = false;
size_t idx_ndim = 0;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::object k = tuple_val[i];
if (k.ptr() == Py_None) {
throw py::index_error("newaxis is not allowed here");
} else if (k.ptr() == Py_Ellipsis) {
need_remove_ellipsis = true;
} else {
if (is_bool_dtype(k.ptr()) && hasattr(k, "ndim")) {
size_t ndim = getattr(k, "ndim").cast<size_t>();
idx_ndim += ndim;
if (ndim > 1) {
need_expand_bool_dim = true;
}
} else {
idx_ndim++;
}
}
}
try {
size_t inp_ndim = getattr(inp, "ndim").cast<size_t>();
if (idx_ndim > inp_ndim) {
std::string msg = "too many indices for tensor: tensor is " +
std::to_string(inp_ndim) + "-dimensional, but " +
std::to_string(idx_ndim) + " were indexed";
throw py::index_error(msg.c_str());
}
} catch (py::error_already_set& err) {
; // ignore
}
if (need_remove_ellipsis) {
tuple_val = _remove_ellipsis(inp, tuple_val);
}
if (need_expand_bool_dim) {
py::object shape = getattr(inp, "shape");
if (shape.ptr() != Py_None) {
py::tuple ret = _expand_bool_dim(inp, tuple_val);
inp = ret[0];
tuple_val = ret[1];
}
}
py::list items;
py::list tensors;
int cur_axis = -1;
for (size_t i = 0; i < tuple_val.size(); ++i) {
py::object handle = tuple_val[i];
cur_axis++;
if (!is_scalar(handle.ptr()) && !PySlice_Check(handle.ptr())) {
use_subtensor = false;
}
py::list item;
item.append(cur_axis);
auto push = [&](PyObject* v) {
if (v == Py_None) {
item.append(false);
} else {
item.append(true);
tensors.append(_get_index(py::reinterpret_borrow<py::object>(v), inp));
}
};
if (PySlice_Check(handle.ptr())) {
PySliceObject* s = (PySliceObject*)handle.ptr();
if (s->start == Py_None && s->stop == Py_None && s->step == Py_None) {
continue;
}
push(s->start);
push(s->stop);
push(s->step);
item.append(false);
} else {
for (size_t j = 0; j < 3; j++)
item.append(false);
push(handle.ptr());
}
items.append(item);
}
return py::make_tuple(inp, tensors, items, use_subtensor, need_expand_bool_dim);
}
py::object _getitem_cpp(py::handle inp_hdl, py::handle idx_hdl) {
py::tuple try_res = _try_cond_take(inp_hdl, idx_hdl);
if (try_res.size() == 2) {
return try_res[0];
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
return ret[0];
}
py::object _setitem_cpp(py::handle inp_hdl, py::handle idx_hdl, py::handle val_hdl) {
py::object org_shape = getattr(inp_hdl, "shape");
py::object val = py::reinterpret_borrow<py::object>(val_hdl);
if (!TensorWrapper::try_cast(val.ptr()) && !py::isinstance<PySymbolVar>(val)) {
val =
_Const(val_hdl, getattr(inp_hdl, "dtype"), getattr(inp_hdl, "device"),
inp_hdl);
}
py::tuple up = _unpack_indexes(inp_hdl, idx_hdl);
py::object tensor = py::reinterpret_borrow<py::object>(up[0]);
py::list tensors = py::reinterpret_borrow<py::list>(up[1]);
py::list py_items = py::reinterpret_borrow<py::list>(up[2]);
std::vector<std::tuple<int8_t, bool, bool, bool, bool>> cpp_items;
for (size_t i = 0; i < py_items.size(); ++i) {
py::list item = py::reinterpret_borrow<py::list>(py_items[i]);
cpp_items.push_back(
{item[0].cast<int8_t>(), item[1].cast<bool>(), item[2].cast<bool>(),
item[3].cast<bool>(), item[4].cast<bool>()});
}
static std::shared_ptr<OpDef> op, set_op;
if (up[3].cast<bool>()) {
op = Subtensor::make(cpp_items);
} else {
op = IndexingMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> p;
p.resize(tensors.size() + 2);
py::object Op = py::cast(op);
p[0] = Op.ptr();
p[1] = tensor.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
p[i + 2] = tensors[i].ptr();
}
py::tuple ret =
py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size()));
py::object tmp_result = ret[0];
try {
py::object value_tuple_shape = val.attr("_tuple_shape");
py::object tmp_result_tuple_shape = tmp_result.attr("_tuple_shape");
py::tuple value_shape = py::reinterpret_borrow<py::tuple>(value_tuple_shape);
py::tuple tmp_result_shape =
py::reinterpret_borrow<py::tuple>(tmp_result_tuple_shape);
for (size_t i = 0; i < value_shape.size() && i < tmp_result_shape.size(); ++i) {
size_t vs = value_shape[value_shape.size() - i - 1].cast<size_t>();
size_t ts =
tmp_result_shape[tmp_result_shape.size() - i - 1].cast<size_t>();
if (vs != 1 && vs != ts) {
std::string lhs = "", rhs = "";
for (size_t j = 0; j < tmp_result_shape.size(); ++j) {
lhs += std::to_string(tmp_result_shape[j].cast<size_t>());
if (j)
lhs += ",";
}
for (size_t j = 0; j < value_shape.size(); ++j) {
rhs += std::to_string(value_shape[j].cast<size_t>());
if (j)
rhs += ",";
}
throw py::value_error(
"cannot copy tensor with shape (" + rhs +
") to subtensor with shape (" + lhs + ")");
}
}
} catch (py::error_already_set& err) {
;
}
py::object broadcast_func = getattr(val, "_broadcast");
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, getattr(tmp_result, "shape").release().ptr());
PyObject* new_val = PyObject_CallObject(broadcast_func.ptr(), Args);
Py_XDECREF(Args);
val = py::reinterpret_steal<py::object>(new_val);
if (up[3].cast<bool>()) {
set_op = SetSubtensor::make(cpp_items);
} else {
set_op = IndexingSetMultiAxisVec::make(cpp_items);
}
std::vector<PyObject*> q;
q.resize(tensors.size() + 3);
py::object Set_Op = py::cast(set_op);
q[0] = Set_Op.ptr();
q[1] = tensor.ptr();
q[2] = val.ptr();
for (size_t i = 0; i < tensors.size(); ++i) {
q[i + 3] = tensors[i].ptr();
}
py::tuple result =
py::reinterpret_steal<py::object>(py_apply(NULL, q.data(), q.size()));
py::object res = result[0];
if (up[4].cast<bool>()) {
py::object reshape_func = getattr(res, "reshape");
PyObject* Args = PyTuple_New(1);
PyTuple_SetItem(Args, 0, org_shape.release().ptr());
PyObject* new_tensor = PyObject_CallObject(reshape_func.ptr(), Args);
Py_XDECREF(Args);
res = py::reinterpret_steal<py::object>(new_tensor);
}
return res;
}
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _make_shape_tuple(py::handle(args[0])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _getitem_cpp(py::handle(args[0]), py::handle(args[1])).release().ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs) {
try {
return _setitem_cpp(
py::handle(args[0]), py::handle(args[1]), py::handle(args[2]))
.release()
.ptr();
}
PYEXT17_TRANSLATE_EXC_RET(nullptr)
}
} // namespace mgb::imperative::python
#pragma once
namespace mgb::imperative::python {
PyObject* make_shape_tuple(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* getitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
PyObject* setitem_cpp(PyObject* self, PyObject* const* args, size_t nargs);
} // namespace mgb::imperative::python
\ No newline at end of file
......@@ -642,7 +642,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) {
m_dtr.update_used_time(dest);
MGB_RECORD_EVENT(
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(),
ptr->dev_tensor().raw_ptr());
ptr->dev_tensor(false).raw_ptr());
// update tensor desc for static infer
if (dest->desc.layout.ndim) {
mgb_assert(
......@@ -730,10 +730,20 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
inputs, apply_functor, const_functor);
return outputs;
}
return OpDef::apply_on_physical_tensor(def, inputs, output_descs, validated);
// Check Input Layout
// Get the input layout constraints, and if the constraint is not satisfied
// inplace update the layout and blob to make the tensor contiguous
auto&& constraints = OpDef::get_input_layout_constraint(def, inputs);
for (size_t idx = 0; idx < inputs.size(); ++idx) {
auto&& layout_checker = constraints[idx];
if (layout_checker) {
inputs[idx]->to_contiguous_inplace(layout_checker);
}
}
return OpDef::apply_on_physical_tensor(
def, std::move(inputs), output_descs, validated);
};
MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason);
// Begin profiling operator
SmallVector<std::pair<CompNode, uint64_t>> kernels;
if (profiling_device) {
// Collecting devices
......
#include "../../../src/core/impl/graph/cg_impl.h"
#include "../../../src/core/impl/graph/var_node_mem_mgr.h"
......@@ -60,6 +60,11 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_falli
return def.trait()->infer_output_attrs_fallible(def, inputs);
}
SmallVector<VarNode::LayoutConstraintCallback> OpDef::get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
return def.trait()->get_input_layout_constraint(def, inputs);
}
EncodedSubgraph OpDef::make_backward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
......
......@@ -47,6 +47,10 @@ void OpMethFallbackByProxyGraph::impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible);
}
void OpMethFallbackByProxyGraph::impl(
GetInputLayoutConstraint& func, op_meth_tag::GetInputLayoutConstraint) {
func.Base::operator=(proxy_graph_detail::get_input_layout_constraint);
}
void OpMethFallbackByProxyGraph::impl(GradMaker& func, op_meth_tag::GradMaker) {
func.Base::operator=(proxy_graph_detail::make_backward_graph);
}
......@@ -63,6 +67,10 @@ void OpMethFallbackFromSubgraph::impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) {
func.Base::operator=(subgraph_detail::infer_output_attrs_fallible);
}
void OpMethFallbackFromSubgraph::impl(
GetInputLayoutConstraint& func, op_meth_tag::GetInputLayoutConstraint) {
func.Base::operator=(subgraph_detail::get_input_layout_constraint);
}
void OpMethFallbackFromSubgraph::impl(GradMaker& func, op_meth_tag::GradMaker) {
func.Base::operator=(subgraph_detail::make_backward_graph);
}
......
......@@ -73,6 +73,9 @@ OpMethType(ApplyOnVarNode,
OpMethType(InferOutputAttrsFallible,
decltype(OpDef::infer_output_attrs_fallible));
OpMethType(GetInputLayoutConstraint,
decltype(OpDef::get_input_layout_constraint));
OpMethType(GradMaker,
decltype(OpDef::make_backward_graph));
......@@ -119,6 +122,8 @@ struct OpMethFallbackByProxyGraph : OpMethImplBase {
static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor);
static void impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible);
static void impl(
GetInputLayoutConstraint& func, op_meth_tag::GetInputLayoutConstraint);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
};
......@@ -128,6 +133,8 @@ struct OpMethFallbackFromSubgraph : OpMethImplBase {
static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode);
static void impl(
InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible);
static void impl(
GetInputLayoutConstraint& func, op_meth_tag::GetInputLayoutConstraint);
static void impl(GradMaker& func, op_meth_tag::GradMaker);
};
......@@ -179,6 +186,7 @@ struct OpTrait {
ApplyOnDeviceTensorND apply_on_device_tensornd;
ApplyOnVarNode apply_on_var_node;
InferOutputAttrsFallible infer_output_attrs_fallible;
GetInputLayoutConstraint get_input_layout_constraint;
GradMaker make_backward_graph;
Props props;
HashFunc hash;
......@@ -199,6 +207,7 @@ struct OpTrait {
cb(apply_on_device_tensornd) \
cb(apply_on_var_node) \
cb(infer_output_attrs_fallible) \
cb(get_input_layout_constraint) \
cb(make_backward_graph) \
cb(props) \
cb(hash) \
......
......@@ -117,7 +117,7 @@ void InputCallback::scn_do_execute() {
layout.init_contiguous_stride();
dev_tensor.reset(dev_tensor.storage(), layout);
}
output(0)->reset_dev_tensor_from_tensor(dev_tensor);
output(0)->force_assign_dev_tensor_from_tensor(dev_tensor);
}
cg::OperatorNodeBase* InputCallback::shallow_copy(
......@@ -311,7 +311,7 @@ cg::OperatorNodeBase::NodeProp* MutableTensor::do_make_node_prop() const {
}
void MutableTensor::scn_do_execute() {
output(0)->reset_dev_tensor_from_tensor(*m_dev_tensor);
output(0)->force_assign_dev_tensor_from_tensor(*m_dev_tensor);
}
void MutableTensor::init_output_static_infer_desc() {
......
......@@ -83,28 +83,18 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto& input = inputs[0];
TensorShape target_shape;
if (validated) {
target_shape = output_descs[0].layout;
} else {
cg::copy_tensor_value_to_shape(
target_shape, inputs[1]->get_value().proxy_to_default_cpu());
}
TensorPtr output = Tensor::make(
TensorLayout(target_shape, input->dtype()), input->comp_node());
if (output->layout().is_empty()) {
return {output};
}
if (input->shape().eq_shape(output->shape())) {
mgb_assert(input->layout().eq_layout(output->layout()));
output->dev_tensor().copy_from_fixlayout(input->dev_tensor());
} else {
TensorLayout input_layout = input->layout().broadcast(output->shape());
output->dev_tensor().copy_from_fixlayout(
input->dev_tensor().sub(SubTensorSpec::make_from_layout(input_layout)));
}
return {output};
def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();
TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
TensorLayout tlayout = slayout.broadcast(tshp);
// memory forward
return {Tensor::make(src->blob(), src->offset(), tlayout)};
}
OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast)
......@@ -184,10 +174,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
auto&& tshp_nd = inputs[1];
auto slayout = src->layout();
if (validated) {
return {Tensor::make(src->blob(), 0, output_descs[0].layout)};
}
TensorShape tshp;
cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
......@@ -195,13 +181,39 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
tshp[op_def.axis] = 1;
tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems();
}
return {Tensor::make(src->blob(), 0, slayout.reshape(tshp))};
TensorLayout tlayout;
mgb_assert(slayout.try_reshape(tlayout, tshp));
return {Tensor::make(src->blob(), src->offset(), tlayout)};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto&& op_def = def.cast_final_safe<Reshape>();
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = [&](const TensorLayout& layout) {
TensorShape tshp;
TensorLayout ret;
cg::copy_tensor_value_to_shape(
tshp, inputs[1]->get_value().proxy_to_default_cpu());
if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) {
mgb_assert(tshp[op_def.axis] == -1);
tshp[op_def.axis] = 1;
tshp[op_def.axis] = layout.total_nr_elems() / tshp.total_nr_elems();
}
if (layout.try_reshape(ret, tshp)) {
return true;
} else {
return false;
}
};
return layout_checker;
}
OP_TRAIT_REG(Reshape, Reshape)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace reshape
......
......@@ -220,12 +220,22 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node(
SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
mgb_assert(
inputs[0]->blob().use_count() == 1 && inputs[0]->blob()->storage().unique(),
"This inplace modification may change the elements of other tensors. "
"Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs "
"correctly.");
auto dest = inputs[0], delta = inputs[1], alpha = inputs[2], beta = inputs[3];
if (!(inputs[0]->blob().unique() && inputs[0]->blob()->storage().unique())) {
mgb_log_warn(
"This inplace modification may change the elements of other tensors. "
"Fallback to non-inplace update.");
DeviceTensorStorage storage;
storage.reset(dest->comp_node(), dest->blob()->size(), dest->blob()->storage());
storage = storage.sub(dest->offset());
DeviceTensorND dv;
dv.reset(storage, dest->layout());
DeviceTensorND dv_new;
dv_new.copy_from(dv);
dest = Tensor::make(dv_new);
}
auto tensor_to_scalar = [](const TensorPtr& tensor) -> float {
return *tensor->get_value().ptr<float>();
};
......
......@@ -54,7 +54,8 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
if (memory_forward_success(def, inputs)) {
return {Tensor::make(inputs[0]->blob(), 0, inputs[0]->layout())};
return {Tensor::make(
inputs[0]->blob(), inputs[0]->offset(), inputs[0]->layout())};
}
return proxy_graph_detail::apply_on_physical_tensor(
def, inputs, output_descs, validated);
......@@ -73,11 +74,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {output_descs, validated};
}
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
layout_checker[0] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
return layout_checker;
}
OP_TRAIT_REG(Reduce, Reduce, opr::Reduce)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.get_input_layout_constraint(get_input_layout_constraint)
.fallback();
} // namespace reduce
} // namespace
......
......@@ -594,6 +594,13 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible<Dro
return {dests, true};
}
template <typename Op>
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<VarNode::LayoutConstraintCallback> layout_checker(inputs.size());
return layout_checker;
}
} // anonymous namespace
Handle new_handle(CompNode comp_node, uint64_t seed) {
......@@ -622,6 +629,7 @@ CompNode get_rng_handle_compnode(Handle handle) {
.apply_on_var_node(apply_on_var_node<NAME, Output>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.get_input_layout_constraint(get_input_layout_constraint<NAME>) \
.fallback(); \
}
......
......@@ -60,9 +60,55 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return opr::Dimshuffle::make(inputs[0], ds.pattern, 0UL, config);
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& ds = static_cast<const Dimshuffle&>(def);
mgb_assert(
ds.pattern.size() <= TensorShape::MAX_NDIM,
"Dimshuffle pattern exceeds max length of %zd", TensorShape::MAX_NDIM);
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "Dimshuffle expects 1 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto inp_layout = src->layout();
size_t pattern_ndim = *std::max_element(ds.pattern.begin(), ds.pattern.end()) + 1;
mgb_assert(
inp_layout.ndim == pattern_ndim,
"input ndim mismatch for Dimshuffle: expect=%zd actual=%zd", pattern_ndim,
inp_layout.ndim);
TensorLayout out_layout{inp_layout.dtype};
out_layout.ndim = ds.pattern.size();
size_t idx = 0;
bool input_used[TensorLayout::MAX_NDIM] = {0};
for (auto i : ds.pattern) {
if (i < 0) {
out_layout.shape[idx] = 1;
out_layout.stride[idx] = 1;
} else {
input_used[i] = true;
out_layout.shape[idx] = inp_layout.shape[i];
out_layout.stride[idx] = inp_layout.stride[i];
}
++idx;
}
if (out_layout.is_contiguous()) {
out_layout.init_contiguous_stride();
}
for (size_t i = 0; i < pattern_ndim; ++i) {
mgb_assert(
input_used[i] || inp_layout.shape[i] == 1,
"non-1 dim discarded in Dimshuffle: ishp=%s dim=%zd",
inp_layout.megdnn::TensorShape::to_string().c_str(), i);
}
// memory forward
return {Tensor::make(src->blob(), src->offset(), out_layout)};
}
OP_TRAIT_REG(Dimshuffle, Dimshuffle, opr::Dimshuffle)
.make_from_op_node(make_from_op_node)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace dimshuffle
} // namespace
......@@ -80,7 +126,25 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return opr::AxisAddRemove::make(inputs[0], param, config);
}
OP_TRAIT_REG(AddAxis, AddAxis).apply_on_var_node(apply_on_var_node).fallback();
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<AddAxis>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "AddAxis expects 1 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto tlayout = src->layout();
for (auto&& i : op_def.axis) {
tlayout.add_axis_cont_inplace(i);
}
// memory forward
return {Tensor::make(src->blob(), src->offset(), tlayout)};
}
OP_TRAIT_REG(AddAxis, AddAxis)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace add_axis
} // namespace
......@@ -97,7 +161,36 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
return opr::AxisAddRemove::make(inputs[0], param, config);
}
OP_TRAIT_REG(RemoveAxis, RemoveAxis).apply_on_var_node(apply_on_var_node).fallback();
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<RemoveAxis>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "RemoveAxis expects 1 inputs; got %lu actually", nr_inp);
auto&& src = inputs[0];
auto tlayout = src->layout();
for (auto&& i : op_def.axis) {
if (tlayout.ndim == 1) {
mgb_assert(
tlayout.shape[0] == 1 && i == 0,
"can not remove axis %u from tensor of shape=%s", i,
tlayout.megdnn::TensorShape::to_string().c_str());
} else {
mgb_assert(
i < tlayout.ndim && tlayout.shape[i] == 1,
"can not remove axis %u from tensor of shape=%s", i,
tlayout.megdnn::TensorShape::to_string().c_str());
tlayout.remove_axis_inplace(i);
}
}
// memory forward
return {Tensor::make(src->blob(), src->offset(), tlayout)};
}
OP_TRAIT_REG(RemoveAxis, RemoveAxis)
.apply_on_var_node(apply_on_var_node)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace remove_axis
} // namespace
......
......@@ -411,7 +411,7 @@ struct ComputingGraphHolder {
executable->wait();
size_t nr_inputs = inputs.size();
for (size_t i = 0; i < nr_inputs; ++i) {
auto input_dev_tensor = input_tensors[i]->dev_tensor();
auto input_dev_tensor = input_tensors[i]->dev_tensor(false);
inputs[i].device_value->reset(
input_dev_tensor.storage(), input_dev_tensor.layout());
if (inputs[i].host_value) {
......
......@@ -95,7 +95,13 @@ const Blob::RawStorage& Blob::storage() {
Tensor::Tensor(
BlobPtr blob, const TensorLayout& layout, size_t offset, const HostTensorND& hv)
: m_layout(layout), m_blob(std::move(blob)), m_offset(offset), m_value(hv) {}
: m_cn(blob->comp_node()),
m_shape(layout),
m_dtype(layout.dtype),
m_layout(layout),
m_blob(std::move(blob)),
m_offset(offset),
m_value(hv) {}
Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) {
constexpr int size_threshold = TensorShape::MAX_NDIM;
......@@ -107,7 +113,12 @@ Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) {
MGB_RECORD_EVENT(
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(),
dev_tensor().raw_ptr());
dev_tensor().copy_from_fixlayout(hv);
DeviceTensorStorage storage;
storage.reset(m_cn, m_blob->size(), m_blob->storage());
storage = storage.sub(m_offset);
DeviceTensorND dv;
dv.reset(storage, m_layout);
dv.copy_from_fixlayout(hv);
// even though hv is saved in m_value, Tensor itself could be
// released before copy completes
MGB_RECORD_EVENT(
......@@ -117,25 +128,36 @@ Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) {
}
}
Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv) {
Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv)
: m_offset(dv.storage().offset()),
m_cn(dv.comp_node()),
m_shape(dv.layout()),
m_dtype(dv.layout().dtype),
m_blob(Blob::make(dv.storage())),
m_layout(dv.layout()) {
if (!hv.empty()) {
mgb_assert(dv.comp_node() == hv.comp_node());
mgb_assert(dv.dtype() == hv.dtype());
mgb_assert(dv.shape().eq_shape(hv.shape()));
m_value = hv;
}
m_layout = dv.layout();
m_blob = Blob::make(dv.storage());
m_offset = dv.storage().offset();
}
Tensor::Tensor(const TensorLayout& layout, const CompNode& cn)
: m_layout{layout},
m_blob{Blob::make(cn, layout.span().dist_byte())},
m_offset{0} {}
m_offset{0},
m_cn(cn),
m_shape(layout),
m_dtype(layout.dtype) {}
Tensor::Tensor(const BlobPtr blob, const size_t offset, const TensorLayout& layout)
: m_layout{layout}, m_blob{blob}, m_offset{offset} {}
: m_layout{layout},
m_blob{blob},
m_offset{offset},
m_cn(blob->comp_node()),
m_shape(layout),
m_dtype(layout.dtype) {}
TensorPtr Tensor::make(const HostTensorND& hv) {
auto&& blob = MultiCNConstTensorCache::inst().lookup(hv);
......@@ -145,10 +167,45 @@ TensorPtr Tensor::make(const HostTensorND& hv) {
return std::make_shared<Tensor>(hv);
}
DeviceTensorND Tensor::dev_tensor() {
void Tensor::to_contiguous_inplace(VarNode::LayoutConstraintCallback& layout_checker) {
MGB_LOCK_GUARD(m_blob_mtx);
if (!m_layout.is_empty() && !layout_checker(m_layout)) {
DeviceTensorStorage storage;
storage.reset(m_cn, m_blob->size(), m_blob->storage());
storage = storage.sub(m_offset);
DeviceTensorND dv;
dv.reset(storage, m_layout);
DeviceTensorND dv_contig;
dv_contig.copy_from(dv);
m_layout = dv_contig.layout();
std::atomic_store(&m_blob, Blob::make(dv_contig.storage()));
mgb_assert(m_layout.is_contiguous());
m_offset = 0;
}
}
void Tensor::to_contiguous_inplace() {
static VarNode::LayoutConstraintCallback default_cb =
[](const TensorLayout& layout) { return layout.is_contiguous(); };
to_contiguous_inplace(default_cb);
}
void Tensor::assign_from_dev_tensor(DeviceTensorND dv) {
MGB_LOCK_GUARD(m_blob_mtx);
std::atomic_store(&m_blob, Blob::make(dv.storage()));
m_offset = dv.storage().offset();
m_layout = dv.layout();
}
DeviceTensorND Tensor::dev_tensor(bool contiguous) {
mgb_assert(m_blob, "uninitialized tensor.");
if (contiguous) {
to_contiguous_inplace();
}
MGB_LOCK_GUARD(m_blob_mtx);
DeviceTensorStorage storage;
storage.reset(m_blob->comp_node(), m_blob->size(), m_blob->storage());
storage.reset(m_cn, m_blob->size(), m_blob->storage());
storage = storage.sub(m_offset);
DeviceTensorND ret;
ret.reset(storage, m_layout);
......@@ -156,16 +213,22 @@ DeviceTensorND Tensor::dev_tensor() {
}
void Tensor::fetch_value() {
MGB_LOCK_GUARD(m_mtx);
MGB_LOCK_GUARD(m_blob_mtx);
MGB_LOCK_GUARD(m_value_mtx);
if (m_value.empty()) {
m_value.copy_from(dev_tensor());
DeviceTensorStorage storage;
storage.reset(m_cn, m_blob->size(), m_blob->storage());
storage = storage.sub(m_offset);
DeviceTensorND dv;
dv.reset(storage, m_layout);
m_value.copy_from(dv);
m_value_ready.reset(EventPool::without_timer().alloc(comp_node()));
m_value_ready->record();
}
}
bool Tensor::value_fetched() {
MGB_LOCK_GUARD(m_mtx);
MGB_LOCK_GUARD(m_value_mtx);
return m_value.layout().ndim != 0;
}
......@@ -178,7 +241,7 @@ const HostTensorND& Tensor::get_value() {
}
const HostTensorND* Tensor::try_get_value() {
MGB_LOCK_GUARD(m_mtx);
MGB_LOCK_GUARD(m_value_mtx);
if (!m_value.empty() && (!m_value_ready || m_value_ready->finished())) {
return &m_value;
}
......@@ -193,7 +256,7 @@ TensorPtr Tensor::make_scalar(DTypeScalar value, CompNode cn) {
}
TensorPtr Tensor::sub(size_t offset, TensorShape shape) {
TensorLayout layout(shape, m_layout.dtype);
TensorLayout layout(shape, m_dtype);
return Tensor::make(m_blob, offset + m_offset, layout);
}
......
......@@ -73,7 +73,7 @@ public:
static SymbolVar make(ComputingGraph& graph, Tensor& tensor) {
auto opr = graph.insert_opr(std::make_unique<InputPlaceholder>(graph, &tensor));
auto var = opr->output(0);
auto&& dev_tensor = tensor.dev_tensor();
auto&& dev_tensor = tensor.dev_tensor(false);
var->m_comp_node = dev_tensor.comp_node();
var->m_shape = dev_tensor.shape();
if (dev_tensor.empty()) {
......@@ -81,10 +81,7 @@ public:
layout.init_contiguous_stride();
dev_tensor.reset(dev_tensor.storage(), layout);
}
var->m_dev_tensor = dev_tensor;
var->m_mem_plan.reset_from_owner_var()
.chunk()
.mem_alloc_status.set_from_owner_var();
var->force_assign_dev_tensor_from_tensor(dev_tensor);
return var;
}
......
......@@ -314,15 +314,11 @@ public:
size_t idx = 0;
for (auto&& input : opr_inputs) {
mgb_assert(input->owner_opr()->same_type<InputPlaceholder>());
input->m_dev_tensor.storage({});
auto&& dev_tensor = inputs[input_remap[idx]]->dev_tensor();
auto&& dev_tensor = inputs[input_remap[idx]]->dev_tensor(false);
auto&& layout = dev_tensor.layout();
input->shape(dev_tensor.shape());
auto&& chk = input->m_mem_plan.reset_from_owner_var().chunk();
input->m_dev_tensor.reset(dev_tensor.storage(), layout);
input->m_mem_plan.layout(layout);
chk.mem_alloc_status.set_from_owner_var();
input->force_assign_dev_tensor_from_tensor(dev_tensor);
mgb_assert(input->comp_node() == dev_tensor.comp_node());
mgb_assert(input->shape().eq_shape(layout));
......@@ -335,9 +331,14 @@ public:
mgb_assert(m_opr->usable_output().size() == outputs.size());
::mgb::opr::intl::WorkspaceLimitHook::set_impl(
m_opr->owner_graph(), get_workspace_limit);
size_t j = 0;
for (auto&& var : m_opr->output()) {
auto&& chk = var->m_mem_plan.reset_from_owner_var().chunk();
chk.mem_alloc_status.set_from_owner_var();
}
m_opr->mem_plan_fwd_in2out_readonly();
size_t j = 0;
for (auto&& var : m_opr->output()) {
if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
TensorLayout layout{var->shape(), var->dtype(), var->format()};
var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(
......@@ -349,18 +350,16 @@ public:
mgb_assert(var->comp_node() == tensor->comp_node());
mgb_assert(var->shape().eq_shape(layout));
mgb_assert(var->dtype() == layout.dtype);
if (var->m_mem_plan.chunk().owner_var != var) {
tensor->assign_from_dev_tensor(
var->m_dev_tensor); // memory forwarding
} else {
var->assign_dev_tensor_from_tensor(tensor->dev_tensor());
}
++j;
}
chk.mem_alloc_status.set_from_owner_var();
}
mgb_assert(j == outputs.size());
// Memory forwarding was bypassed in megbrain with graph option
// imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly
// to initialize some opr(e.g. Subtensor)'s internal state
// TODO: implement memory forwarding
m_opr->mem_plan_fwd_in2out_readonly();
{
// some opr (e.g. Reduce) rely on on_mem_status_changed to set
// input/output tensor corretly, since we bypass var_node_mem_mgr
......@@ -840,7 +839,7 @@ public:
Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
}
auto raw_outputs = to_raw_ptr_array(outputs);
auto raw_outputs = to_raw_ptr_array(outputs, false);
CompNode::UnorderedSet used_cns;
for (auto&& out : raw_outputs) {
auto cn = out->comp_node();
......
......@@ -9,8 +9,12 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "../mgb_cg_impl.h"
#include "./mini_graph.h"
#include "megbrain/opr/io.h"
using LayoutConstraintLevel = mgb::cg::VarNodeMemManager::LayoutConstraintLevel;
using LayoutConstraintCallback = mgb::VarNode::LayoutConstraintCallback;
namespace mgb::imperative::proxy_graph {
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder);
......@@ -34,4 +38,81 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return ret;
}
std::unordered_map<size_t, SmallVector<LayoutConstraintCallback>>
input_layout_constraints_cache;
SmallVector<LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
auto get_input_layout_constraint_hash_key =
[](const OpDef& def, const SmallVector<TensorPtr>& inputs) {
XXHash state;
size_t length = 0, data[1 + inputs.size()];
data[length++] = def.hash();
for (auto&& i : inputs) {
data[length++] = mgb::hash(i->comp_node());
}
state.update(data, length * sizeof(size_t));
return state.digest();
};
auto hash_key = get_input_layout_constraint_hash_key(def, inputs);
auto&& iter = input_layout_constraints_cache.find(hash_key);
if (iter != input_layout_constraints_cache.end()) {
return iter->second;
}
static cg::ComputingGraphImpl* graph =
imperative::ResourceManager::create_global<cg::ComputingGraphImpl>();
VarNodeArray vinputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
OperatorNodeConfig config;
auto&& layout = inputs[i]->layout();
layout.init_contiguous_stride();
vinputs[i] = graph->insert_opr(std::make_unique<mgb::opr::SharedDeviceTensor>(
*graph,
std::make_shared<DeviceTensorND>(
inputs[i]->comp_node(), layout),
false, config))
->output(0);
}
auto&& opr = OpDef::apply_on_var_node(def, vinputs)[0]->owner_opr();
opr->add_input_layout_constraint();
SmallVector<LayoutConstraintCallback> res(inputs.size());
auto& mem_mgr = graph->var_node_mem_manager();
for (size_t i = 0; i < vinputs.size(); ++i) {
auto& trait = mem_mgr.get_var_node_mem_trait(vinputs[i]);
switch (trait.layout_constraint.level) {
case LayoutConstraintLevel::CONTIG:
res[i] = [](const TensorLayout& layout) {
return layout.is_contiguous();
};
break;
case LayoutConstraintLevel::MONOTONE:
res[i] = [&trait](const TensorLayout& layout) {
if (!layout.is_abs_monotonous_allow_brdcst()) {
return false;
}
for (auto&& i : trait.layout_constraint.custom)
if (!i(layout))
return false;
return true;
};
break;
case LayoutConstraintLevel::NONE:
if (!trait.layout_constraint.custom.empty()) {
res[i] = [&trait](const TensorLayout& layout) {
for (auto&& i : trait.layout_constraint.custom)
if (!i(layout))
return false;
return true;
};
}
break;
default:
mgb_throw(InternalError, "invalid layout_constraint_level");
}
}
input_layout_constraints_cache.emplace(hash_key, res);
return res;
}
} // namespace mgb::imperative::proxy_graph_detail
......@@ -17,6 +17,8 @@
#include "./op_trait.h"
using LayoutConstraintCallback = mgb::VarNode::LayoutConstraintCallback;
namespace mgb {
namespace imperative {
namespace subgraph_detail {
......@@ -73,6 +75,13 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
const std::shared_ptr<OpDef>& op,
const SmallVector<TensorPtr>& inputs,
size_t nr_outputs) {
auto&& constraints = OpDef::get_input_layout_constraint(*op, inputs);
for (size_t idx = 0; idx < inputs.size(); ++idx) {
auto&& layout_checker = constraints[idx];
if (layout_checker) {
inputs[idx]->to_contiguous_inplace(layout_checker);
}
}
// do not use infered output_desc in subgraph
return OpDef::apply_on_physical_tensor(*op, inputs, output_descs, false);
};
......@@ -81,6 +90,12 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return outputs;
}
SmallVector<LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<LayoutConstraintCallback> res(inputs.size());
return res;
}
static EncodedSubgraph make_backward_graph_from_forward(
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
......
......@@ -78,6 +78,9 @@ public:
static EncodedSubgraph make_forward_graph(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs);
static SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs);
const OpTrait* trait() const;
std::string to_string() const;
......
......@@ -14,6 +14,7 @@
#include <memory>
#include <mutex>
#include "megbrain/graph.h"
#include "megbrain/imperative/resource_manager.h"
#include "megbrain/tensor.h"
......@@ -90,18 +91,24 @@ public:
CompNode comp_node() const {
mgb_assert(m_blob, "uninitialized tensor.");
return m_blob->comp_node();
return m_cn;
}
DType dtype() const { return m_layout.dtype; }
DType dtype() const { return m_dtype; }
TensorLayout layout() const { return m_layout; }
const TensorShape& shape() const { return m_layout; }
const TensorShape& shape() const { return m_shape; }
size_t offset() const { return m_offset; }
DeviceTensorND dev_tensor();
void to_contiguous_inplace(VarNode::LayoutConstraintCallback&);
void to_contiguous_inplace();
DeviceTensorND dev_tensor(bool contiguous = true);
void assign_from_dev_tensor(DeviceTensorND);
static TensorPtr make_scalar(DTypeScalar value, CompNode cn);
......@@ -110,7 +117,7 @@ public:
return make_scalar(value, m_blob->comp_node());
}
BlobPtr& blob() { return m_blob; }
BlobPtr blob() { return m_blob; }
void fetch_value();
bool value_fetched();
......@@ -131,10 +138,16 @@ public:
static void static_initialize();
private:
TensorLayout m_layout;
BlobPtr m_blob;
size_t m_offset;
std::mutex m_mtx;
const CompNode m_cn;
const TensorShape m_shape;
const DType m_dtype;
std::mutex m_blob_mtx;
BlobPtr m_blob;
TensorLayout m_layout;
std::mutex m_value_mtx;
HostTensorND m_value;
EventPtr m_value_ready = nullptr;
};
......
......@@ -33,6 +33,9 @@ EncodedSubgraph make_backward_graph(
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs);
} // namespace proxy_graph_detail
} // namespace imperative
} // namespace mgb
......
......@@ -36,6 +36,9 @@ EncodedSubgraph make_backward_graph(
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad);
SmallVector<VarNode::LayoutConstraintCallback> get_input_layout_constraint(
const OpDef& def, const SmallVector<TensorPtr>& inputs);
} // namespace subgraph_detail
} // namespace imperative
} // namespace mgb
\ No newline at end of file
......@@ -322,7 +322,7 @@ void ComputingGraphImpl::free_varnode_storage(void* ptr) {
m_var_node_pool.free_raw(ptr);
};
OperatorNodeBase* ComputingGraphImpl::insert_opr(
MGE_WIN_DECLSPEC_FUC OperatorNodeBase* ComputingGraphImpl::insert_opr(
std::unique_ptr<OperatorNodeBase> opr_uniqp) {
auto opr = opr_uniqp.get();
......
......@@ -148,8 +148,8 @@ class ComputingGraphImpl final : public ComputingGraph {
public:
class ComputingSequence;
ComputingGraphImpl();
~ComputingGraphImpl();
MGE_WIN_DECLSPEC_FUC ComputingGraphImpl();
MGE_WIN_DECLSPEC_FUC ~ComputingGraphImpl();
template <typename T>
static ComputingGraphImpl* downcast(T* ptr) = delete;
......@@ -166,7 +166,8 @@ public:
SmallVector<std::unique_ptr<AsyncExecutable>> compile_multi_part(
const SmallVector<OutputSpec>& out_specs) override;
OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr) override;
MGE_WIN_DECLSPEC_FUC OperatorNodeBase* insert_opr(
std::unique_ptr<OperatorNodeBase> opr) override;
void* alloc_varnode_storage() override;
......
......@@ -93,6 +93,23 @@ MemAllocPlan& MemAllocPlan::assign_for_forward(
return *this;
}
MemAllocPlan& MemAllocPlan::force_assign_for_forward(
const MemAllocPlan& src, const SubTensorSpec& sub) {
mgb_assert(valid() && src.valid() && m_layout.eq_shape(sub.layout()));
++(m_chunk = src.m_chunk)->m_refcnt;
m_layout = sub.layout();
// make layout strong-contig
for (int i = static_cast<int>(m_layout.ndim) - 1; i >= 0; --i) {
if (m_layout.shape[i] == 1) {
m_layout.stride[i] = i + 1 < static_cast<int>(m_layout.ndim)
? m_layout.stride[i + 1] * m_layout.shape[i + 1]
: 1;
}
}
m_layout.dtype = dtype();
return *this;
}
MemAllocPlan& MemAllocPlan::reset_from_owner_var() {
auto owner_var = m_chunk_storage.owner_var;
m_layout.dtype = dtype();
......@@ -223,8 +240,13 @@ VarNode& VarNode::format(TensorFormat format) {
bool VarNode::set_fwd_in2out_readonly(VarNode* input, const SubTensorSpec& sub) {
if (owner_graph()->options().imperative_proxy_graph) {
if (input->comp_node() != comp_node()) {
return false;
}
m_mem_plan.force_assign_for_forward(input->m_mem_plan, sub);
m_dev_tensor = input->dev_tensor().sub(sub);
return true;
}
return ComputingGraphImpl::downcast(owner_graph())
->var_node_mem_manager()
.fwd_in2out_readonly(input, sub, this);
......@@ -361,6 +383,13 @@ VarNode& VarNode::reset_dev_tensor_from_tensor(const DeviceTensorND& value) {
return *this;
}
void VarNode::force_assign_dev_tensor_from_tensor(const DeviceTensorND& value) {
m_dev_tensor = value;
shape(value.shape());
m_mem_plan.reset_from_owner_var().chunk().mem_alloc_status.set_from_owner_var();
m_mem_plan.layout(value.layout());
}
void VarNode::assign_dev_tensor_from_tensor(const DeviceTensorND& value) {
mgb_assert(
(value.layout().is_contiguous() || value.empty()) &&
......
......@@ -475,7 +475,7 @@ DEF(CompNode node, const TensorShape& shape, DType dtype, TensorFormat format)
DEF(CompNode node, const TensorLayout& layout)
: TensorND(node, layout, layout.dtype, layout.format) {
mgb_assert(
layout.is_contiguous(),
layout.is_contiguous() || layout.is_empty(),
"non-contiguous layout used for initializing a tensor: %s",
layout.to_string().c_str());
}
......
......@@ -241,7 +241,8 @@ public:
* \return the node in the graph (maybe another node due to
* deduplication)
*/
virtual OperatorNodeBase* insert_opr(std::unique_ptr<OperatorNodeBase> opr) = 0;
MGE_WIN_DECLSPEC_FUC virtual OperatorNodeBase* insert_opr(
std::unique_ptr<OperatorNodeBase> opr) = 0;
/*!
* \brief used by OperatorNodeBase to allocate its outputs
......
......@@ -194,6 +194,10 @@ public:
MGE_WIN_DECLSPEC_FUC MemAllocPlan& assign_for_forward(
const MemAllocPlan& src, const SubTensorSpec& sub);
//! force assign for readonly forward
MGE_WIN_DECLSPEC_FUC MemAllocPlan& force_assign_for_forward(
const MemAllocPlan& src, const SubTensorSpec& sub);
/*!
* \brief next readonly-forward reader of this MemAllocPlan
*
......@@ -509,6 +513,9 @@ public:
//! NO_SYS_MEM_ALLOC can be modified.
MGE_WIN_DECLSPEC_FUC bool is_graph_dest_varnode();
MGE_WIN_DECLSPEC_FUC void force_assign_dev_tensor_from_tensor(
const DeviceTensorND& value);
private:
//! whether its memory should be allocated by mgb system during graph
//! execution; initialized in VarNodeMemManager::reset_opr_seq()
......
......@@ -24,7 +24,7 @@ namespace intl {
* \brief base class for IO nodes between device and host
*/
class HostIONodeBase : public cg::SingleCNOperatorNodeBase {
void init_output_static_infer_desc() override final;
MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override final;
protected:
using cg::SingleCNOperatorNodeBase::SingleCNOperatorNodeBase;
......@@ -32,9 +32,10 @@ protected:
/*!
* \brief src_type for static shape and value infer
*/
virtual cg::static_infer::SourceType static_infer_src_type() const;
MGE_WIN_DECLSPEC_FUC virtual cg::static_infer::SourceType static_infer_src_type()
const;
virtual const TensorShape& get_output_shape() = 0;
MGE_WIN_DECLSPEC_FUC virtual const TensorShape& get_output_shape() = 0;
/*!
* \brief fill value in *dest* for static inference
......@@ -52,10 +53,10 @@ protected:
class DeviceTensorHolder : public HostIONodeBase {
class DevValueExecDep;
void init_output_format() override;
void init_output_mem_plan(bool dynamic) override final;
void scn_do_execute() override final;
void record_execute_deps(ExecDependencyArray& deps) override;
MGE_WIN_DECLSPEC_FUC void init_output_format() override;
MGE_WIN_DECLSPEC_FUC void init_output_mem_plan(bool dynamic) override final;
MGE_WIN_DECLSPEC_FUC void scn_do_execute() override final;
MGE_WIN_DECLSPEC_FUC void record_execute_deps(ExecDependencyArray& deps) override;
protected:
using HostIONodeBase::HostIONodeBase;
......@@ -77,20 +78,20 @@ MGB_DEFINE_CLS_WITH_SUPER(SharedDeviceTensorBase, DeviceTensorHolder) // {
std::shared_ptr<DeviceTensorND> m_dev_data;
bool m_const_value;
const TensorShape& get_output_shape() override;
MGE_WIN_DECLSPEC_FUC const TensorShape& get_output_shape() override;
bool fill_in_static_infer(DeviceTensorND* dest) override {
MGB_MARK_USED_VAR(dest);
return false;
}
void init_output_comp_node() override;
MGE_WIN_DECLSPEC_FUC void init_output_comp_node() override;
public:
//! const_value marks whether the device value of this operator should
//! be treated as constant during graph execution. Should be false in
//! most cases.
SharedDeviceTensorBase(
MGE_WIN_DECLSPEC_FUC SharedDeviceTensorBase(
ComputingGraph& graph, const std::shared_ptr<DeviceTensorND>& dev_data,
bool const_value, const OperatorNodeConfig& config);
......@@ -248,7 +249,8 @@ private:
*/
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
SharedDeviceTensor, intl::SharedDeviceTensorBase) // {
cg::static_infer::SourceType static_infer_src_type() const override;
MGE_WIN_DECLSPEC_FUC cg::static_infer::SourceType static_infer_src_type()
const override;
public:
using Super::Super;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册