提交 ca001777 编写于 作者: M Megvii Engine Team

perf(dispatch): speed up dispatch system

GitOrigin-RevId: eabbe3e0219ff989801751c726eb0828b1b7a740
上级 187c1dc0
......@@ -16,6 +16,7 @@ import numpy as np
from .. import _config
from .._imperative_rt.common import CompNode
from .._imperative_rt.core2 import SymbolVar, Tensor, apply, dtype_promotion
from .._imperative_rt.core2 import reduce_to_scalar as _reduce_to_scalar
from ..ops import builtin
from . import amp
from .indexing import getitem, setitem
......@@ -508,12 +509,8 @@ def _reduce(mode):
elif self.dtype == np.bool_:
data = data.astype("int32")
if axis is None:
data = data.reshape(-1)
assert not keepdims, "can not set axis=None and keepdims=True"
op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
result = _remove_axis(result, 0)
result = _reduce_to_scalar(builtin.Reduce(mode=mode), data)
elif isinstance(axis, collections.abc.Iterable):
axis = _normalize_axis(self.ndim, axis, reverse=True)
for ai in axis:
......
......@@ -69,7 +69,7 @@ class SGD(Optimizer):
inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
if inplace_mode:
_neg_lr = tensor(-lr, dtype="float32")
c1 = tensor([1.0])
c1 = tensor(1.0)
for param in param_group["params"]:
if param.grad is None:
......
......@@ -84,14 +84,15 @@ class Tensor(_Tensor, ArrayMethodMixin):
device: str = None,
is_const: bool = False,
no_cache: bool = False,
name: str = "",
name: str = None,
):
if name is None:
name = ""
else:
self._set_name(name)
self._custom_name = name
self._name = name
self._short_name = name
self._set_name(self._name)
self._prefix = None
@property
......
......@@ -46,17 +46,17 @@ void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
if (args[1] != Py_None) {
callback = py::reinterpret_borrow<py::object>(args[1]);
}
GenericFunction generic_callback =
[=](Span<ValueRef> inputs) -> std::vector<ValueRef> {
GenericFunction generic_callback = [=](Span<ValueRef> inputs) -> ValueRefList {
mgb_assert(inputs.size() == 1);
if (callback) {
callback(TensorWrapper::make(py_tensor_type, inputs[0]));
}
return {};
};
tw->m_tensor->reset(imperative::apply(
auto attached_value = imperative::apply(
AttachGrad(m_key), tw->m_tensor->data(),
FunctionValue::make(generic_callback))[0]);
FunctionValue::make(generic_callback))[0];
tw->m_tensor->reset(attached_value);
}
void GradKeyWrapper::backward(GradKeyWrapper* self, py::list tensors, py::list grads) {
......
......@@ -98,7 +98,7 @@ ValueRef make_empty_tensor(
return res;
}
std::optional<std::vector<ValueRef>> elemwise_grad_rule(
std::optional<ValueRefList> elemwise_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto& elemwise = op.cast_final_safe<Elemwise>();
......@@ -117,7 +117,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(2);
ValueRefList ret(2);
if (!grad) {
return ret;
}
......@@ -132,7 +132,7 @@ std::optional<std::vector<ValueRef>> elemwise_grad_rule(
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<std::vector<ValueRef>> reshape_grad_rule(
std::optional<ValueRefList> reshape_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 2);
......@@ -147,7 +147,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(2);
ValueRefList ret(2);
if (!grad) {
return ret;
}
......@@ -162,7 +162,7 @@ std::optional<std::vector<ValueRef>> reshape_grad_rule(
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<std::vector<ValueRef>> subtensor_grad_rule(
std::optional<ValueRefList> subtensor_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& subtensor = op.cast_final_safe<Subtensor>();
......@@ -180,9 +180,9 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule(
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
ValueRefList ret(1);
if (grad && inputs[0]) {
SmallVector<ValueRef> args_(inputs.size() + 1);
ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = grad;
......@@ -197,7 +197,7 @@ std::optional<std::vector<ValueRef>> subtensor_grad_rule(
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule(
std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& indexingMultiAxisVec = op.cast_final_safe<IndexingMultiAxisVec>();
......@@ -215,9 +215,9 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule(
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
ValueRefList ret(1);
if (grad && inputs[0]) {
SmallVector<ValueRef> args_(inputs.size() + 1);
ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
args_[0] = zeros;
args_[1] = grad;
......@@ -232,7 +232,7 @@ std::optional<std::vector<ValueRef>> indexingMultiAxisVec_grad_rule(
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<std::vector<ValueRef>> reduce_grad_rule(
std::optional<ValueRefList> reduce_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto& reduce = op.cast_final_safe<Reduce>();
......@@ -251,7 +251,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
ValueRefList ret(1);
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0]);
}
......@@ -261,7 +261,7 @@ std::optional<std::vector<ValueRef>> reduce_grad_rule(
return imperative::apply(ApplyOp(op), inputs);
}
std::optional<std::vector<ValueRef>> addAxis_grad_rule(
std::optional<ValueRefList> addAxis_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& addAxis = op.cast_final_safe<AddAxis>();
......@@ -274,7 +274,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule(
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
ValueRefList ret(1);
if (grad && flag_) {
ret[0] = imperative::apply(*grad_op_, grad)[0];
}
......@@ -284,7 +284,7 @@ std::optional<std::vector<ValueRef>> addAxis_grad_rule(
return imperative::apply(op, inputs);
}
std::optional<std::vector<ValueRef>> removeAxis_grad_rule(
std::optional<ValueRefList> removeAxis_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
auto&& removeAxis = op.cast_final_safe<RemoveAxis>();
......@@ -297,7 +297,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule(
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
ValueRefList ret(1);
if (grad && flag_) {
ret[0] = imperative::apply(*grad_op_, grad)[0];
}
......@@ -307,7 +307,7 @@ std::optional<std::vector<ValueRef>> removeAxis_grad_rule(
return imperative::apply(op, inputs);
}
std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule(
std::optional<ValueRefList> fastpathcopy_grad_rule(
const OpDef& op, Span<ValueRef> inputs, Span<bool> inputs_require_grad,
CustomBackward& backward) {
mgb_assert(inputs.size() == 1);
......@@ -316,7 +316,7 @@ std::optional<std::vector<ValueRef>> fastpathcopy_grad_rule(
maker.backward([](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
std::vector<ValueRef> ret(1);
ValueRefList ret(1);
if (grad) {
ret[0] = grad;
}
......
......@@ -25,24 +25,23 @@ private:
py::function m_hook_fn;
int m_enabled = 0;
std::vector<ValueRef> apply_module_trace_hook(
const OpDef& op, Span<ValueRef> input_values) {
ValueRefList apply_module_trace_hook(const OpDef& op, Span<ValueRef> input_values) {
py::list input_tws;
for (auto&& input_value : input_values) {
input_tws.append(TensorWrapper::make(py_tensor_type, input_value));
}
py::list output_tws = m_hook_fn(py::cast(op.shared_from_this()), *input_tws);
std::vector<ValueRef> outputs;
ValueRefList outputs(output_tws.size());
auto it = outputs.begin();
for (auto&& output_tw : output_tws) {
outputs.push_back(
TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data());
*(it++) = TensorWrapper::try_cast(output_tw.ptr())->m_tensor->data();
}
return outputs;
}
public:
ModuleTraceTransformation(py::function hook_fn) : m_hook_fn(hook_fn) {}
std::vector<ValueRef> apply_transformation(
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (op.is<ApplyOp>() && m_enabled > 0) {
auto outputs = apply_module_trace_hook(op.cast<ApplyOp>().op(), inputs);
......
......@@ -87,7 +87,7 @@ PyObject* py_apply(
--nargs;
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
SmallVector<ValueRef, 64> tensors(nargs);
SmallVector<ValueRef, 8> tensors(nargs);
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) {
// swap to a special context to reuse scalar handle
......@@ -100,16 +100,15 @@ PyObject* py_apply(
Transformation::top());
std::make_shared<ScalarTransformation>()->register_at(
Transformation::top());
SmallVector<ValueRef> inputs(nargs);
for (size_t i = 0; i < nargs; ++i) {
auto* py_input = py::handle(args[i]).cast<PySymbolVar*>();
ValueRef input = SymbolValue::make(py_input->m_node);
if (py_input->is_scalar) {
input = ScalarValue::make(input);
}
inputs[i] = input;
tensors[i] = input;
}
auto outputs = imperative::apply(*op, inputs);
auto outputs = imperative::apply(*op, tensors);
auto ret = pybind11::tuple(outputs.size());
auto typeobj = py::handle(args[0]).get_type();
for (size_t i = 0; i < outputs.size(); ++i) {
......@@ -140,7 +139,7 @@ PyObject* py_apply(
}
}
auto outputs = imperative::apply(ApplyOp(*op), {tensors.data(), nargs});
auto outputs = imperative::apply(*op, tensors);
size_t nout = outputs.size();
auto ret = py::tuple(nout);
for (size_t i = 0; i < nout; ++i) {
......@@ -214,16 +213,10 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) {
if (!name.empty()) {
m_tensor->reset(
imperative::apply(RenameValue(name), m_tensor->data())[0]);
mgb_assert(
((std::string&)*m_tensor->data().name()) == name,
"result name incorrect");
}
if (data.ndim() == 0) {
mgb_assert(m_tensor->is_scalar(), "result should be scalar");
}
}
}
mgb_assert(m_tensor->data());
}
PyObject* TensorWrapper::module_trace_info() {
......@@ -1384,15 +1377,20 @@ void init_tensor(py::module m) {
std::function<bool(py::object, py::object)> array_comparator;
bool compare_value(ValueRef lhs, ValueRef rhs) {
if (!lhs.shape()->eq(*rhs.shape())) {
auto lvalue = lhs.numpy();
auto rvalue = rhs.numpy();
if (lvalue->shape() != rvalue->shape()) {
return false;
}
HostTensorND lvalue = lhs.numpy()->as_nd(true);
HostTensorND rvalue = rhs.numpy()->as_nd(true);
if (lvalue->shape().is_scalar()) {
return lvalue->item() == rvalue->item();
}
HostTensorND lnd = lvalue->as_nd(true);
HostTensorND rnd = rvalue->as_nd(true);
auto larr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(lvalue, npy::ShareType::TRY_SHARE));
npy::ndarray_from_tensor(lnd, npy::ShareType::TRY_SHARE));
auto rarr = py::reinterpret_steal<py::array>(
npy::ndarray_from_tensor(rvalue, npy::ShareType::TRY_SHARE));
npy::ndarray_from_tensor(rnd, npy::ShareType::TRY_SHARE));
return array_comparator(larr, rarr);
}
......@@ -1539,6 +1537,19 @@ void init_tensor(py::module m) {
}
});
m.def("reduce_to_scalar", [](py::object op, py::object tensor) {
auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto make_scalar_shape = [&](CompNode device) {
return imperative::apply(
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
HostStorage::make(device))[0];
};
auto output = imperative::apply(
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data(),
make_scalar_shape(tw->m_tensor->comp_node()))[0];
return TensorWrapper::make(py_tensor_type, output);
});
m.def("name_tensor", [](std::string name, py::object tensor) {
auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto output = imperative::apply(TraceMarkVar(name), tw->m_tensor->data())[0];
......@@ -1546,9 +1557,9 @@ void init_tensor(py::module m) {
});
m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
SmallVector<ValueRef> values;
for (auto&& tensor : tensors) {
values.push_back(tensor.cast<TensorWrapper>().m_tensor->data());
ValueRefList values(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
}
auto outputs = imperative::apply(GetGradKey(), values);
if (outputs[0].is<GradKeyValue>()) {
......@@ -1559,9 +1570,9 @@ void init_tensor(py::module m) {
});
m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
SmallVector<ValueRef> values;
for (auto&& tensor : tensors) {
values.push_back(tensor.cast<TensorWrapper>().m_tensor->data());
ValueRefList values(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
}
auto outputs = imperative::apply(GetGradKey(), values);
if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) {
......@@ -1578,7 +1589,7 @@ void init_tensor(py::module m) {
mgb_assert(GradKeyWrapper::wrap_t::type().isinstance(py_key.ptr()));
auto* key = reinterpret_cast<GradKeyWrapper::wrap_t*>(py_key.ptr())->inst();
GenericFunction generic_backward_fn =
[backward_fn](Span<ValueRef> output_grads) -> std::vector<ValueRef> {
[backward_fn](Span<ValueRef> output_grads) -> ValueRefList {
py::list output_grad_tws;
for (auto&& output_grad : output_grads) {
if (output_grad) {
......@@ -1589,23 +1600,25 @@ void init_tensor(py::module m) {
}
}
py::tuple input_grad_tws = backward_fn(*output_grad_tws);
std::vector<ValueRef> input_grads;
for (auto&& input_grad_tw : input_grad_tws) {
ValueRefList input_grads(input_grad_tws.size());
for (size_t i = 0; i < input_grad_tws.size(); ++i) {
auto input_grad_tw = input_grad_tws[i];
if (!input_grad_tw.is_none()) {
input_grads.push_back(
py::cast<TensorWrapper>(input_grad_tw).m_tensor->data());
input_grads[i] =
py::cast<TensorWrapper>(input_grad_tw).m_tensor->data();
} else {
input_grads.push_back({});
input_grads[i] = {};
}
}
return input_grads;
};
SmallVector<ValueRef> values;
for (auto&& input : inputs) {
values.push_back(input.cast<TensorWrapper>().m_tensor->data());
ValueRefList values(inputs.size() + outputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
}
for (auto&& output : outputs) {
values.push_back(output.cast<TensorWrapper>().m_tensor->data());
for (size_t i = 0; i < outputs.size(); ++i) {
values[i + inputs.size()] =
outputs[i].cast<TensorWrapper>().m_tensor->data();
}
auto wrapped_output_values = imperative::apply(
SetGrad(key->m_key, generic_backward_fn, inputs.size()), values);
......
......@@ -39,7 +39,7 @@ namespace mgb::imperative::python {
extern interpreter::Interpreter::Channel* interpreter_for_py;
extern PyTypeObject* py_tensor_type;
struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj {
struct Tensor : NonCopyableObj {
private:
std::string m_name;
ValueRef m_data;
......@@ -52,7 +52,7 @@ public:
~Tensor() = default;
inline std::shared_ptr<Tensor> copy() {
auto ret = std::make_shared<Tensor>(m_data.unwrap());
auto ret = std::make_shared<Tensor>(m_data);
ret->m_name = m_name;
return ret;
}
......
......@@ -11,7 +11,15 @@
#pragma once
#include <optional>
#include <string>
#include "pybind11/pybind11.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/value.h"
#include "megbrain/utils/small_vector.h"
namespace mgb::imperative::python {
struct TransformationManager {
......@@ -58,4 +66,14 @@ struct TransformationManager {
return sl_instance;
}
};
class PyValue final : public MixinValueImpl<PyValue, pybind11::object> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const {
return pybind11::str((const pybind11::object&)*this).cast<std::string>();
}
};
} // namespace mgb::imperative::python
......@@ -45,7 +45,7 @@ CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout)
layout.is_contiguous() || layout.is_empty(), "layout should be contiguous");
}
auto CreateTensor::parse(Span<ValueRef> inputs) -> Args {
auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
Args result;
for (auto&& input : inputs) {
if (auto host_storage = input.as_ref<HostStorage>()) {
......
......@@ -16,70 +16,67 @@
#include "megbrain/imperative/utils/map.h"
namespace mgb {
void imperative_log_profile_begin(const char* message);
void imperative_log_profile(const char* message);
void imperative_log_profile_end(const char* message);
namespace imperative {
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs) {
static bool log_dispatch = MGB_GETENV("MGE_LOG_OP_DISPATCH");
bool enable_watch = ValueRef::any_watching();
auto& context = Transformation::get_context();
size_t& depth = context.next_transformation;
static const char tabs_storage[] = "\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t";
const char* tabs = tabs_storage + sizeof(tabs_storage) / sizeof(char) - depth - 1;
bool log_current_dispatch = log_dispatch;
if (enable_watch) {
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
if (input.watching()) {
log_current_dispatch = true;
mgb_log_debug("%sinput[%zu] is %s", tabs, i, input.to_string().c_str());
debug::notify_event("apply");
}
}
}
// entrance
std::vector<ValueRef> outputs;
if (depth >= context.transformations.size()) {
// fallback
if (log_current_dispatch) {
mgb_log_debug(
"%sfallback apply %s in %s", tabs, op.to_string().c_str(),
imperative::to_string(inputs).c_str());
namespace {
MGB_NOINLINE void copy_outputs(
ForwardAllocator<ValueRef>& allocator, ValueRefList& outputs) {
size_t nr_outputs = outputs.size();
if (mgb_likely(nr_outputs == 1)) {
ValueRef output_copy;
output_copy = outputs[0];
allocator.clear();
outputs = ValueRefList({output_copy});
} else if (!outputs.empty()) {
SmallVector<ValueRef> outputs_copy(nr_outputs);
for (size_t i = 0; i < nr_outputs; ++i) {
outputs_copy[i] = outputs[i];
}
outputs = op.fallback(inputs);
outputs.clear();
allocator.clear();
outputs = {outputs_copy.begin(), outputs_copy.end()};
} else {
// dispatch to stack top
auto& transformation = *context.transformations[depth];
++depth;
context.frames.push_back({op, inputs});
CleanupGuard _{[&] {
context.frames.pop_back();
--depth;
}};
if (log_current_dispatch) {
mgb_log_debug(
"%s%s apply %s in %s", tabs, transformation.name().c_str(),
op.to_string().c_str(), imperative::to_string(inputs).c_str());
}
outputs = transformation.apply_transformation(op, inputs);
allocator.clear();
}
if (log_current_dispatch) {
mgb_log_debug("%sreturn %s", tabs, imperative::to_string(outputs).c_str());
}
} // namespace
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) {
auto& context = Transformation::get_context();
size_t& depth = context.next_transformation;
bool top = depth == 0;
auto outputs = ([&] {
if (mgb_unlikely(depth >= context.transformations.size())) {
return op.fallback(inputs);
} else {
auto& transformation = *context.transformations[depth++];
CleanupGuard _{[&] { --depth; }};
return transformation.apply_transformation(op, inputs);
}
})();
if (mgb_unlikely(top)) {
copy_outputs(context.allocator, outputs);
}
return outputs;
}
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs) {
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) {
return imperative::apply(ApplyOp{def}, inputs);
}
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) {
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) {
SmallVector<ValueRef> inputs_storage;
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_storage.push_back(inputs[i]);
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs,
size_t) {
auto outputs = imperative::apply(ApplyOp(*op), inputs);
auto outputs = imperative::apply(*op, inputs);
return SmallVector<ValueRef>(outputs.begin(), outputs.end());
};
auto make_const = [](TensorPtr constant) -> ValueRef {
......@@ -101,7 +98,7 @@ std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs) {
DeviceStorage::make(device_value.storage()))[0];
};
auto outputs = graph.apply(inputs_storage, apply_functor, make_const);
return {outputs.begin(), outputs.end()};
return ValueRefList{outputs.begin(), outputs.end()};
}
} // namespace imperative
......
......@@ -126,7 +126,7 @@ public:
m_frames[m_frames.size() - 1 - i] = {node, node->version()};
node = node->parent();
}
mgb_assert(node->is_root(), "");
mgb_assert(node->is_root());
}
Trace() = default;
std::string to_string() const {
......
......@@ -3,7 +3,7 @@
namespace mgb {
namespace imperative {
std::vector<ValueRef> Operator::fallback(Span<ValueRef> inputs) const {
ValueRefList Operator::fallback(Span<ValueRef> inputs) const {
mgb_throw(MegBrainError, "no fallback implementation for %s", to_string().c_str());
}
......
......@@ -99,19 +99,22 @@ Tensor::Tensor(
Tensor::Tensor(const HostTensorND& hv) : Tensor(hv.layout(), hv.comp_node()) {
constexpr int size_threshold = TensorShape::MAX_NDIM;
if (hv.layout().total_nr_elems() <= size_threshold) {
size_t nr_elems = hv.layout().total_nr_elems();
if (nr_elems <= size_threshold) {
m_value = hv;
}
MGB_RECORD_EVENT(
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(),
dev_tensor().raw_ptr());
dev_tensor().copy_from_fixlayout(hv);
// even though hv is saved in m_value, Tensor itself could be
// released before copy completes
MGB_RECORD_EVENT(
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(),
hv.raw_ptr(), dev_tensor().raw_ptr());
AsyncReleaser::inst()->add(hv);
if (nr_elems) {
MGB_RECORD_EVENT(
profiler::HostToDeviceEvent, hv.layout(), hv.comp_node(), hv.raw_ptr(),
dev_tensor().raw_ptr());
dev_tensor().copy_from_fixlayout(hv);
// even though hv is saved in m_value, Tensor itself could be
// released before copy completes
MGB_RECORD_EVENT(
profiler::HostToDeviceFinishEvent, hv.layout(), hv.comp_node(),
hv.raw_ptr(), dev_tensor().raw_ptr());
AsyncReleaser::inst()->add(hv);
}
}
Tensor::Tensor(const DeviceTensorND& dv, const HostTensorND& hv) {
......
......@@ -310,7 +310,8 @@ struct ChromeTimelineEventVisitor : EventVisitor<ChromeTimelineEventVisitor> {
} else if constexpr (std::is_same_v<TEvent, TensorGetPropEvent>) {
new_host_event("TensorGetProp", 'X')
.dur(0)
.args(current_tensor->detail(current->time));
.args(current_tensor->detail(current->time))
.arg("kind", imperative::to_string(event.prop));
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropEvent>) {
new_host_event("TensorWaitProp", 'B');
} else if constexpr (std::is_same_v<TEvent, TensorWaitPropFinishEvent>) {
......
......@@ -15,71 +15,109 @@
namespace mgb {
namespace imperative {
std::vector<ValueRef> InterpreterTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* op_val = op.as<ApplyOp>()) {
if (op_val->op().same_type<FastpathCopy>()) {
return {inputs[0]};
}
SmallVector<Handle> input_handles;
SmallVector<Handle> output_handles;
CleanupGuard _{[&] {
for (auto handle : output_handles) {
if (handle) {
m_channel->del(handle);
}
DTypeValue::ref_t InterpreterInfo::dtype() const {
if (!m_dtype) {
m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle()));
}
return m_dtype;
}
CompNodeValue::ref_t InterpreterInfo::comp_node() const {
if (!m_comp_node) {
m_comp_node = CompNodeValue::make(
handle()->channel()->get_device(handle()->handle()));
}
return m_comp_node;
}
ShapeValue::ref_t InterpreterInfo::shape() const {
if (!m_shape) {
m_shape = ShapeValue::make(
ValueShape::from(handle()->channel()->get_shape(handle()->handle())));
}
return m_shape;
}
ValueRefList InterpreterTransformation::apply_op(
const ApplyOp& apply_op, Span<ValueRef> inputs) {
if (apply_op.op().same_type<FastpathCopy>()) {
return {inputs[0]};
}
SmallVector<Handle> input_handles;
SmallVector<Handle> output_handles;
CleanupGuard _{[&] {
for (auto handle : output_handles) {
if (handle) {
m_channel->del(handle);
}
}};
for (auto input : inputs) {
input_handles.push_back(*input.cast<InterpreterValue>().handle());
}
output_handles =
m_channel->apply_op(op_val->op().shared_from_this(), input_handles);
std::vector<ValueRef> outputs;
for (auto& handle : output_handles) {
outputs.push_back(InterpreterValue::make(share_handle(handle)));
handle = nullptr;
}
return outputs;
}};
for (auto input : inputs) {
input_handles.push_back(input.cast<InterpreterValue>().handle()->handle());
}
output_handles =
m_channel->apply_op(apply_op.op().shared_from_this(), input_handles);
ValueRefList outputs(output_handles.size());
for (size_t i = 0; i < output_handles.size(); ++i) {
outputs[i] = InterpreterValue::make(share_handle(output_handles[i]));
output_handles[i] = nullptr;
}
return outputs;
}
ValueRefList InterpreterTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) {
auto& input = inputs.item().cast<InterpreterValue>();
ValueRef output;
switch (get_attr.attr()) {
case GetAttr::DType:
output = input.dtype();
break;
case GetAttr::Shape:
output = input.shape();
break;
case GetAttr::Device:
output = input.comp_node();
break;
case GetAttr::Value:
output = HostValue::make(m_channel->get_value(input.handle()->handle()));
break;
case GetAttr::Data:
output = DeviceValue::make(
m_channel->get_dev_tensor(input.handle()->handle()));
break;
default:
mgb_throw(
MegBrainError, "Interpreter: malformed GetAttr: %s",
get_attr.to_string().c_str());
}
return {output};
}
ValueRefList InterpreterTransformation::apply_create_tensor(
const CreateTensor& create_tensor, Span<ValueRef> inputs) {
auto args = create_tensor.parse(inputs);
if (!args.device) {
// implies H2D
mgb_assert(args.host, "neither host and device value is valid");
return {InterpreterValue::make(share_handle(
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
} else {
return {InterpreterValue::make(share_handle(m_channel->put(
*args.device, args.host ? *args.host : HostTensorND())))};
}
}
ValueRefList InterpreterTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* op_val = op.as<ApplyOp>()) {
return apply_op(*op_val, inputs);
} else if (auto* get_attr = op.as<GetAttr>()) {
Handle handle = *inputs[0].cast<InterpreterValue>().handle();
ValueRef output;
switch (get_attr->attr()) {
case GetAttr::DType:
output = DTypeValue::make(m_channel->get_dtype(handle));
break;
case GetAttr::Shape:
output = ShapeValue::make(
ValueShape::from(m_channel->get_shape(handle)));
break;
case GetAttr::Device:
output = CompNodeValue::make(m_channel->get_device(handle));
break;
case GetAttr::Value:
output = HostValue::make(m_channel->get_value(handle));
break;
case GetAttr::Data:
output = DeviceValue::make(m_channel->get_dev_tensor(handle));
break;
default:
mgb_throw(
MegBrainError, "Interpreter: malformed GetAttr: %s",
op.to_string().c_str());
}
return {output};
return apply_get_attr(*get_attr, inputs);
} else if (auto* create_tensor = op.as<CreateTensor>()) {
auto args = create_tensor->parse(inputs);
if (!args.device) {
// implies H2D
mgb_assert(args.host, "neither host and device value is valid");
return {InterpreterValue::make(share_handle(
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
} else {
return {InterpreterValue::make(share_handle(m_channel->put(
*args.device, args.host ? *args.host : HostTensorND())))};
}
return apply_create_tensor(*create_tensor, inputs);
} else if (auto* dtr_command = op.as<DTRCommand>()) {
auto handle = *inputs[0].cast<InterpreterValue>().handle();
auto handle = inputs[0].cast<InterpreterValue>().handle()->handle();
switch (dtr_command->kind()) {
case DTRCommand::Drop:
m_channel->drop(handle);
......
......@@ -64,12 +64,13 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
size_t count = std::count_if(
save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
if (!backward_graph->precomp.empty()) {
SmallVector<ValueRef> inputs_and_outputs;
ValueRefList inputs_and_outputs(inputs.size() + outputs.size());
auto it = inputs_and_outputs.begin();
for (auto&& input : inputs) {
inputs_and_outputs.push_back(input);
*it++ = input;
}
for (auto&& output : outputs) {
inputs_and_outputs.push_back(output);
*it++ = output;
}
auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs);
closure.reserve(precomp.size() + count);
......@@ -89,7 +90,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
}
}
void BackwardGraphWithClosure::operator()(
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) {
ValueRef args[closure.size() + grads.size()];
size_t nargs = 0;
for (auto&& value : closure) {
......@@ -120,7 +121,7 @@ void BackwardGraphWithClosure::operator()(
}
void CustomBackward::operator()(
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) {
size_t nargs = grads.size();
ValueRef args[nargs];
for (size_t i = 0; i < nargs; ++i) {
......@@ -201,9 +202,10 @@ void GradKey::backward() {
mgb_throw(AssertionError, "invalid backward");
} else {
mgb_assert(grad_fn->m_slots.size() > 0);
std::vector<ValueRef> grads;
ValueRefList grads (grad_fn->m_slots.size());
auto iter = grads.begin();
for (auto&& slot : grad_fn->m_slots) {
grads.push_back(slot.m_grad);
*iter++ = slot.m_grad;
}
backward(grads, grad_receiver);
}
......@@ -254,21 +256,28 @@ void GradKey::freeze() {
m_frozen = true;
}
std::vector<ValueRef> GradTransformation::apply_transformation(
ValueRefList GradTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
auto unwrap_inputs = [this](Span<ValueRef> inputs) -> SmallVector<ValueRef> {
SmallVector<ValueRef> unwrapped_inputs;
for (auto&& input : inputs) {
if (auto grad_value = as_grad_value(input)) {
unwrapped_inputs.push_back(grad_value->m_value);
auto fallback = [&] {
ValueRefList unwrapped_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto grad_value = as_grad_value(inputs[i])) {
unwrapped_inputs[i] = grad_value->m_value;
} else {
unwrapped_inputs.push_back(input);
unwrapped_inputs[i] = inputs[i];
}
}
return unwrapped_inputs;
return imperative::apply(op, unwrapped_inputs);
};
if (auto* get_attr = op.as<GetAttr>()) {
if (auto grad_value = as_grad_value(inputs.item())) {
return imperative::apply(op, grad_value->m_value);
} else {
return imperative::apply(op, inputs);
}
}
if (m_suppressed) {
return imperative::apply(op, unwrap_inputs(inputs));
return fallback();
}
if (auto* op_val = op.as<ApplyOp>()) {
size_t nr_require_grad = 0;
......@@ -284,20 +293,21 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
if (nr_require_grad == 0) {
return imperative::apply(op, inputs);
}
SmallVector<ValueRef> captured_inputs;
SmallVector<bool> inputs_require_grad;
ValueRefList captured_inputs(inputs.size());
SmallVector<bool> inputs_require_grad(inputs.size());
// capture value so that trace could assume input as same
auto capture_value = [](ValueRef value) {
// TODO: fastpath copy shouldn't be an OpDef
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0];
};
for (auto& input : inputs) {
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
if (auto grad_value = as_grad_value(input)) {
captured_inputs.push_back(capture_value(grad_value->m_value));
inputs_require_grad.push_back(true);
captured_inputs[i] = capture_value(grad_value->m_value);
inputs_require_grad[i] = true;
} else {
captured_inputs.push_back(capture_value(input));
inputs_require_grad.push_back(false);
captured_inputs[i] = capture_value(input);
inputs_require_grad[i] = false;
}
}
decltype(std::declval<GradFn>().m_backward) backward_storage;
......@@ -373,9 +383,11 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
mgb_assert(!grad_fn->m_slots.empty());
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()});
return outputs;
} else if (op.is<CreateTensor>()) {
return imperative::apply(op, inputs);
} else if (auto* attach_grad = op.as<AttachGrad>()) {
if (!has_key(attach_grad->key())) {
return imperative::apply(op, unwrap_inputs(inputs));
return fallback();
}
auto tensor = inputs[0];
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>();
......@@ -386,7 +398,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
return {record_grad(output)};
} else if (auto* grad_backward = op.as<GradBackward>()) {
if (!has_key(grad_backward->key())) {
return imperative::apply(op, unwrap_inputs(inputs));
return fallback();
}
size_t nr_grads = inputs.size() / 2;
mgb_assert(nr_grads * 2 == inputs.size());
......@@ -416,7 +428,7 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
backward.m_output_attrs =
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true});
backward.m_backward = set_grad->grad_fn();
std::vector<ValueRef> outputs;
ValueRefList outputs(nr_outputs);
grad_fn->m_key = m_key;
grad_fn->m_slots.resize(nr_outputs);
grad_fn->m_dests.reserve(nr_inputs);
......@@ -439,13 +451,13 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
} else {
grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i));
}
outputs.push_back(record_grad(grad_value));
outputs[i] = record_grad(grad_value);
}
m_key->m_tape.push_back({grad_fn, nullptr});
return outputs;
} else if (auto* gbc = op.as<GetBackwardColsure>()) {
if (gbc->key() != m_key) {
return imperative::apply(op, unwrap_inputs(inputs));
return fallback();
}
return {FunctionValue::make(make_backward_closure(inputs))};
} else if (op.is<DetachGrad>()) {
......@@ -471,21 +483,8 @@ std::vector<ValueRef> GradTransformation::apply_transformation(
} else {
return imperative::apply(op, inputs);
}
} else if (op.is<CreateTensor>()) {
return imperative::apply(op, inputs);
} else {
SmallVector<ValueRef> unwrapped_inputs;
for (auto&& input : inputs) {
if (auto grad_value = as_grad_value(input)) {
unwrapped_inputs.push_back(grad_value->m_value);
} else {
unwrapped_inputs.push_back(input);
}
}
auto outputs = imperative::apply(
op, {unwrapped_inputs.data(), unwrapped_inputs.size()});
mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty());
return outputs;
return fallback();
}
}
......@@ -500,8 +499,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
y_slots.emplace_back();
}
}
GenericFunction closure = [grad_key,
y_slots](Span<ValueRef> dys) -> std::vector<ValueRef> {
GenericFunction closure = [grad_key, y_slots](Span<ValueRef> dys) -> ValueRefList {
size_t nr_grads = y_slots.size();
mgb_assert(dys.size() == nr_grads);
for (size_t i = 0; i < nr_grads; ++i) {
......
......@@ -21,7 +21,7 @@
namespace mgb {
namespace imperative {
std::vector<ValueRef> LazyEvalTransformation::apply_transformation(
ValueRefList LazyEvalTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* op_val = op.as<ApplyOp>()) {
static std::unordered_set<Typeinfo*> mm_io_ops = {
......@@ -59,9 +59,9 @@ std::vector<ValueRef> LazyEvalTransformation::apply_transformation(
mgb_assert(!output_nodes.empty());
m_io_link = SymbolVar(output_nodes[0]);
}
std::vector<ValueRef> outputs;
for (auto&& output_node : output_nodes) {
outputs.push_back(record_var(output_node));
ValueRefList outputs(output_nodes.size());
for (size_t i = 0; i < output_nodes.size(); ++i) {
outputs[i] = record_var(output_nodes[i]);
}
return outputs;
} else if (auto* create_tensor = op.as<CreateTensor>()) {
......
......@@ -19,26 +19,8 @@ namespace imperative {
namespace {
using ScalarRule = std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>;
static std::unordered_map<
Typeinfo*, std::function<std::vector<ValueRef>(const OpDef&, Span<ValueRef>)>>
scalar_rules;
ValueRef unwrap_input(ValueRef input) {
if (auto scalar_input = input.as_ref<ScalarValue>()) {
return scalar_input->value();
} else {
return input;
}
}
std::vector<ValueRef> unwrap_inputs(Span<ValueRef> inputs) {
std::vector<ValueRef> unwrapped_inputs;
for (auto&& input : inputs) {
unwrapped_inputs.push_back(unwrap_input(input));
}
return unwrapped_inputs;
}
using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>);
static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules;
ValueRef make_scalar_shape(CompNode device) {
HostTensorND scalar_shape(device, {1}, dtype::Int32());
......@@ -49,9 +31,6 @@ ValueRef make_scalar_shape(CompNode device) {
}
bool is_scalar_shape(ValueRef shape) {
if (shape.is<ScalarValue>()) {
return false;
}
// may have performance issue
auto shape_of_shape = shape.shape();
if (!shape_of_shape) {
......@@ -61,74 +40,65 @@ bool is_scalar_shape(ValueRef shape) {
return *shape_of_shape == ValueShape{0};
}
template <typename T>
void register_scalar_rule(std::vector<ValueRef> (*rule)(const T&, Span<ValueRef>)) {
scalar_rules[T::typeinfo()] = [rule](const OpDef& def, Span<ValueRef> inputs) {
return (*rule)(def.cast_final_safe<T>(), inputs);
template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)>
void register_scalar_rule() {
scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask);
};
}
std::vector<ValueRef> elemwise_rule(const Elemwise& elem, Span<ValueRef> inputs) {
template <typename TOpDef, size_t nr_inputs>
ValueRefList elemwise_rule(
const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) {
if constexpr (nr_inputs != 0) {
mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch");
}
bool all_scalar = true;
for (auto&& input : inputs) {
if (!input.is<ScalarValue>()) {
for (auto&& input_mask : inputs_mask) {
if (!input_mask) {
all_scalar = false;
break;
}
}
auto output = imperative::apply(elem, unwrap_inputs(inputs))[0];
auto outputs = imperative::apply(op_def, inputs);
if (all_scalar) {
return {ScalarValue::make(output)};
} else {
return {output};
outputs[0] = ScalarValue::make(outputs[0]);
}
return outputs;
}
std::vector<ValueRef> remove_axis_rule(
const RemoveAxis& remove_axis, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 1);
mgb_assert(!inputs[0].is<ScalarValue>());
auto output = imperative::apply(remove_axis, inputs)[0];
bool is_scalar = inputs[0].shape()->ndim == remove_axis.axis.size();
ValueRefList remove_axis_rule(
const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) {
mgb_assert(!inputs_mask.item());
bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size();
if (is_scalar && remove_axis.axis.size() == 1) {
return {ScalarValue::make(inputs.item())};
}
auto outputs = imperative::apply(remove_axis, inputs);
if (is_scalar) {
return {ScalarValue::make(output)};
} else {
return {output};
outputs[0] = ScalarValue::make(outputs[0]);
}
return outputs;
}
std::vector<ValueRef> reduce_rule(const Reduce& reduce, Span<ValueRef> inputs) {
ValueRefList reduce_rule(
const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) {
if (inputs.size() == 1) {
return imperative::apply(reduce, unwrap_inputs(inputs));
return imperative::apply(reduce, inputs);
}
mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]);
if (is_scalar) {
auto unwrapped_input = unwrap_input(inputs[0]);
CompNode device = *unwrapped_input.device();
return {ScalarValue::make(imperative::apply(
reduce, unwrapped_input, make_scalar_shape(device))[0])};
}
auto output = imperative::apply(reduce, unwrap_inputs(inputs))[0];
if (is_scalar) {
return {ScalarValue::make(output)};
} else {
return {output};
}
}
std::vector<ValueRef> typecvt_rule(const TypeCvt& typecvt, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 1);
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) {
CompNode device = *inputs[0].device();
return {ScalarValue::make(
imperative::apply(typecvt, scalar_input->value())[0])};
} else {
return imperative::apply(typecvt, inputs);
imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])};
}
return imperative::apply(reduce, inputs);
}
std::vector<ValueRef> collective_comm_rule(
const CollectiveComm& collective_comm, Span<ValueRef> inputs) {
ValueRefList collective_comm_rule(
const CollectiveComm& collective_comm, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
mgb_assert(inputs.size() == 1);
static std::unordered_set<CollectiveComm::Mode> modes = {
CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN,
......@@ -138,17 +108,17 @@ std::vector<ValueRef> collective_comm_rule(
if (modes.count(collective_comm.mode) == 0) {
return imperative::apply(collective_comm, inputs);
}
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) {
return {ScalarValue::make(
imperative::apply(collective_comm, scalar_input->value())[0])};
if (inputs_mask.item()) {
return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])};
} else {
return imperative::apply(collective_comm, inputs);
}
}
std::vector<ValueRef> param_pack_split_rule(
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs) {
auto outputs = imperative::apply(param_pack_split, unwrap_inputs(inputs));
ValueRefList param_pack_split_rule(
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
auto outputs = imperative::apply(param_pack_split, inputs);
size_t nr_outputs = outputs.size();
mgb_assert(nr_outputs == param_pack_split.shapes.size());
for (size_t i = 0; i < nr_outputs; ++i) {
......@@ -159,29 +129,28 @@ std::vector<ValueRef> param_pack_split_rule(
return outputs;
}
std::vector<ValueRef> dot_rule(const Dot& dot, Span<ValueRef> inputs) {
return {ScalarValue::make(imperative::apply(dot, unwrap_inputs(inputs))[0])};
ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) {
return {ScalarValue::make(imperative::apply(dot, inputs)[0])};
}
std::vector<ValueRef> add_axis_rule(const AddAxis& add_axis, Span<ValueRef> inputs) {
ValueRefList add_axis_rule(
const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) {
mgb_assert(inputs.size() == 1);
if (auto scalar_input = inputs[0].as_ref<ScalarValue>()) {
if (inputs_mask.item()) {
mgb_assert(add_axis.axis[0] == 0);
if (add_axis.axis.size() == 1) {
return {scalar_input->value()};
return {inputs[0]};
} else {
std::vector<int32_t> axis(add_axis.axis.begin() + 1, add_axis.axis.end());
return imperative::apply(
ApplyOp(*AddAxis::make(axis, add_axis.scope())),
scalar_input->value());
return imperative::apply(*AddAxis::make(axis, add_axis.scope()), inputs[0]);
}
} else {
return imperative::apply(add_axis, inputs);
}
}
std::vector<ValueRef> remote_recv_rule(
const RemoteRecv& remote_recv, Span<ValueRef> inputs) {
ValueRefList remote_recv_rule(
const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) {
if (remote_recv.shape.empty()) {
std::vector<int32_t> shape = {1};
auto remote_recv_no_scalar = RemoteRecv::make(
......@@ -189,32 +158,32 @@ std::vector<ValueRef> remote_recv_rule(
remote_recv.rank_from, remote_recv.cn, shape, remote_recv.dtype,
remote_recv.backend);
remote_recv_no_scalar->set_scope(remote_recv.scope());
return imperative::apply(
ApplyOp(*remote_recv_no_scalar), unwrap_inputs(inputs));
return imperative::apply(ApplyOp(*remote_recv_no_scalar), inputs);
} else {
return imperative::apply(remote_recv, unwrap_inputs(inputs));
return imperative::apply(remote_recv, inputs);
}
}
std::vector<ValueRef> check_no_finite_rule(
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs) {
auto outputs = imperative::apply(check_no_finite, unwrap_inputs(inputs));
ValueRefList check_no_finite_rule(
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
auto outputs = imperative::apply(check_no_finite, inputs);
mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch");
outputs.back() = ScalarValue::make(outputs.back());
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs[i].is<ScalarValue>()) {
if (inputs_mask[i]) {
outputs[i] = ScalarValue::make(outputs[i]);
}
}
return outputs;
}
std::vector<ValueRef> subtensor_rule(
const Subtensor& subtensor, Span<ValueRef> inputs) {
ValueRefList subtensor_rule(
const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) {
mgb_assert(inputs.size() >= 1);
auto input = inputs[0];
bool is_scalar;
mgb_assert(!input.is<ScalarValue>(), "subtensor shouldn't have scalar input");
mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input");
if (auto shape = input.shape()) {
size_t ndim = input.shape()->ndim;
for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
......@@ -226,25 +195,25 @@ std::vector<ValueRef> subtensor_rule(
} else {
is_scalar = false;
}
auto output = imperative::apply(subtensor, unwrap_inputs(inputs))[0];
auto outputs = imperative::apply(subtensor, inputs);
if (is_scalar) {
return {ScalarValue::make(output)};
} else {
return {output};
outputs[0] = ScalarValue::make(outputs[0]);
}
return outputs;
}
std::vector<ValueRef> get_var_shape_rule(
const GetVarShape& get_var_shape, Span<ValueRef> inputs) {
ValueRefList get_var_shape_rule(
const GetVarShape& get_var_shape, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
bool all_scalar = true;
mgb_assert(inputs.size() >= 1);
for (auto&& input : inputs) {
if (!input.is<ScalarValue>()) {
for (auto&& input_mask : inputs_mask) {
if (!input_mask) {
all_scalar = false;
}
}
if (all_scalar) {
auto device = inputs[0].cast<ScalarValue>().value().device();
auto device = inputs[0].device();
auto storage = HostStorage::make(*device);
// storage->ensure_size(1);
return imperative::apply(
......@@ -252,88 +221,49 @@ std::vector<ValueRef> get_var_shape_rule(
CreateTensor::Const, *device, dtype::Int32(), ValueShape{0}),
storage);
} else {
return imperative::apply(get_var_shape, unwrap_inputs(inputs));
}
}
std::vector<ValueRef> fastpath_copy_rule(
const FastpathCopy& fastpath_copy, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 1);
bool is_scalar = inputs[0].is<ScalarValue>();
auto output = imperative::apply(fastpath_copy, unwrap_inputs(inputs))[0];
if (is_scalar) {
return {ScalarValue::make(output)};
} else {
return {output};
return imperative::apply(get_var_shape, inputs);
}
}
std::vector<ValueRef> reshape_rule(const Reshape& reshape, Span<ValueRef> inputs) {
ValueRefList reshape_rule(
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) {
mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]);
auto unwrapped_input = inputs[0].is<ScalarValue>()
? inputs[0].cast<ScalarValue>().value()
: inputs[0];
if (is_scalar) {
return {ScalarValue::make(imperative::apply(
reshape, unwrapped_input,
make_scalar_shape(*unwrapped_input.device()))[0])};
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
} else {
return imperative::apply(reshape, unwrap_inputs(inputs));
return imperative::apply(reshape, inputs);
}
}
std::vector<ValueRef> broadcast_rule(
const Broadcast& broadcast, Span<ValueRef> inputs) {
ValueRefList broadcast_rule(
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) {
mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]);
auto unwrapped_input = inputs[0].is<ScalarValue>()
? inputs[0].cast<ScalarValue>().value()
: inputs[0];
if (is_scalar) {
return {ScalarValue::make(imperative::apply(
broadcast, unwrapped_input,
make_scalar_shape(*unwrapped_input.device()))[0])};
} else {
return imperative::apply(broadcast, unwrap_inputs(inputs));
}
}
std::vector<ValueRef> copy_rule(const Copy& copy, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 1);
bool is_scalar = inputs[0].is<ScalarValue>();
if (is_scalar) {
return {ScalarValue::make(imperative::apply(copy, unwrap_inputs(inputs))[0])};
} else {
return imperative::apply(copy, unwrap_inputs(inputs));
}
}
std::vector<ValueRef> inplace_add_rule(
const InplaceAdd& inplace_add, Span<ValueRef> inputs) {
mgb_assert(inputs.size() == 4);
bool is_scalar = inputs[0].is<ScalarValue>();
if (is_scalar) {
return {ScalarValue::make(
imperative::apply(inplace_add, unwrap_inputs(inputs))[0])};
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
} else {
return imperative::apply(inplace_add, unwrap_inputs(inputs));
return imperative::apply(broadcast, inputs);
}
}
template <typename T>
std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) {
ValueRefList subgraph_op_rule(
const T& op, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
// TODO: add flag instead of assume
bool all_scalar = true;
for (auto&& input : inputs) {
if (!input.is<ScalarValue>()) {
for (auto&& input_mask : inputs_mask) {
if (!input_mask) {
all_scalar = false;
}
}
auto outputs = imperative::apply(op, unwrap_inputs(inputs));
auto outputs = imperative::apply(op, inputs);
if (all_scalar) {
for (auto& output : outputs) {
output = ScalarValue::make(output);
output = scalar_type.make(output);
}
}
return outputs;
......@@ -341,67 +271,54 @@ std::vector<ValueRef> subgraph_op_rule(const T& op, Span<ValueRef> inputs) {
struct ScalarRuleRegistry {
ScalarRuleRegistry() {
register_scalar_rule(elemwise_rule);
register_scalar_rule(remove_axis_rule);
register_scalar_rule(reduce_rule);
register_scalar_rule(typecvt_rule);
register_scalar_rule(collective_comm_rule);
register_scalar_rule(param_pack_split_rule);
register_scalar_rule(dot_rule);
register_scalar_rule(add_axis_rule);
register_scalar_rule(remote_recv_rule);
register_scalar_rule(check_no_finite_rule);
register_scalar_rule(subtensor_rule);
register_scalar_rule(get_var_shape_rule);
register_scalar_rule(fastpath_copy_rule);
register_scalar_rule(reshape_rule);
register_scalar_rule(broadcast_rule);
register_scalar_rule(copy_rule);
register_scalar_rule(inplace_add_rule);
register_scalar_rule(subgraph_op_rule<SubgraphOp>);
register_scalar_rule(subgraph_op_rule<CompiledOp>);
register_scalar_rule<Elemwise, elemwise_rule<Elemwise, 0>>();
register_scalar_rule<RemoveAxis, remove_axis_rule>();
register_scalar_rule<Reduce, reduce_rule>();
register_scalar_rule<TypeCvt, elemwise_rule<TypeCvt, 1>>();
register_scalar_rule<CollectiveComm, collective_comm_rule>();
register_scalar_rule<ParamPackSplit, param_pack_split_rule>();
register_scalar_rule<Dot, dot_rule>();
register_scalar_rule<AddAxis, add_axis_rule>();
register_scalar_rule<RemoteRecv, remote_recv_rule>();
register_scalar_rule<CheckNonFinite, check_no_finite_rule>();
register_scalar_rule<Subtensor, subtensor_rule>();
register_scalar_rule<GetVarShape, get_var_shape_rule>();
register_scalar_rule<FastpathCopy, elemwise_rule<FastpathCopy, 1>>();
register_scalar_rule<Reshape, reshape_rule>();
register_scalar_rule<Broadcast, broadcast_rule>();
register_scalar_rule<Copy, elemwise_rule<Copy, 1>>();
register_scalar_rule<InplaceAdd, elemwise_rule<InplaceAdd, 4>>();
register_scalar_rule<SubgraphOp, subgraph_op_rule<SubgraphOp>>();
register_scalar_rule<CompiledOp, subgraph_op_rule<CompiledOp>>();
}
} _;
} // namespace
std::vector<ValueRef> ScalarTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto apply_op = op.as<ApplyOp>()) {
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo());
if (iter != scalar_rules.end()) {
return iter->second(apply_op->op(), inputs);
} else {
// TODO: repeat op
return imperative::apply(op, unwrap_inputs(inputs));
}
} else if (auto* create_tensor = op.as<CreateTensor>()) {
if (create_tensor->shape().is_scalar()) {
ValueShape scalar_shape = {1};
CreateTensor scalar_op(
create_tensor->kind(), create_tensor->device(),
create_tensor->dtype(), scalar_shape);
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])};
} else {
return imperative::apply(op, inputs);
}
} else if (auto* get_attr = op.as<GetAttr>()) {
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>();
auto output = imperative::apply(op, unwrap_inputs(inputs))[0];
if (!is_scalar) {
return {output};
ValueRefList ScalarTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) {
auto&& input = inputs.item();
bool is_scalar = input.is<ScalarValue>();
if (!is_scalar) {
return imperative::apply(get_attr, input);
}
auto unwrapped_input = input.cast<ScalarValue>().value();
if (get_attr.attr() == GetAttr::Shape) {
if (!m_empty_shape) {
m_empty_shape = ShapeValue::make();
}
switch (get_attr->attr()) {
case GetAttr::Shape: {
// Scalar Shape
return {ShapeValue::make()};
}
return {m_empty_shape};
} else {
auto outputs = imperative::apply(get_attr, unwrapped_input);
auto& output = outputs[0];
switch (get_attr.attr()) {
case GetAttr::Value: {
auto& hv = output.cast<HostValue>();
mgb_assert(
hv.shape() == ValueShape({1}),
"underlying value should has shape {1}, got %s",
hv.shape().to_string().c_str());
return {HostValue::make(hv.dtype(), ValueShape(), hv.storage())};
output = HostValue::make(hv.dtype(), ValueShape(), hv.storage());
break;
}
case GetAttr::Data: {
auto& dv = output.cast<DeviceValue>();
......@@ -409,22 +326,67 @@ std::vector<ValueRef> ScalarTransformation::apply_transformation(
dv.shape() == ValueShape({1}),
"underlying value should has shape {1}, got %s",
dv.shape().to_string().c_str());
return {DeviceValue::make(dv.dtype(), ValueShape(), dv.storage())};
output = DeviceValue::make(dv.dtype(), ValueShape(), dv.storage());
break;
}
default:
return {output};
break;
}
return outputs;
}
}
ValueRefList ScalarTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* get_attr = op.as<GetAttr>()) {
// fastpath for GetAttr
return apply_get_attr(*get_attr, inputs);
}
size_t nr_inputs = inputs.size();
ValueRefList unwrapped_inputs(nr_inputs);
bool inputs_mask[nr_inputs];
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto scalar_value = inputs[i].as_ref<ScalarValue>()) {
unwrapped_inputs[i] = scalar_value->value();
inputs_mask[i] = true;
} else {
unwrapped_inputs[i] = inputs[i];
inputs_mask[i] = false;
}
}
auto fallback = [&] { return imperative::apply(op, unwrapped_inputs); };
if (auto apply_op = op.as<ApplyOp>()) {
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo());
if (iter != scalar_rules.end()) {
return iter->second(
apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs});
} else {
// TODO: repeat op
return fallback();
}
} else if (auto* create_tensor = op.as<CreateTensor>()) {
if (create_tensor->shape().is_scalar()) {
ValueShape scalar_shape = {1};
CreateTensor scalar_op(
create_tensor->kind(), create_tensor->device(),
create_tensor->dtype(), scalar_shape);
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])};
} else {
return imperative::apply(op, inputs);
}
} else if (op.as<IsScalar>()) {
return {BoolValue::make(inputs.as_array<1>()[0].is<ScalarValue>())};
mgb_assert(nr_inputs == 1);
return {BoolValue::make(inputs_mask[0])};
} else if (op.is<Operator::IdentityLike>()) {
bool is_scalar = inputs.as_array<1>()[0].is<ScalarValue>();
mgb_assert(nr_inputs == 1);
bool is_scalar = inputs_mask[0];
auto outputs = fallback();
if (is_scalar) {
return {ScalarValue::make(imperative::apply(op, unwrap_inputs(inputs))[0])};
} else {
return imperative::apply(op, inputs);
outputs[0] = ScalarValue::make(outputs[0]);
}
return outputs;
} else {
return imperative::apply(op, unwrap_inputs(inputs));
return fallback();
}
};
......
/**
* \file imperative/src/impl/transformations/tangent.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/imperative/transformations/tangent.h"
namespace mgb {
namespace imperative {
ValueRefList TangentTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* apply_op = op.as<ApplyOp>()) {
}
mgb_assert(false);
}
} // namespace imperative
} // namespace mgb
......@@ -153,7 +153,7 @@ VarNodeArray TraceResult::dump(
return output_nodes;
}
std::vector<ValueRef> TracingTransformation::apply_transformation(
ValueRefList TracingTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* op_value = op.as<ApplyOp>()) {
SmallVector<ValueRef> unwrapped_inputs;
......@@ -180,11 +180,12 @@ std::vector<ValueRef> TracingTransformation::apply_transformation(
}
const_cast<OpDef&>(op_value->op()).set_scope(scopes_join);
auto unwrapped_outputs = imperative::apply(op, unwrapped_inputs);
std::vector<ValueRef> wrapped_outputs;
ValueRefList wrapped_outputs(unwrapped_outputs.size());
SmallVector<size_t> output_ids;
for (auto&& output : unwrapped_outputs) {
for (size_t i = 0; i < unwrapped_outputs.size(); ++i) {
auto&& output = unwrapped_outputs[i];
auto wrapped_output = record_var(output, false, VarKind::Internal);
wrapped_outputs.push_back(wrapped_output);
wrapped_outputs[i] = wrapped_output;
output_ids.push_back(wrapped_output->id());
}
m_seq.push_back({op_value->op().shared_from_this(), input_ids, output_ids});
......@@ -375,6 +376,11 @@ void CompiledTransformation::compile() {
return accessor;
};
std::vector<VarAccessor> var_accessors(m_vars.size());
auto exc_setter = std::bind(
&CompiledTransformation::set_exception, this, std::placeholders::_1);
for (auto&& accessor : var_accessors) {
accessor.exc_setter = exc_setter;
}
for (auto&& item : m_seq) {
bool require_link = bool(item.op) && mm_io_ops.count(item.op->dyn_typeinfo());
VarNodeArray input_vars;
......@@ -509,8 +515,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
}
}
TracedValue::ref_t CompiledTransformation::trace_output(size_t id) {
auto traced_value = TracedValue::make(id);
auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t {
auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]);
m_weak_values.push_back(traced_value);
return traced_value;
}
......@@ -520,64 +526,99 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
return m_seq[m_pc++];
}
std::vector<ValueRef> CompiledTransformation::apply_transformation(
ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
if (!m_shape) {
trace_assert(m_accessor->shape_getter, "shape unreadable");
m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter()));
}
return m_shape;
}
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const {
if (!m_dtype) {
m_dtype = DTypeValue::make(m_var->dtype);
}
return m_dtype;
}
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const {
if (!m_comp_node) {
m_comp_node = CompNodeValue::make(m_var->device);
}
return m_comp_node;
}
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& {
return *m_accessor;
}
ValueRefList CompiledTransformation::apply_op(
const ApplyOp& apply_op, Span<ValueRef> inputs) {
auto& item = next_instruction();
trace_assert(inputs.size() == item.inputs.size(), "input size mismatch");
trace_assert(apply_op.op().is_same(*item.op), "operator mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
trace_input(item.inputs[i], inputs[i]);
}
ValueRefList outputs(item.outputs.size());
for (size_t i = 0; i < item.outputs.size(); ++i) {
outputs[i] = trace_output(item.outputs[i]);
}
return outputs;
}
ValueRefList CompiledTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) {
if (auto* traced_value = inputs[0].as<TracedValue>()) {
ValueRef output;
auto& var_accessor = traced_value->accessor();
switch (get_attr.attr()) {
case GetAttr::Shape:
output = traced_value->shape();
break;
case GetAttr::Data:
trace_assert(var_accessor.data_getter, "data unreadable");
output = DeviceValue::make(var_accessor.data_getter());
break;
case GetAttr::Value:
trace_assert(var_accessor.value_getter, "value unreadable");
output = HostValue::make(var_accessor.value_getter());
break;
case GetAttr::DType:
output = traced_value->dtype();
break;
case GetAttr::Device:
output = traced_value->comp_node();
default:
break;
}
return {output};
} else {
return imperative::apply(get_attr, inputs);
}
}
ValueRefList CompiledTransformation::apply_create_tensor(
const CreateTensor& create_tensor, Span<ValueRef> inputs) {
if (create_tensor.kind() == CreateTensor::NoTrace) {
return imperative::apply(create_tensor, inputs);
}
auto& item = next_instruction();
trace_assert(item.op == nullptr, "operator mismatch");
auto input_id = item.inputs[0];
auto output_id = item.outputs[0];
auto tensor = imperative::apply(create_tensor, inputs)[0];
trace_input(input_id, tensor);
return {trace_output(output_id)};
}
ValueRefList CompiledTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* op_value = op.as<ApplyOp>()) {
auto& item = next_instruction();
SmallVector<ValueRef> unwrapped_inputs;
SmallVector<ValueRef> wrapped_inputs;
trace_assert(inputs.size() == item.inputs.size(), "input size mismatch");
trace_assert(op_value->op().is_same(*item.op), "operator mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
trace_input(item.inputs[i], inputs[i]);
}
std::vector<ValueRef> outputs;
for (auto&& output_id : item.outputs) {
outputs.push_back(trace_output(output_id));
}
return outputs;
return apply_op(*op_value, inputs);
} else if (auto* create_tensor = op.as<CreateTensor>()) {
if (create_tensor->kind() == CreateTensor::NoTrace) {
return imperative::apply(op, inputs);
}
auto& item = next_instruction();
trace_assert(item.op == nullptr, "operator mismatch");
auto input_id = item.inputs[0];
auto output_id = item.outputs[0];
auto tensor = imperative::apply(op, inputs)[0];
trace_input(input_id, tensor);
return {trace_output(output_id)};
return apply_create_tensor(*create_tensor, inputs);
} else if (auto* get_attr = op.as<GetAttr>()) {
if (auto* traced_value = inputs[0].as<TracedValue>()) {
ValueRef output;
auto& var = m_vars[traced_value->id()];
auto& var_accessor = m_var_accessors[traced_value->id()];
switch (get_attr->attr()) {
case GetAttr::Shape:
trace_assert(var_accessor.shape_getter, "shape unreadable");
output = ShapeValue::make(
ValueShape::from(var_accessor.shape_getter()));
break;
case GetAttr::Data:
trace_assert(var_accessor.data_getter, "data unreadable");
output = DeviceValue::make(var_accessor.data_getter());
break;
case GetAttr::Value:
trace_assert(var_accessor.value_getter, "value unreadable");
output = HostValue::make(var_accessor.value_getter());
break;
case GetAttr::DType:
output = DTypeValue::make(var.dtype);
break;
case GetAttr::Device:
output = CompNodeValue::make(var.device);
default:
break;
}
return {output};
} else {
return imperative::apply(op, inputs);
}
return apply_get_attr(*get_attr, inputs);
} else if (auto* trace_mark_var = op.as<TraceMarkVar>()) {
auto& item = next_instruction();
trace_assert(item.op == nullptr, "operator mismatch");
......
......@@ -8,50 +8,58 @@ namespace mgb {
namespace imperative {
namespace {
static thread_local size_t nr_watched_values = 0;
static thread_local uint64_t nr_values = 0;
static thread_local bool recording_values = false;
static thread_local std::vector<ValueWeakRef> recorded_values;
static /*thread_local*/ size_t nr_watched_values = 0;
static /*thread_local*/ uint64_t nr_values = 0;
static /*thread_local*/ bool recording_values = false;
static /*thread_local*/ std::vector<ValueWeakRef> recorded_values;
static WeakValueMap<uint64_t, ValueWeakRef> registered_values;
} // namespace
ValueRef::storage_t& ValueRef::storage() const {
if (!m_storage) {
if (mgb_likely(!m_storage->m_successor.m_storage)) {
return m_storage;
}
if (auto& storage = m_storage->m_successor.m_storage) {
while (storage->m_successor.m_storage) {
storage = storage->m_successor.m_storage;
}
return storage;
} else {
return m_storage;
while (m_storage->m_successor.m_storage) {
m_storage = m_storage->m_successor.m_storage;
}
return m_storage;
}
const Value* ValueRef::as(size_t typecode) const {
auto&& storage = this->storage();
if (storage->m_typecode != typecode) {
return nullptr;
}
return static_cast<Value*>(storage.get());
}
bool ValueRef::is(size_t typecode) const {
return this->storage()->m_typecode == typecode;
}
TypedValueRef<DeviceValue> ValueRef::dev_tensor() const {
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].as_ref<DeviceValue>();
return imperative::apply(GetAttr(GetAttr::Data), *this)[0].cast_ref<DeviceValue>();
}
TypedValueRef<HostValue> ValueRef::numpy() const {
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].as_ref<HostValue>();
return imperative::apply(GetAttr(GetAttr::Value), *this)[0].cast_ref<HostValue>();
}
TypedValueRef<CompNodeValue> ValueRef::device() const {
return imperative::apply(GetAttr(GetAttr::Device), *this)[0]
.as_ref<CompNodeValue>();
.cast_ref<CompNodeValue>();
}
TypedValueRef<ShapeValue> ValueRef::shape() const {
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].as_ref<ShapeValue>();
return imperative::apply(GetAttr(GetAttr::Shape), *this)[0].cast_ref<ShapeValue>();
}
TypedValueRef<DTypeValue> ValueRef::dtype() const {
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].as_ref<DTypeValue>();
return imperative::apply(GetAttr(GetAttr::DType), *this)[0].cast_ref<DTypeValue>();
}
TypedValueRef<StringValue> ValueRef::name() const {
return imperative::apply(GetName(), *this)[0].as_ref<StringValue>();
return imperative::apply(GetName(), *this)[0].cast_ref<StringValue>();
}
bool ValueRef::is_scalar() const {
......@@ -75,13 +83,15 @@ void ValueRef::unwatch() const {
}
ValueRef ValueRef::unwrap() const {
ValueRef value = *this;
auto& context = Transformation::get_context();
for (size_t i = 0; i < context.next_transformation; ++i) {
value = context.transformations[i]->unwrap(value);
if (mgb_unlikely(context.next_transformation)) {
ValueRef value = *this;
for (size_t i = 0; i < context.next_transformation; ++i) {
value = context.transformations[i]->unwrap(value);
}
return value;
}
mgb_assert(value);
return value;
return *this;
}
std::string ValueRef::to_string() const {
......@@ -101,13 +111,11 @@ std::string ValueRef::raw_type() const {
return types[m_storage->m_typecode].name();
}
uint64_t ValueRef::id() const {
return m_storage ? m_storage->m_id : std::numeric_limits<uint64_t>::max();
}
bool ValueRef::watching() const {
auto storage = this->storage();
return storage && storage->m_watching;
if (!m_storage) {
return false;
}
return this->storage()->m_watching;
}
ValueRef ValueRef::make(ValueRef::storage_t storage) {
......@@ -186,5 +194,96 @@ void Value::try_rethrow() {
}
}
inline void ValueRefList::init(size_t nr_elems) {
m_size = nr_elems;
if (m_size > 0) {
if (m_size == 1) {
m_data = inline_storage();
} else {
auto& context = Transformation::get_context();
m_data = context.allocator.allocate(m_size);
}
for (size_t i = 0; i < m_size; ++i) {
new (m_data + i) ValueRef();
}
} else {
m_data = nullptr;
}
}
ValueRefList::ValueRefList(size_t nr_elems) {
init(nr_elems);
}
ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
: ValueRefList(values.begin(), values.end()) {}
ValueRefList::ValueRefList(const ValueRefList& rhs)
: ValueRefList(rhs.cbegin(), rhs.cend()) {}
ValueRefList::ValueRefList(ValueRefList&& rhs) : ValueRefList() {
m_size = rhs.m_size;
if (rhs.m_data == rhs.inline_storage()) {
m_data = inline_storage();
new (m_data) ValueRef();
m_data[0] = std::move(rhs.m_data[0]);
} else {
m_data = rhs.m_data;
rhs.m_data = nullptr;
rhs.m_size = 0;
}
}
ValueRefList& ValueRefList::operator=(const ValueRefList& rhs) {
if (this == &rhs) {
return *this;
}
clear();
init(rhs.m_size);
for (size_t i = 0; i < m_size; ++i) {
m_data[i] = rhs.m_data[i];
}
return *this;
}
ValueRefList& ValueRefList::operator=(ValueRefList&& rhs) {
if (this == &rhs) {
return *this;
}
clear();
if (rhs.m_data == rhs.inline_storage()) {
m_data = inline_storage();
new (m_data) ValueRef();
m_data[0] = rhs.m_data[0];
m_size = 1;
rhs.clear();
} else {
m_data = rhs.m_data;
m_size = rhs.m_size;
rhs.m_data = nullptr;
rhs.m_size = 0;
}
return *this;
}
ValueRefList::~ValueRefList() {
clear();
}
void ValueRefList::clear() {
for (size_t i = 0; i < m_size; ++i) {
m_data[i].~ValueRef();
}
if (m_data) {
if (m_size != 1) {
Transformation::get_context().allocator.deallocate(m_data, m_size);
} else {
mgb_assert(m_data == inline_storage());
}
}
m_data = nullptr;
m_size = 0;
}
} // namespace imperative
} // namespace mgb
......@@ -24,8 +24,6 @@ namespace imperative {
class GradKey;
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>;
/**
* \brief apply an OpDef to values
*
......@@ -37,7 +35,7 @@ private:
public:
ApplyOp(const OpDef& op) : m_op(op) {}
const OpDef& op() { return m_op; }
const OpDef& op() const { return m_op; }
std::string to_string() const override;
};
......@@ -106,7 +104,7 @@ public:
* \param inputs contains host_storage and device_storage
* \return Args unpacked args
*/
Args parse(Span<ValueRef> inputs);
Args parse(Span<ValueRef> inputs) const;
Kind kind() const { return m_kind; }
CompNode device() const { return m_device; }
......@@ -129,11 +127,11 @@ private:
public:
DTRCommand(Kind kind) : m_kind(kind) {}
Kind kind() { return m_kind; }
Kind kind() const { return m_kind; }
std::string to_string() const override;
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override { return {}; }
ValueRefList fallback(Span<ValueRef> inputs) const override { return {}; }
};
// deprecated
......@@ -141,9 +139,7 @@ class GetName final : public OperatorImpl<GetName, Operator::GetAttrLike> {
public:
std::string to_string() const override;
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
return {ValueRef()};
}
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
};
/**
......@@ -161,7 +157,7 @@ public:
std::string to_string() const override;
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
ValueRefList fallback(Span<ValueRef> inputs) const override {
return {inputs.as_array<1>()[0]};
}
};
......
......@@ -23,7 +23,7 @@ namespace imperative {
class GradKey;
using GenericFunction = std::function<std::vector<ValueRef>(Span<ValueRef>)>;
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>;
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> {
public:
......@@ -97,6 +97,10 @@ public:
ValueShape shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); }
HostTensorStorage storage() const { return m_storage; }
DTypeScalar item() const {
mgb_assert(m_shape.is_scalar());
return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr());
}
HostTensorND as_nd(bool allow_scalar = false) const;
};
......
......@@ -36,11 +36,11 @@ namespace imperative {
*
* \param op
* \param inputs
* \return std::vector<ValueRef>
* \return ValueRefList
*/
std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs);
std::vector<ValueRef> apply(const OpDef& def, Span<ValueRef> inputs);
std::vector<ValueRef> apply(Subgraph graph, Span<ValueRef> inputs);
ValueRefList apply(const Operator& op, Span<ValueRef> inputs);
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs);
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs);
template <typename... TArgs>
constexpr bool is_all_value_ref_v =
......@@ -49,7 +49,7 @@ constexpr bool is_all_value_ref_v =
template <typename T, typename... TArgs>
static auto apply(T&& op, TArgs&&... args)
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, std::vector<ValueRef>> {
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, ValueRefList> {
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...};
return imperative::apply(
std::forward<T&&>(op),
......@@ -63,7 +63,7 @@ static auto apply(T&& op, TContainer&& container) -> std::enable_if_t<
ValueRef> &&
std::is_same_v<decltype(container.size()), size_t> &&
!std::is_same_v<std::decay_t<TContainer>, Span<ValueRef>>,
std::vector<ValueRef>> {
ValueRefList> {
return imperative::apply(
std::forward<T&&>(op), Span<ValueRef>(container.data(), container.size()));
}
......
......@@ -25,6 +25,8 @@
namespace mgb {
namespace imperative {
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>;
/**
* \brief base class for all operators
*
......@@ -49,25 +51,24 @@ public:
Kind kind() const { return m_kind; }
template <typename U>
U* as() const {
const U* as() const {
if (m_typecode != U::TYPE_CODE) {
return nullptr;
}
return static_cast<U*>(const_cast<Operator*>(this));
return static_cast<const U*>(this);
}
template <typename U>
bool is() const {
return as<U>() != nullptr;
return m_typecode == U::TYPE_CODE;
}
template <Kind kKind>
bool is() const {
return kind() == kKind;
}
template <typename U>
U& cast() const {
U* ptr = as<U>();
mgb_assert(ptr);
return *ptr;
const U& cast() const {
mgb_assert(m_typecode == U::TYPE_CODE);
return static_cast<const U&>(*this);
}
virtual std::string to_string() const = 0;
......@@ -77,9 +78,9 @@ public:
* implementation.
*
* \param inputs
* \return std::vector<ValueRef>
* \return ValueRefList
*/
virtual std::vector<ValueRef> fallback(Span<ValueRef> inputs) const;
virtual ValueRefList fallback(Span<ValueRef> inputs) const;
std::type_index type() const { return registered_types()[m_typecode]; }
......
......@@ -123,7 +123,6 @@ public:
template <typename T, typename... TArgs>
static uint64_t record(TArgs&&... args) {
auto& profiler = get_instance();
// auto& mem_pool = get_mem_pool<T>();
if constexpr (sm_debug) {
Status expected = Running;
mgb_assert(profiler.m_status.compare_exchange_strong(expected, Recording));
......
......@@ -18,6 +18,7 @@
#include "megbrain/common.h"
#include "megbrain/imperative/subgraph.h"
#include "megbrain/imperative/utils/allocator.h"
#include "megbrain/imperative/utils/local_ptr.h"
#include "megbrain/imperative/utils/span.h"
......@@ -25,6 +26,7 @@ namespace mgb {
namespace imperative {
class ValueRef;
class ValueRefList;
class Operator;
class Transformation;
......@@ -43,6 +45,7 @@ struct TransformationContext {
// TODO: deprecate TransformationGuard, let next_transformation == frames.size()
size_t next_transformation = 0;
std::vector<TransformationFrame> frames;
ForwardAllocator<ValueRef> allocator;
};
/**
......@@ -86,9 +89,9 @@ public:
*
* \param op
* \param inputs
* \return std::vector<ValueRef>
* \return ValueRefList
*/
virtual std::vector<ValueRef> apply_transformation(
virtual ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) = 0;
virtual ValueRef unwrap(ValueRef value) = 0;
......@@ -187,11 +190,12 @@ public:
std::swap(context.transformations, current_context.transformations);
std::swap(context.scopes, current_context.scopes);
std::swap(context.next_transformation, current_context.next_transformation);
std::swap(context.allocator, current_context.allocator);
}
static TransformationContext& get_context();
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs);
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs);
friend class ValueRef;
};
......
......@@ -23,16 +23,38 @@ public:
using Handle = interpreter::Interpreter::Handle;
using Channel = interpreter::Interpreter::Channel;
class RAIIHandle : public NonCopyableObj {
private:
Handle m_handle = nullptr;
Channel* m_channel = nullptr;
public:
RAIIHandle(Handle handle, Channel* channel)
: m_handle(handle), m_channel(channel) {}
~RAIIHandle() { m_channel->del(m_handle); }
Handle handle() const { return m_handle; }
Channel* channel() const { return m_channel; }
};
private:
std::shared_ptr<Handle> m_handle = nullptr;
LocalPtr<RAIIHandle> m_handle;
std::string m_name;
mutable DTypeValue::ref_t m_dtype;
mutable CompNodeValue::ref_t m_comp_node;
mutable ShapeValue::ref_t m_shape;
public:
InterpreterInfo() = default;
InterpreterInfo(std::shared_ptr<Handle> handle, std::string name = {})
InterpreterInfo(LocalPtr<RAIIHandle> handle, std::string name = {})
: m_handle(handle), m_name(name) {}
std::shared_ptr<Handle> handle() const { return m_handle; }
const LocalPtr<RAIIHandle>& handle() const { return m_handle; }
DTypeValue::ref_t dtype() const;
CompNodeValue::ref_t comp_node() const;
ShapeValue::ref_t shape() const;
std::string name() const { return m_name; }
};
......@@ -60,6 +82,7 @@ class InterpreterTransformation final : public Transformation {
public:
using Interpreter = interpreter::Interpreter;
using Handle = Interpreter::Handle;
using SharedHandle = LocalPtr<InterpreterInfo::RAIIHandle>;
using Channel = Interpreter::Channel;
private:
......@@ -71,7 +94,14 @@ public:
Channel* channel() { return m_channel.get(); }
std::vector<ValueRef> apply_transformation(
ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs);
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs);
ValueRefList apply_create_tensor(
const CreateTensor& create_tensor, Span<ValueRef> inputs);
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
......@@ -81,14 +111,8 @@ public:
std::string name() const override { return "InterpreterTransformation"; }
std::shared_ptr<Handle> share_handle(Handle handle) {
return std::shared_ptr<Handle>(
new Handle(handle), [channel = m_channel.get()](Handle* ptr) {
if (ptr) {
channel->del(*ptr);
delete ptr;
}
});
SharedHandle share_handle(Handle handle) {
return SharedHandle::make(handle, m_channel.get());
}
};
......
......@@ -34,9 +34,7 @@ struct BackwardGraphWithClosure {
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs);
void operator()(
std::vector<ValueRef> grads,
std::function<void(size_t, ValueRef)> receiver);
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; }
......@@ -50,12 +48,11 @@ struct BackwardGraphWithClosure {
struct CustomBackward;
using GradRuleFn =
std::function<std::vector<ValueRef>(Span<ValueRef> inputs, CustomBackward&)>;
using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>;
struct CustomBackward {
using BackwardFn = std::function<std::vector<ValueRef>(Span<ValueRef>)>;
using BackwardRule = std::function<std::optional<std::vector<ValueRef>>(
using BackwardFn = std::function<ValueRefList(Span<ValueRef>)>;
using BackwardRule = std::function<std::optional<ValueRefList>(
const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>;
BackwardFn m_backward;
SmallVector<bool, 8> m_input_has_grad;
......@@ -65,9 +62,7 @@ struct CustomBackward {
SmallVector<OutputAttr> m_output_attrs;
public:
void operator()(
std::vector<ValueRef> grads,
std::function<void(size_t, ValueRef)> receiver);
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return m_input_has_grad[i]; }
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; }
......@@ -188,7 +183,7 @@ public:
std::string to_string() const override;
bool has_key(std::shared_ptr<GradKey> key) const { return m_key == key; }
bool has_key(const std::shared_ptr<GradKey>& key) const { return m_key == key; }
const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const {
mgb_assert(m_key == key);
......@@ -287,7 +282,7 @@ public:
return false;
}
std::vector<ValueRef> apply_transformation(
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
......@@ -314,7 +309,7 @@ private:
public:
std::string to_string() const override { return "DetachValue"; }
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
ValueRefList fallback(Span<ValueRef> inputs) const override {
return {inputs.as_array<1>()[0]};
}
};
......@@ -325,7 +320,7 @@ private:
public:
AttachGrad(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::shared_ptr<GradKey> key() const { return m_key; }
std::string to_string() const override {
return ssprintf("AttachGradValue{key=%s}", m_key->name().c_str());
......@@ -339,7 +334,7 @@ private:
public:
GradBackward(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::shared_ptr<GradKey> key() const { return m_key; }
std::string to_string() const override {
return ssprintf("GradBackwardValue{key=%s}", m_key->name().c_str());
......@@ -352,13 +347,13 @@ private:
public:
IsAttachedTo(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::shared_ptr<GradKey> key() const { return m_key; }
std::string to_string() const override {
return ssprintf("IsAttachedToValue{key=%s}", m_key->name().c_str());
}
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
ValueRefList fallback(Span<ValueRef> inputs) const override {
return {BoolValue::make(false)};
}
};
......@@ -373,9 +368,9 @@ public:
SetGrad(std::shared_ptr<GradKey> key, GenericFunction grad_fn, size_t nr_inputs)
: m_key(key), m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {}
GenericFunction grad_fn() { return m_grad_fn; }
GenericFunction grad_fn() const { return m_grad_fn; }
size_t nr_inputs() { return m_nr_inputs; }
size_t nr_inputs() const { return m_nr_inputs; }
std::string to_string() const override {
return ssprintf("SetGradValue{key=%s}", m_key->name().c_str());
......@@ -388,9 +383,7 @@ public:
std::string to_string() const override { return ssprintf("GetGradKeyValue{}"); }
std::vector<ValueRef> fallback(Span<ValueRef> inputs) const override {
return {ValueRef()};
}
ValueRefList fallback(Span<ValueRef> inputs) const override { return {ValueRef()}; }
};
class GetBackwardColsure
......@@ -401,7 +394,7 @@ private:
public:
GetBackwardColsure(std::shared_ptr<GradKey> key) : m_key(key) {}
std::shared_ptr<GradKey> key() { return m_key; }
std::shared_ptr<GradKey> key() const { return m_key; }
std::string to_string() const override {
return ssprintf("GetBackwardClosure{key=%s}", m_key->name().c_str());
......
......@@ -81,7 +81,7 @@ public:
ComputingGraph::Options& options() { return m_graph->options(); }
std::vector<ValueRef> apply_transformation(
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
......
......@@ -11,6 +11,7 @@
#pragma once
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/dispatch.h"
#include "megbrain/imperative/ops/autogen.h"
......@@ -45,8 +46,10 @@ public:
*/
class ScalarTransformation final : public Transformation {
private:
ShapeValue::ref_t m_empty_shape; // []
public:
std::vector<ValueRef> apply_transformation(
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs);
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
......
......@@ -50,7 +50,7 @@ private:
public:
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {}
std::vector<ValueRef> apply_transformation(
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override {
if (auto* apply_op = op.as<ApplyOp>()) {
SmallVector<VarNode*> input_nodes;
......@@ -58,9 +58,9 @@ public:
input_nodes.push_back(input.cast<SymbolValue>().node());
}
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes);
std::vector<ValueRef> outputs;
for (auto&& output_node : output_nodes) {
outputs.push_back(SymbolValue::make(output_node));
ValueRefList outputs(output_nodes.size());
for (size_t i = 0; i < output_nodes.size(); ++i) {
outputs[i] = SymbolValue::make(output_nodes[i]);
}
return outputs;
} else if (auto* create_tensor = op.as<CreateTensor>()) {
......
/**
* \file imperative/src/include/megbrain/imperative/grad.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "megbrain/imperative/basic_operators.h"
#include "megbrain/imperative/operator.h"
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/value.h"
namespace mgb::imperative {
struct TangentInfo {
ValueRef value;
ValueRef tangent;
};
class TangentTransformation final : public Transformation {
public:
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override { mgb_assert(false); }
std::string name() const override { return "Tangent"; }
};
} // namespace mgb::imperative
......@@ -126,25 +126,6 @@ public:
void on_unwatch() override { value().unwatch(); }
};
class TracedInfo {
private:
size_t m_id = 0;
public:
TracedInfo() = default;
TracedInfo(size_t id) : m_id(id) {}
size_t id() const { return m_id; }
};
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf("TracedValue{\"id\"=%zu}", id());
}
};
/**
* \brief trace operation sequence to TraceResult
*
......@@ -202,7 +183,7 @@ public:
return value;
}
std::vector<ValueRef> apply_transformation(
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
......@@ -248,6 +229,40 @@ public:
std::function<DeviceTensorND()> data_getter;
std::function<HostTensorND()> value_getter;
std::function<void(DeviceTensorND)> data_setter;
std::function<void(std::exception_ptr)> exc_setter;
};
class TracedInfo {
private:
size_t m_id = 0;
VarInfo* m_var = nullptr;
VarAccessor* m_accessor = nullptr;
mutable ShapeValue::ref_t m_shape;
mutable DTypeValue::ref_t m_dtype;
mutable CompNodeValue::ref_t m_comp_node;
public:
TracedInfo() = default;
TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor)
: m_id(id), m_var(var), m_accessor(accessor) {}
size_t id() const { return m_id; }
ShapeValue::ref_t shape() const;
DTypeValue::ref_t dtype() const;
CompNodeValue::ref_t comp_node() const;
const VarAccessor& accessor() const;
void set_exception(std::exception_ptr exc) const {
m_accessor->exc_setter(exc);
}
};
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf("TracedValue{\"id\"=%zu}", id());
}
};
private:
......@@ -319,7 +334,14 @@ public:
TraceResult::SeqItem& next_instruction();
std::vector<ValueRef> apply_transformation(
ValueRefList apply_op(const ApplyOp& apply_op, Span<ValueRef> inputs);
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs);
ValueRefList apply_create_tensor(
const CreateTensor& create_tensor, Span<ValueRef> inputs);
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
void on_unregister() noexcept override;
......
......@@ -36,12 +36,12 @@ private:
public:
Allocator(pool_type* pool) : m_pool(pool) {}
T* allocate(size_type n) {
pointer allocate(size_type n) {
mgb_assert(n == 1);
return m_pool->alloc(sizeof(T));
}
void deallocate(pointer* p, size_type n) {
void deallocate(pointer p, size_type n) {
mgb_assert(n == 1);
m_pool->free(p);
}
......@@ -68,4 +68,114 @@ public:
bool operator!=(const ThreadLocalAllocatorAdapter& rhs) const { return false; }
};
} // namespace mgb::imperative
\ No newline at end of file
template <typename T>
class ForwardAllocator {
public:
using value_type = T;
using size_type = std::size_t;
using pointer = T*;
static constexpr size_t alignment = alignof(T);
static constexpr size_t element_offset =
sizeof(T) +
((sizeof(T) % alignment) ? 0 : (alignment - sizeof(T) % alignment));
private:
struct Block {
std::unique_ptr<std::byte[]> data;
size_t size = 0;
size_t capacity = 0;
T* allocate(size_type n) {
static_assert(element_offset > std::max(alignment, sizeof(T)));
size_t begin = size;
size_t end = begin + element_offset * n;
if (end > capacity) {
return nullptr;
}
size = end;
return reinterpret_cast<T*>(data.get() + begin);
}
void reset() { size = 0; }
};
std::vector<Block> m_used;
std::optional<Block> m_current;
size_t block_size = 16 * 1024 * 1024;
size_t nr_allocated = 0;
private:
Block allocate_block() {
block_size *= 2;
return Block{std::make_unique<std::byte[]>(block_size), 0, block_size};
}
public:
pointer allocate(size_type n) {
if (!m_current) {
m_current.emplace(allocate_block());
}
pointer pointer = m_current->allocate(n);
while (pointer == nullptr) {
m_used.push_back(allocate_block());
std::swap(m_used.back(), *m_current);
pointer = m_current->allocate(n);
}
nr_allocated++;
return pointer;
}
void deallocate(pointer p, size_type n) {
mgb_assert(nr_allocated > 0);
nr_allocated--;
}
void clear() {
if (mgb_likely(m_used.empty())) {
// fastpath
if (m_current) {
m_current->reset();
}
} else {
// trim
*m_current = allocate_block();
m_used.clear();
}
mgb_assert(nr_allocated == 0);
}
bool operator==(const ForwardAllocator& rhs) const { return &rhs == this; }
bool operator!=(const ForwardAllocator& rhs) const { return &rhs != this; }
};
template <typename T, template <typename> typename TAllocator>
class ProxyAllocator {
public:
using value_type = T;
using size_type = typename TAllocator<T>::size_type;
using pointer = typename TAllocator<T>::pointer;
private:
TAllocator<T>* m_impl;
public:
T* allocate(size_type n) { return m_impl->allocate(n); }
void deallocate(pointer* p, size_type n) { return m_impl->deallocate(p, n); }
bool operator==(const ProxyAllocator<T, TAllocator>& rhs) const {
if (m_impl == rhs.m_impl) {
return true;
} else if (bool(m_impl) ^ bool(rhs.m_impl)) {
return false;
} else {
return *m_impl == *rhs.m_impl;
}
}
bool operator!=(const ProxyAllocator<T, TAllocator>& rhs) const {
return !((*this) == rhs);
}
};
} // namespace mgb::imperative
......@@ -16,6 +16,8 @@
#include "megbrain/imperative/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
#define MGB_FAT_LOCAL_PTR 0
namespace mgb::imperative {
template <typename T>
......@@ -52,6 +54,8 @@ private:
}
}
size_t ref_count() const { return m_ref_count; }
template <typename U>
friend class LocalPtr;
......@@ -88,14 +92,24 @@ public:
using storage_t = LocalPtrStorage<T>;
using pool_t = MemPool<storage_t>;
using weak_type = LocalWeakPtr<T>;
using pointer_t = T*;
private:
storage_t* m_storage = nullptr;
#if MGB_FAT_LOCAL_PTR
pointer_t m_pointer = nullptr;
#endif
// (m_storage == nullptr) == (m_pointer == nullptr)
void emplace(storage_t* ptr) {
if (ptr) {
ptr->inc_ref();
m_storage = ptr;
#if MGB_FAT_LOCAL_PTR
m_pointer = ptr->m_pointer;
#endif
}
}
......@@ -103,8 +117,22 @@ private:
public:
LocalPtr() = default;
LocalPtr(const LocalPtr& rhs) { (*this) = rhs; }
LocalPtr(LocalPtr&& rhs) { (*this) = std::move(rhs); }
LocalPtr(const LocalPtr& rhs) {
auto storage = rhs.m_storage;
if (storage) {
storage->inc_ref();
}
m_storage = storage;
#if MGB_FAT_LOCAL_PTR
m_pointer = rhs.m_pointer;
#endif
}
LocalPtr(LocalPtr&& rhs) {
std::swap(m_storage, rhs.m_storage);
#if MGB_FAT_LOCAL_PTR
std::swap(m_pointer, rhs.m_pointer);
#endif
}
LocalPtr& operator=(const LocalPtr& rhs) {
if (this == &rhs) {
return *this;
......@@ -115,9 +143,11 @@ public:
}
if (m_storage) {
m_storage->dec_ref();
// rhs.m_storage may be invalid here
}
m_storage = storage;
#if MGB_FAT_LOCAL_PTR
m_pointer = rhs.m_pointer;
#endif
return *this;
}
LocalPtr& operator=(LocalPtr&& rhs) {
......@@ -125,6 +155,9 @@ public:
return *this;
}
std::swap(m_storage, rhs.m_storage);
#if MGB_FAT_LOCAL_PTR
std::swap(m_pointer, rhs.m_pointer);
#endif
rhs.reset();
return *this;
}
......@@ -186,10 +219,11 @@ public:
T& operator*() const { return *get(); }
T* get() const {
if ((!m_storage) || !m_storage->m_pointer) {
return nullptr;
}
return m_storage->m_pointer;
#if MGB_FAT_LOCAL_PTR
return m_pointer;
#else
return m_storage ? m_storage->m_pointer : nullptr;
#endif
}
T* operator->() const { return get(); }
......@@ -202,6 +236,9 @@ public:
if (m_storage) {
m_storage->dec_ref();
m_storage = nullptr;
#if MGB_FAT_LOCAL_PTR
m_pointer = nullptr;
#endif
}
}
......
......@@ -49,8 +49,8 @@ public:
instance = std::make_unique<MemPool<T>>();
sm_instance = instance.get();
}
mgb_assert(sm_instance);
}
return *sm_instance;
}
};
......@@ -62,9 +62,9 @@ std::unordered_map<std::thread::id, std::unique_ptr<MemPool<T>>>
MemPoolUtils<T>::sm_instances;
template <typename T>
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance;
thread_local MemPool<T>* MemPoolUtils<T>::tm_instance = nullptr;
template <typename T>
MemPool<T>* MemPoolUtils<T>::sm_instance;
MemPool<T>* MemPoolUtils<T>::sm_instance = nullptr;
} // namespace mgb::imperative
\ No newline at end of file
} // namespace mgb::imperative
......@@ -95,6 +95,8 @@ struct ValueShape {
}
return true;
}
bool operator!=(const ValueShape& rhs) const { return !operator==(rhs); }
};
static_assert(sizeof(size_t) >= sizeof(int));
......
......@@ -47,6 +47,17 @@ class StringValue;
class Operator;
class ValueRefList;
template <typename T>
class Type {
private:
const size_t m_code = T::TYPE_CODE;
public:
inline size_t code() const { return m_code; }
};
/**
* \brief an smart reference of value
*
......@@ -64,8 +75,9 @@ public:
protected:
mutable storage_t m_storage;
size_t m_id = std::numeric_limits<size_t>::max();
ValueRef(storage_t storage) { m_storage = storage; }
inline ValueRef(storage_t storage);
private:
/**
......@@ -75,6 +87,10 @@ private:
*/
storage_t& storage() const;
const Value* as(size_t typecode) const;
bool is(size_t typecode) const;
public:
ValueRef() = default;
......@@ -86,7 +102,7 @@ public:
* \return false if empty or type of value is not TValue
*/
template <typename TValue>
bool is() const;
inline bool is(Type<TValue> type = {}) const;
/**
* \brief try cast value as target type
......@@ -95,7 +111,7 @@ public:
* \return TValue* raw pointer if success, otherwise nullptr
*/
template <typename TValue>
const TValue* as() const;
inline const TValue* as(Type<TValue> type = {}) const;
/**
* \brief cast value to target type
......@@ -104,7 +120,7 @@ public:
* \return TValue& reference of value
*/
template <typename TValue>
const TValue& cast() const;
inline const TValue& cast(Type<TValue> type = {}) const;
/**
* \brief like as(), but returns TypedValueRef instead
......@@ -113,7 +129,13 @@ public:
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template <typename TValue>
inline TypedValueRef<TValue> as_ref() const;
inline TypedValueRef<TValue> as_ref(Type<TValue> type = {}) const;
template <typename TValue>
inline TypedValueRef<TValue> cast_ref(Type<TValue> type = {}) const;
template <typename TValue>
void on_cast_failure() const;
operator bool() const { return bool(m_storage); }
......@@ -132,7 +154,7 @@ public:
ValueRef unwrap() const;
std::string to_string() const;
std::string raw_type() const;
uint64_t id() const;
uint64_t id() const { return m_id; }
size_t hash() const { return id(); }
static ValueRef make(storage_t storage);
......@@ -144,7 +166,7 @@ public:
friend class TypedValueRef;
template <typename T>
friend class ValueImpl;
friend std::vector<ValueRef> apply(const Operator& op, Span<ValueRef> inputs);
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs);
};
template <>
......@@ -244,7 +266,7 @@ public:
using ref_t = TypedValueRef<T>;
using weak_ref_t = TypedValueWeakRef<T>;
static inline size_t TYPE_CODE = [] { return register_type(typeid(T)); }();
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }();
/**
* \brief helper function for construct a value
......@@ -254,7 +276,7 @@ public:
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
static TypedValueRef<T> make(TArgs&&... args) {
static MGB_NOINLINE TypedValueRef<T> make(TArgs&&... args) {
static_assert(std::is_final_v<T>);
return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...));
}
......@@ -279,46 +301,60 @@ public:
bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; }
};
inline ValueRef::ValueRef(storage_t storage) {
// mgb_assert(storage);
m_storage = storage;
m_id = m_storage->m_id;
}
template <typename TValue>
const TValue* ValueRef::as() const {
inline const TValue* ValueRef::as(Type<TValue> type) const {
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>);
auto storage = this->storage();
if (!storage) {
return nullptr;
}
if (storage->m_typecode != TValue::TYPE_CODE) {
return nullptr;
}
return static_cast<TValue*>(storage.get());
return static_cast<const TValue*>(as(type.code()));
}
template <typename TValue>
const TValue& ValueRef::cast() const {
auto* ptr = as<TValue>();
if (!ptr) {
// if this is ErrorValue, rethrow directly
storage()->try_rethrow();
mgb_assert(
ptr, "expect type %s, got %s", typeid(TValue).name(),
to_string().c_str());
inline const TValue& ValueRef::cast(Type<TValue> type) const {
auto* ptr = as<TValue>(type);
if (mgb_unlikely(!ptr)) {
on_cast_failure<TValue>();
}
return *ptr;
return static_cast<const TValue&>(*ptr);
}
template <typename TValue>
inline bool ValueRef::is(Type<TValue> type) const {
return is(type.code());
}
template <typename TValue>
bool ValueRef::is() const {
auto* ptr = as<TValue>();
return ptr != nullptr;
inline TypedValueRef<TValue> ValueRef::as_ref(Type<TValue> type) const {
if (!is<TValue>(type)) {
return {};
}
return TypedValueRef<TValue>(*this);
}
template <typename TValue>
TypedValueRef<TValue> ValueRef::as_ref() const {
if (!is<TValue>()) {
inline TypedValueRef<TValue> ValueRef::cast_ref(Type<TValue> type) const {
if (!m_storage) {
return {};
}
if (mgb_unlikely(!is<TValue>(type))) {
on_cast_failure<TValue>();
}
return TypedValueRef<TValue>(*this);
}
template <typename TValue>
void ValueRef::on_cast_failure() const {
// if this is ErrorValue, rethrow directly
storage()->try_rethrow();
mgb_assert(
storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s",
typeid(TValue).name(), to_string().c_str());
}
/**
* \brief ValueRef with concrete type, convenient for dereference
*
......@@ -361,11 +397,87 @@ private:
public:
TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {}
TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {}
TypedValueRef<T> lock() { return ValueWeakRef::lock().template as_ref<T>(); }
TypedValueRef<T> lock() {
auto value = ValueWeakRef::lock();
if (value) {
return value.template as_ref<T>();
} else {
return {};
}
}
};
// TODO: add proxy value type, which is meant to be reset in the end
class ValueRefList {
private:
ValueRef* m_data = nullptr;
size_t m_size = 0;
std::aligned_storage_t<sizeof(ValueRef), alignof(ValueRef)> m_storage;
private:
void init(size_t nr_elems);
ValueRef* inline_storage() { return reinterpret_cast<ValueRef*>(&m_storage); }
public:
ValueRefList() = default;
ValueRefList(size_t nr_elems);
ValueRefList(ValueRef item);
ValueRefList(std::initializer_list<ValueRef> values);
template <typename TIterator>
ValueRefList(TIterator begin, TIterator end);
ValueRefList(const ValueRefList& rhs);
ValueRefList(ValueRefList&& rhs);
ValueRefList& operator=(const ValueRefList& rhs);
ValueRefList& operator=(ValueRefList&& rhs);
~ValueRefList();
void clear();
ValueRef* begin() { return m_data; }
ValueRef* end() { return m_data + m_size; }
const ValueRef* cbegin() const { return m_data; }
const ValueRef* cend() const { return m_data + m_size; }
size_t size() const { return m_size; }
ValueRef& at(size_t idx) {
mgb_assert(idx < m_size);
return m_data[idx];
}
const ValueRef& at(size_t idx) const {
mgb_assert(idx < m_size);
return m_data[idx];
}
ValueRef& operator[](size_t idx) { return m_data[idx]; }
const ValueRef& operator[](size_t idx) const { return m_data[idx]; }
ValueRef* data() { return m_data; }
const ValueRef* data() const { return m_data; }
bool empty() const { return m_size == 0; }
ValueRef& front() {
mgb_assert(m_size > 1);
return m_data[0];
}
ValueRef& back() {
mgb_assert(m_size > 1);
return m_data[m_size - 1];
}
};
template <typename TIterator>
ValueRefList::ValueRefList(TIterator begin, TIterator end) : ValueRefList(end - begin) {
for (size_t i = 0; i < m_size; ++i) {
m_data[i] = *(begin + i);
}
}
inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_size(1) {
new (m_data) ValueRef();
m_data[0] = std::move(item);
}
/*class ValueRefList : public SmallVector<ValueRef, 1> {
public:
using SmallVector::SmallVector;
};*/
} // namespace imperative
} // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册