提交 177001d5 编写于 作者: M Megvii Engine Team

refactor(dispatch): allow dynamic type creation

GitOrigin-RevId: 27dde05cff7e1e0bf61c652a3ae1fe4def829ada
上级 150a6a61
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "range/v3/all.hpp" #include "range/v3/all.hpp"
#include "./helper.h"
#include "./transformation.h" #include "./transformation.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -30,9 +31,7 @@ namespace { ...@@ -30,9 +31,7 @@ namespace {
std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map; std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map;
} }
GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) { GradKeyWrapper::GradKeyWrapper() {}
grad_key_map[m_key] = this;
}
void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
if (nargs != 2) { if (nargs != 2) {
...@@ -77,8 +76,8 @@ pybind11::function GradKeyWrapper::get_backward_closure( ...@@ -77,8 +76,8 @@ pybind11::function GradKeyWrapper::get_backward_closure(
for (auto&& tensor : tensors) { for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
} }
auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0] auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0];
.as<FunctionValue>(); auto closure = closure_value.as_ref<FunctionValue>();
auto py_function = [closure](std::vector<TensorWrapper*> tensors) { auto py_function = [closure](std::vector<TensorWrapper*> tensors) {
std::vector<ValueRef> args; std::vector<ValueRef> args;
for (auto* tw : tensors) { for (auto* tw : tensors) {
...@@ -90,11 +89,14 @@ pybind11::function GradKeyWrapper::get_backward_closure( ...@@ -90,11 +89,14 @@ pybind11::function GradKeyWrapper::get_backward_closure(
} }
PyObject* GradKeyWrapper::get_name() { PyObject* GradKeyWrapper::get_name() {
return py::cast(m_key->name()).release().ptr(); return py::cast(m_name).release().ptr();
} }
void GradKeyWrapper::set_name(py::handle name) { void GradKeyWrapper::set_name(py::handle name) {
m_key->name(py::cast<std::string>(name)); m_name = py::cast<std::string>(name);
if (m_key) {
m_key->name(m_name);
}
} }
PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
...@@ -115,7 +117,10 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { ...@@ -115,7 +117,10 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
} }
void GradKeyWrapper::enter() { void GradKeyWrapper::enter() {
m_transformation = std::make_shared<GradTransformation>(m_key); m_transformation = std::make_shared<GradTransformation>();
m_key = m_transformation->key();
m_key->name(m_name);
grad_key_map[m_key] = this;
TransformationManager::get_instance().register_at<TransformationManager::Grad>( TransformationManager::get_instance().register_at<TransformationManager::Grad>(
m_transformation); m_transformation);
} }
...@@ -123,6 +128,8 @@ void GradKeyWrapper::enter() { ...@@ -123,6 +128,8 @@ void GradKeyWrapper::enter() {
void GradKeyWrapper::exit() { void GradKeyWrapper::exit() {
TransformationManager::get_instance().unregister<TransformationManager::Grad>( TransformationManager::get_instance().unregister<TransformationManager::Grad>(
m_transformation); m_transformation);
grad_key_map.erase(m_key);
m_key = {};
m_transformation.reset(); m_transformation.reset();
} }
...@@ -138,8 +145,6 @@ GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) { ...@@ -138,8 +145,6 @@ GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
return grad_key_map.at(key); return grad_key_map.at(key);
} }
GradKeyWrapper::~GradKeyWrapper() { GradKeyWrapper::~GradKeyWrapper() {}
grad_key_map.erase(m_key);
}
} // namespace mgb::imperative::python } // namespace mgb::imperative::python
...@@ -26,6 +26,7 @@ struct GradKeyWrapper : NonCopyableObj { ...@@ -26,6 +26,7 @@ struct GradKeyWrapper : NonCopyableObj {
using wrap_t = pyext17::wrap<GradKeyWrapper>; using wrap_t = pyext17::wrap<GradKeyWrapper>;
static constexpr auto tp_name = pybind11::detail::_("GradKey"); static constexpr auto tp_name = pybind11::detail::_("GradKey");
std::string m_name;
std::shared_ptr<GradKey> m_key; std::shared_ptr<GradKey> m_key;
std::shared_ptr<GradTransformation> m_transformation; std::shared_ptr<GradTransformation> m_transformation;
......
...@@ -117,7 +117,7 @@ std::optional<ValueRefList> elemwise_grad_rule( ...@@ -117,7 +117,7 @@ std::optional<ValueRefList> elemwise_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(2); SmallVector<ValueRef> ret(2);
if (!grad) { if (!grad) {
return ret; return ret;
} }
...@@ -147,7 +147,7 @@ std::optional<ValueRefList> reshape_grad_rule( ...@@ -147,7 +147,7 @@ std::optional<ValueRefList> reshape_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(2); SmallVector<ValueRef> ret(2);
if (!grad) { if (!grad) {
return ret; return ret;
} }
...@@ -180,7 +180,7 @@ std::optional<ValueRefList> subtensor_grad_rule( ...@@ -180,7 +180,7 @@ std::optional<ValueRefList> subtensor_grad_rule(
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(1); SmallVector<ValueRef> ret(1);
if (grad && inputs[0]) { if (grad && inputs[0]) {
ValueRefList args_(inputs.size() + 1); ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
...@@ -215,7 +215,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( ...@@ -215,7 +215,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(1); SmallVector<ValueRef> ret(1);
if (grad && inputs[0]) { if (grad && inputs[0]) {
ValueRefList args_(inputs.size() + 1); ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
...@@ -251,7 +251,7 @@ std::optional<ValueRefList> reduce_grad_rule( ...@@ -251,7 +251,7 @@ std::optional<ValueRefList> reduce_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(1); SmallVector<ValueRef> ret(1);
if (grad && shapes[0]) { if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0]); ret[0] = broadcast_to(grad, shapes[0]);
} }
...@@ -274,7 +274,7 @@ std::optional<ValueRefList> addAxis_grad_rule( ...@@ -274,7 +274,7 @@ std::optional<ValueRefList> addAxis_grad_rule(
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(1); SmallVector<ValueRef> ret(1);
if (grad && flag_) { if (grad && flag_) {
ret[0] = imperative::apply(*grad_op_, grad)[0]; ret[0] = imperative::apply(*grad_op_, grad)[0];
} }
...@@ -297,7 +297,7 @@ std::optional<ValueRefList> removeAxis_grad_rule( ...@@ -297,7 +297,7 @@ std::optional<ValueRefList> removeAxis_grad_rule(
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(1); SmallVector<ValueRef> ret(1);
if (grad && flag_) { if (grad && flag_) {
ret[0] = imperative::apply(*grad_op_, grad)[0]; ret[0] = imperative::apply(*grad_op_, grad)[0];
} }
...@@ -316,7 +316,7 @@ std::optional<ValueRefList> fastpathcopy_grad_rule( ...@@ -316,7 +316,7 @@ std::optional<ValueRefList> fastpathcopy_grad_rule(
maker.backward([](Span<ValueRef> grads) { maker.backward([](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1); mgb_assert(grads.size() == 1);
ValueRef grad = grads[0]; ValueRef grad = grads[0];
ValueRefList ret(1); SmallVector<ValueRef> ret(1);
if (grad) { if (grad) {
ret[0] = grad; ret[0] = grad;
} }
......
...@@ -56,42 +56,44 @@ WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; ...@@ -56,42 +56,44 @@ WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
struct SymbolVarContext { struct SymbolVarContext {
TransformationContext context; TransformationContext context;
cg::ComputingGraph* graph; std::shared_ptr<SymbolTransformation> symbol_tsf;
std::shared_ptr<ScalarTransformation> scalar_tsf;
SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) { SymbolVarContext(cg::ComputingGraph* graph) {
symbol_tsf = std::make_shared<SymbolTransformation>(graph);
scalar_tsf = std::make_shared<ScalarTransformation>();
Transformation::swap_context(context); Transformation::swap_context(context);
} }
void init() { void init() {
std::make_shared<SymbolTransformation>(graph)->register_at( symbol_tsf->register_at(Transformation::top());
Transformation::top()); scalar_tsf->register_at(Transformation::top());
std::make_shared<ScalarTransformation>()->register_at(Transformation::top());
} }
~SymbolVarContext() { Transformation::swap_context(context); } ValueRef symvar2val(py::handle py_symbol_var) {
}; auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node);
if (symbol_var->is_scalar) {
value = scalar_tsf->value_type().make(value);
}
return value;
}
ValueRef symvar2val(py::handle py_symbol_var) { py::object val2symvar(py::handle typeobj, ValueRef value) {
auto* symbol_var = py_symbol_var.cast<PySymbolVar*>(); bool is_scalar = false;
ValueRef value = SymbolValue::make(symbol_var->m_node); if (auto* scalar_value = value.as(scalar_tsf->value_type())) {
if (symbol_var->is_scalar) { value = scalar_value->value();
value = ScalarValue::make(value); is_scalar = true;
}
auto* node = value.cast(symbol_tsf->value_type()).node();
auto py_symbol_var =
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
return py_symbol_var;
} }
return value;
}
py::object val2symvar(py::handle typeobj, ValueRef value) { ~SymbolVarContext() { Transformation::swap_context(context); }
bool is_scalar = false; };
if (auto* scalar_value = value.as<ScalarValue>()) {
value = scalar_value->value();
is_scalar = true;
}
auto* node = value.cast<SymbolValue>().node();
auto py_symbol_var =
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
return py_symbol_var;
}
} // namespace } // namespace
...@@ -130,19 +132,21 @@ PyObject* py_apply( ...@@ -130,19 +132,21 @@ PyObject* py_apply(
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
SmallVector<ValueRef, 8> tensors(nargs); SmallVector<ValueRef, 8> tensors(nargs);
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) &&
py::isinstance<PySymbolVar>(py::handle(args[0]));
if (is_symbol_var) {
// swap to a special context to reuse scalar handle // swap to a special context to reuse scalar handle
SymbolVarContext context( SymbolVarContext context(
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph()); py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph());
context.init(); context.init();
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
tensors[i] = symvar2val(args[i]); tensors[i] = context.symvar2val(args[i]);
} }
auto outputs = imperative::apply(*op, tensors); auto outputs = imperative::apply(*op, tensors);
auto ret = pybind11::tuple(outputs.size()); auto ret = pybind11::tuple(outputs.size());
auto typeobj = py::handle(args[0]).get_type(); auto typeobj = py::handle(args[0]).get_type();
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
ret[i] = val2symvar(typeobj, outputs[i]); ret[i] = context.val2symvar(typeobj, outputs[i]);
} }
return ret.release().ptr(); return ret.release().ptr();
} }
...@@ -161,7 +165,7 @@ PyObject* py_apply( ...@@ -161,7 +165,7 @@ PyObject* py_apply(
} }
} }
auto outputs = imperative::apply(*op, tensors); auto outputs = [&] { return imperative::apply(*op, tensors); }();
size_t nout = outputs.size(); size_t nout = outputs.size();
auto ret = py::tuple(nout); auto ret = py::tuple(nout);
for (size_t i = 0; i < nout; ++i) { for (size_t i = 0; i < nout; ++i) {
...@@ -1573,9 +1577,9 @@ void init_tensor(py::module m) { ...@@ -1573,9 +1577,9 @@ void init_tensor(py::module m) {
SymbolVarContext context(graph); SymbolVarContext context(graph);
context.init(); context.init();
auto output = reduce_to_scalar( auto output = reduce_to_scalar(
*op.cast<std::shared_ptr<OpDef>>(), symvar2val(tensor)); *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor));
auto typeobj = tensor.get_type(); auto typeobj = tensor.get_type();
return val2symvar(typeobj, output); return context.val2symvar(typeobj, output);
} else { } else {
auto* tw = TensorWrapper::try_cast(tensor.ptr()); auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto output = reduce_to_scalar( auto output = reduce_to_scalar(
......
...@@ -67,10 +67,9 @@ struct TransformationManager { ...@@ -67,10 +67,9 @@ struct TransformationManager {
} }
}; };
class PyValue final class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
: public MixinValueImpl<PyValue, ValueKind::Object, pybind11::object> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const { std::string to_string() const {
return pybind11::str((const pybind11::object&)*this).cast<std::string>(); return pybind11::str((const pybind11::object&)*this).cast<std::string>();
......
...@@ -63,7 +63,7 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args { ...@@ -63,7 +63,7 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
MegBrainError, MegBrainError,
"unknown input type, expects HostStorage or DeviceStorage, got " "unknown input type, expects HostStorage or DeviceStorage, got "
"%s", "%s",
input.name()->c_str()); input.to_string().c_str());
} }
} }
mgb_assert( mgb_assert(
......
...@@ -12,7 +12,7 @@ std::string CompNodeValue::to_string() const { ...@@ -12,7 +12,7 @@ std::string CompNodeValue::to_string() const {
} }
std::string BoolValue::to_string() const { std::string BoolValue::to_string() const {
return (*m_value) ? "true" : "false"; return (*this) ? "true" : "false";
} }
std::string HostStorage::to_string() const { std::string HostStorage::to_string() const {
...@@ -26,10 +26,10 @@ std::string DeviceStorage::to_string() const { ...@@ -26,10 +26,10 @@ std::string DeviceStorage::to_string() const {
std::string HostValue::to_string() const { std::string HostValue::to_string() const {
return ssprintf( return ssprintf(
"HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), "HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
m_dtype.name(), m_shape.to_string().c_str()); dtype().name(), shape().to_string().c_str());
} }
HostTensorND HostValue::as_nd(bool allow_scalar) const { HostTensorND HostTensor::as_nd(bool allow_scalar) const {
HostTensorND nd; HostTensorND nd;
TensorShape tensor_shape; TensorShape tensor_shape;
if (m_shape.is_scalar()) { if (m_shape.is_scalar()) {
...@@ -45,10 +45,10 @@ HostTensorND HostValue::as_nd(bool allow_scalar) const { ...@@ -45,10 +45,10 @@ HostTensorND HostValue::as_nd(bool allow_scalar) const {
std::string DeviceValue::to_string() const { std::string DeviceValue::to_string() const {
return ssprintf( return ssprintf(
"DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), "DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
m_dtype.name(), m_shape.to_string().c_str()); dtype().name(), shape().to_string().c_str());
} }
DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const { DeviceTensorND DeviceTensor::as_nd(bool allow_scalar) const {
DeviceTensorND nd; DeviceTensorND nd;
TensorShape tensor_shape; TensorShape tensor_shape;
if (m_shape.is_scalar()) { if (m_shape.is_scalar()) {
......
...@@ -19,46 +19,18 @@ ...@@ -19,46 +19,18 @@
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
namespace {
MGB_NOINLINE void copy_outputs(
ForwardAllocator<ValueRef>& allocator, ValueRefList& outputs) {
size_t nr_outputs = outputs.size();
if (mgb_likely(nr_outputs == 1)) {
ValueRef output_copy;
output_copy = outputs[0];
allocator.clear();
outputs = ValueRefList({output_copy});
} else if (!outputs.empty()) {
SmallVector<ValueRef> outputs_copy(nr_outputs);
for (size_t i = 0; i < nr_outputs; ++i) {
outputs_copy[i] = outputs[i];
}
outputs.clear();
allocator.clear();
outputs = {outputs_copy.begin(), outputs_copy.end()};
} else {
allocator.clear();
}
}
} // namespace
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { ValueRefList apply(const Operator& op, Span<ValueRef> inputs) {
auto& context = Transformation::get_context(); auto& context = Transformation::get_context();
size_t& depth = context.next_transformation; size_t& depth = context.next_transformation;
bool top = depth == 0; // TODO: add fallback transformation
auto outputs = ([&] { bool fallback = depth >= context.transformations.size();
if (mgb_unlikely(depth >= context.transformations.size())) { if (mgb_unlikely(fallback)) {
return op.fallback(inputs); return op.fallback(inputs);
} else { } else {
auto& transformation = *context.transformations[depth++]; auto& transformation = *context.transformations[depth++];
CleanupGuard _{[&] { --depth; }}; CleanupGuard _{[&] { --depth; }};
return transformation.apply_transformation(op, inputs); return transformation.apply_transformation(op, inputs);
}
})();
if (mgb_unlikely(top)) {
copy_outputs(context.allocator, outputs);
} }
return outputs;
} }
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) { ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) {
...@@ -66,12 +38,7 @@ ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) { ...@@ -66,12 +38,7 @@ ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) {
} }
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) { ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) {
SmallVector<ValueRef> inputs_storage; auto apply_functor = [](std::shared_ptr<OpDef> op, Span<ValueRef> inputs, size_t) {
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_storage.push_back(inputs[i]);
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs,
size_t) {
auto outputs = imperative::apply(*op, inputs); auto outputs = imperative::apply(*op, inputs);
return SmallVector<ValueRef>(outputs.begin(), outputs.end()); return SmallVector<ValueRef>(outputs.begin(), outputs.end());
}; };
...@@ -93,7 +60,7 @@ ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) { ...@@ -93,7 +60,7 @@ ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) {
HostStorage::make(host_value.storage()), HostStorage::make(host_value.storage()),
DeviceStorage::make(device_value.storage()))[0]; DeviceStorage::make(device_value.storage()))[0];
}; };
auto outputs = graph.apply(inputs_storage, apply_functor, make_const); auto outputs = graph.apply(inputs, apply_functor, make_const);
return ValueRefList{outputs.begin(), outputs.end()}; return ValueRefList{outputs.begin(), outputs.end()};
} }
......
...@@ -331,6 +331,7 @@ void ChannelImpl::dispatch_kernel( ...@@ -331,6 +331,7 @@ void ChannelImpl::dispatch_kernel(
cmd.inputs = std::move(input_infos); cmd.inputs = std::move(input_infos);
cmd.outputs.reserve(output_descs.size()); cmd.outputs.reserve(output_descs.size());
outputs->reserve(output_descs.size()); outputs->reserve(output_descs.size());
for (int i = 0; i < output_descs.size(); ++i) { for (int i = 0; i < output_descs.size(); ++i) {
auto&& desc = output_descs[i]; auto&& desc = output_descs[i];
auto info = alloc(); auto info = alloc();
...@@ -730,7 +731,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { ...@@ -730,7 +731,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
input_descs.push_back({{{}, input->dtype()}, input->comp_node()}); input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
} }
auto forward_graph = OpDef::make_forward_graph(def, input_descs); auto forward_graph = OpDef::make_forward_graph(def, input_descs);
auto outputs = forward_graph.apply(inputs, apply_functor, const_functor); auto outputs = forward_graph.apply<TensorPtr>(
inputs, apply_functor, const_functor);
return outputs; return outputs;
} }
return OpDef::apply_on_physical_tensor(def, inputs); return OpDef::apply_on_physical_tensor(def, inputs);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "megbrain/imperative/opr_utility.h" #include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h" #include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h" #include "megbrain/opr/utility.h"
...@@ -101,7 +102,7 @@ void apply_on_device_tensornd( ...@@ -101,7 +102,7 @@ void apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs, const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) { SmallVector<DeviceTensorND>* outputs) {
auto&& op_def = def.cast_final_safe<Elemwise>(); auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); auto&& trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert( mgb_assert(
inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually", inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
trait.name, trait.arity, inputs.size()); trait.name, trait.arity, inputs.size());
......
...@@ -36,7 +36,7 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { ...@@ -36,7 +36,7 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
.node(); .node();
}; };
auto subgraph = def.trait()->make_forward_graph(def, input_descs); auto subgraph = def.trait()->make_forward_graph(def, input_descs);
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); auto outputs = subgraph.apply<VarNode*>(inputs, apply_functor, const_functor);
return outputs; return outputs;
} }
...@@ -56,7 +56,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -56,7 +56,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
value->layout(), value->comp_node(), value->layout(), value->comp_node(),
value->get_value().proxy_to_default_cpu()}; value->get_value().proxy_to_default_cpu()};
}; };
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); auto outputs =
subgraph.apply<LogicalTensorDesc>(inputs, apply_functor, const_functor);
return {outputs, all_validated}; return {outputs, all_validated};
} }
...@@ -72,7 +73,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -72,7 +73,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return OpDef::apply_on_physical_tensor(*op, inputs); return OpDef::apply_on_physical_tensor(*op, inputs);
}; };
auto const_functor = [&](const TensorPtr& value) { return value; }; auto const_functor = [&](const TensorPtr& value) { return value; };
auto outputs = subgraph.apply(inputs, apply_functor, const_functor); auto outputs = subgraph.apply<TensorPtr>(inputs, apply_functor, const_functor);
return outputs; return outputs;
} }
...@@ -94,7 +95,7 @@ static EncodedSubgraph make_backward_graph_from_forward( ...@@ -94,7 +95,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
}; };
GradContext<var_t> grad_context{accum_grad}; GradContext<var_t> grad_context{accum_grad};
auto input_vars = builder.write_inputs(inputs); auto input_vars = builder.write_inputs(inputs);
auto outputs = forward_graph.apply( auto outputs = forward_graph.apply<var_t>(
input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3),
[&](TensorPtr constant) { [&](TensorPtr constant) {
return builder.write_constant( return builder.write_constant(
...@@ -102,7 +103,7 @@ static EncodedSubgraph make_backward_graph_from_forward( ...@@ -102,7 +103,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
}); });
size_t nr_outputs = outputs.size(); size_t nr_outputs = outputs.size();
auto apply_mask = [](auto&& values, SmallVector<bool> mask) { auto apply_mask = [](auto&& values, SmallVector<bool> mask) {
mgb_assert(mask.size() == values.size(), ""); mgb_assert(mask.size() == values.size());
std::decay_t<decltype(values)> results; std::decay_t<decltype(values)> results;
for (size_t i = 0; i < mask.size(); ++i) { for (size_t i = 0; i < mask.size(); ++i) {
if (mask[i]) { if (mask[i]) {
...@@ -143,7 +144,7 @@ static EncodedSubgraph make_backward_graph_from_forward( ...@@ -143,7 +144,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
return builder.write_constant( return builder.write_constant(
constant, {constant->layout(), constant->comp_node()}); constant, {constant->layout(), constant->comp_node()});
}; };
return bg.apply(grad_inputs, apply_functor, const_functor); return bg.apply<var_t>(grad_inputs, apply_functor, const_functor);
}); });
builder.add_outputs(grad_context.get_grads(input_vars)); builder.add_outputs(grad_context.get_grads(input_vars));
for (size_t i = 0; i < nr_outputs; ++i) { for (size_t i = 0; i < nr_outputs; ++i) {
......
...@@ -10,20 +10,19 @@ ...@@ -10,20 +10,19 @@
*/ */
#include "megbrain/imperative/transformations/eval.h" #include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/utils/stats.h" #include "megbrain/imperative/utils/stats.h"
namespace mgb { namespace mgb {
namespace imperative { namespace imperative {
DTypeValue::ref_t InterpreterInfo::dtype() const { DTypeValue::ref_t InterpreterValue::dtype() const {
if (!m_dtype) { if (!m_dtype) {
m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle())); m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle()));
} }
return m_dtype; return m_dtype;
} }
CompNodeValue::ref_t InterpreterInfo::comp_node() const { CompNodeValue::ref_t InterpreterValue::comp_node() const {
if (!m_comp_node) { if (!m_comp_node) {
m_comp_node = CompNodeValue::make( m_comp_node = CompNodeValue::make(
handle()->channel()->get_device(handle()->handle())); handle()->channel()->get_device(handle()->handle()));
...@@ -31,7 +30,7 @@ CompNodeValue::ref_t InterpreterInfo::comp_node() const { ...@@ -31,7 +30,7 @@ CompNodeValue::ref_t InterpreterInfo::comp_node() const {
return m_comp_node; return m_comp_node;
} }
ShapeValue::ref_t InterpreterInfo::shape() const { ShapeValue::ref_t InterpreterValue::shape() const {
if (!m_shape) { if (!m_shape) {
m_shape = ShapeValue::make( m_shape = ShapeValue::make(
ValueShape::from(handle()->channel()->get_shape(handle()->handle()))); ValueShape::from(handle()->channel()->get_shape(handle()->handle())));
...@@ -51,21 +50,22 @@ ValueRefList InterpreterTransformation::apply_op( ...@@ -51,21 +50,22 @@ ValueRefList InterpreterTransformation::apply_op(
} }
}}; }};
for (auto input : inputs) { for (auto input : inputs) {
input_handles.push_back(input.cast<InterpreterValue>().handle()->handle()); input_handles.push_back(input.cast(m_value_type).handle()->handle());
} }
output_handles = output_handles =
m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); m_channel->apply_op(apply_op.op().shared_from_this(), input_handles);
ValueRefList outputs(output_handles.size()); ValueRefList outputs(output_handles.size());
for (size_t i = 0; i < output_handles.size(); ++i) { for (size_t i = 0; i < output_handles.size(); ++i) {
outputs[i] = InterpreterValue::make(share_handle(output_handles[i])); outputs[i] = m_value_type.make(share_handle(output_handles[i]));
output_handles[i] = nullptr; output_handles[i] = nullptr;
} }
output_handles.clear();
return outputs; return outputs;
} }
ValueRefList InterpreterTransformation::apply_get_attr( ValueRefList InterpreterTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) { const GetAttr& get_attr, Span<ValueRef> inputs) {
auto& input = inputs.item().cast<InterpreterValue>(); auto& input = inputs.item().cast(m_value_type);
ValueRef output; ValueRef output;
switch (get_attr.attr()) { switch (get_attr.attr()) {
case GetAttr::DType: case GetAttr::DType:
...@@ -98,10 +98,10 @@ ValueRefList InterpreterTransformation::apply_create_tensor( ...@@ -98,10 +98,10 @@ ValueRefList InterpreterTransformation::apply_create_tensor(
if (!args.device) { if (!args.device) {
// implies H2D // implies H2D
mgb_assert(args.host, "neither host and device value is valid"); mgb_assert(args.host, "neither host and device value is valid");
return {InterpreterValue::make(share_handle( return {m_value_type.make(share_handle(
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
} else { } else {
return {InterpreterValue::make(share_handle(m_channel->put( return {m_value_type.make(share_handle(m_channel->put(
*args.device, args.host ? *args.host : HostTensorND())))}; *args.device, args.host ? *args.host : HostTensorND())))};
} }
} }
...@@ -119,7 +119,7 @@ ValueRefList InterpreterTransformation::apply_transformation( ...@@ -119,7 +119,7 @@ ValueRefList InterpreterTransformation::apply_transformation(
} else if (auto* create_tensor = op.as<CreateTensor>()) { } else if (auto* create_tensor = op.as<CreateTensor>()) {
return apply_create_tensor(*create_tensor, inputs); return apply_create_tensor(*create_tensor, inputs);
} else if (auto* dtr_command = op.as<DTRCommand>()) { } else if (auto* dtr_command = op.as<DTRCommand>()) {
auto handle = inputs[0].cast<InterpreterValue>().handle()->handle(); auto handle = inputs[0].cast(m_value_type).handle()->handle();
switch (dtr_command->kind()) { switch (dtr_command->kind()) {
case DTRCommand::Drop: case DTRCommand::Drop:
m_channel->drop(handle); m_channel->drop(handle);
...@@ -129,10 +129,10 @@ ValueRefList InterpreterTransformation::apply_transformation( ...@@ -129,10 +129,10 @@ ValueRefList InterpreterTransformation::apply_transformation(
} }
return {}; return {};
} else if (auto* rename_value = op.as<RenameValue>()) { } else if (auto* rename_value = op.as<RenameValue>()) {
auto& input = inputs[0].cast<InterpreterValue>(); auto& input = inputs[0].cast(m_value_type);
return {InterpreterValue::make(input.handle(), rename_value->name())}; return {m_value_type.make(input.handle(), rename_value->name())};
} else if (op.is<GetName>()) { } else if (op.is<GetName>()) {
auto name = inputs[0].cast<InterpreterValue>().name(); auto name = inputs[0].cast(m_value_type).name();
if (!name.empty()) { if (!name.empty()) {
return {StringValue::make(name)}; return {StringValue::make(name)};
} else { } else {
......
...@@ -68,7 +68,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( ...@@ -68,7 +68,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
size_t count = std::count_if( size_t count = std::count_if(
save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
if (!backward_graph->precomp.empty()) { if (!backward_graph->precomp.empty()) {
ValueRefList inputs_and_outputs(inputs.size() + outputs.size()); SmallVector<ValueRef> inputs_and_outputs(inputs.size() + outputs.size());
auto it = inputs_and_outputs.begin(); auto it = inputs_and_outputs.begin();
for (auto&& input : inputs) { for (auto&& input : inputs) {
*it++ = input; *it++ = input;
...@@ -94,7 +94,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( ...@@ -94,7 +94,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
} }
} }
void BackwardGraphWithClosure::operator()( void BackwardGraphWithClosure::operator()(
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
ValueRef args[closure.size() + grads.size()]; ValueRef args[closure.size() + grads.size()];
size_t nargs = 0; size_t nargs = 0;
for (auto&& value : closure) { for (auto&& value : closure) {
...@@ -114,7 +114,9 @@ void BackwardGraphWithClosure::operator()( ...@@ -114,7 +114,9 @@ void BackwardGraphWithClosure::operator()(
if (null_grad) { if (null_grad) {
return; return;
} }
auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs));
SmallVector<ValueRef> igrads = {igrads_.begin(), igrads_.end()};
igrads_.clear();
auto&& iter = igrads.begin(); auto&& iter = igrads.begin();
for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) {
if (p) { if (p) {
...@@ -125,7 +127,7 @@ void BackwardGraphWithClosure::operator()( ...@@ -125,7 +127,7 @@ void BackwardGraphWithClosure::operator()(
} }
void CustomBackward::operator()( void CustomBackward::operator()(
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
size_t nargs = grads.size(); size_t nargs = grads.size();
ValueRef args[nargs]; ValueRef args[nargs];
for (size_t i = 0; i < nargs; ++i) { for (size_t i = 0; i < nargs; ++i) {
...@@ -206,7 +208,7 @@ void GradKey::backward() { ...@@ -206,7 +208,7 @@ void GradKey::backward() {
mgb_throw(AssertionError, "invalid backward"); mgb_throw(AssertionError, "invalid backward");
} else { } else {
mgb_assert(grad_fn->m_slots.size() > 0); mgb_assert(grad_fn->m_slots.size() > 0);
ValueRefList grads (grad_fn->m_slots.size()); SmallVector<ValueRef> grads (grad_fn->m_slots.size());
auto iter = grads.begin(); auto iter = grads.begin();
for (auto&& slot : grad_fn->m_slots) { for (auto&& slot : grad_fn->m_slots) {
*iter++ = slot.m_grad; *iter++ = slot.m_grad;
...@@ -231,11 +233,9 @@ void GradKey::backward() { ...@@ -231,11 +233,9 @@ void GradKey::backward() {
GradValue::ref_t GradKey::attach( GradValue::ref_t GradKey::attach(
ValueRef tensor, std::function<void(ValueRef)> callback) { ValueRef tensor, std::function<void(ValueRef)> callback) {
auto grad_value = tensor.as_ref<GradValue>(); auto grad_value = tensor.as_ref(m_value_type);
if (grad_value && grad_value->has_key(shared_from_this())) { if (grad_value) {
mgb_assert( mgb_assert(!tensor.cast(m_value_type).slot()->callback, "callback exists");
!tensor.cast<GradValue>().slot_for(shared_from_this())->callback,
"callback exists");
} else { } else {
GradSlotPtr grad_slot; GradSlotPtr grad_slot;
auto& grad_fn = grad_slot.m_fn; auto& grad_fn = grad_slot.m_fn;
...@@ -243,9 +243,9 @@ GradValue::ref_t GradKey::attach( ...@@ -243,9 +243,9 @@ GradValue::ref_t GradKey::attach(
grad_fn->m_key = shared_from_this(); grad_fn->m_key = shared_from_this();
grad_fn->m_slots.resize(1); grad_fn->m_slots.resize(1);
grad_slot.m_index = 0; grad_slot.m_index = 0;
grad_value = GradValue::make(tensor, shared_from_this(), grad_slot); grad_value = m_value_type.make(tensor, shared_from_this(), grad_slot);
} }
grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback; grad_value->slot().m_fn->m_slots[0].callback = callback;
return grad_value; return grad_value;
} }
...@@ -263,7 +263,7 @@ void GradKey::freeze() { ...@@ -263,7 +263,7 @@ void GradKey::freeze() {
ValueRefList GradTransformation::apply_transformation( ValueRefList GradTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) { const Operator& op, Span<ValueRef> inputs) {
auto fallback = [&] { auto fallback = [&] {
ValueRefList unwrapped_inputs(inputs.size()); SmallVector<ValueRef> unwrapped_inputs(inputs.size());
{ {
// overhead // overhead
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
...@@ -367,7 +367,7 @@ ValueRefList GradTransformation::apply_transformation( ...@@ -367,7 +367,7 @@ ValueRefList GradTransformation::apply_transformation(
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
if (backward.input_has_grad(i) && require_grads[i]) { if (backward.input_has_grad(i) && require_grads[i]) {
auto& input_grad_slot = auto& input_grad_slot =
inputs[i].cast<GradValue>().slot_for(m_key); inputs[i].cast(m_value_type).slot();
grad_fn->m_dests.emplace_back(input_grad_slot); grad_fn->m_dests.emplace_back(input_grad_slot);
grad_fn->m_dests.back().m_producer_record.insert_after( grad_fn->m_dests.back().m_producer_record.insert_after(
input_grad_slot->m_producer_head); input_grad_slot->m_producer_head);
...@@ -378,7 +378,7 @@ ValueRefList GradTransformation::apply_transformation( ...@@ -378,7 +378,7 @@ ValueRefList GradTransformation::apply_transformation(
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) { if (backward.output_requires_grad(i)) {
// little overhead: Value::make // little overhead: Value::make
auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); auto grad_value = m_value_type.make(outputs[i], m_key, GradSlotPtr{grad_fn, i});
outputs[i] = record_grad(grad_value); outputs[i] = record_grad(grad_value);
} }
} }
...@@ -435,7 +435,10 @@ ValueRefList GradTransformation::apply_transformation( ...@@ -435,7 +435,10 @@ ValueRefList GradTransformation::apply_transformation(
backward.m_input_has_grad = SmallVector(nr_inputs, true); backward.m_input_has_grad = SmallVector(nr_inputs, true);
backward.m_output_attrs = backward.m_output_attrs =
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true});
backward.m_backward = set_grad->grad_fn(); backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) {
auto result = fn(inputs);
return SmallVector<ValueRef>(result.begin(), result.end());
};
ValueRefList outputs(nr_outputs); ValueRefList outputs(nr_outputs);
grad_fn->m_key = m_key; grad_fn->m_key = m_key;
grad_fn->m_slots.resize(nr_outputs); grad_fn->m_slots.resize(nr_outputs);
...@@ -454,10 +457,10 @@ ValueRefList GradTransformation::apply_transformation( ...@@ -454,10 +457,10 @@ ValueRefList GradTransformation::apply_transformation(
auto& output = outputs_[i]; auto& output = outputs_[i];
auto grad_value = as_grad_value(output); auto grad_value = as_grad_value(output);
if (grad_value) { if (grad_value) {
grad_value = GradValue::make( grad_value = m_value_type.make(
grad_value->m_value, m_key, GradSlotPtr(grad_fn, i)); grad_value->m_value, m_key, GradSlotPtr(grad_fn, i));
} else { } else {
grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); grad_value = m_value_type.make(output, m_key, GradSlotPtr(grad_fn, i));
} }
outputs[i] = record_grad(grad_value); outputs[i] = record_grad(grad_value);
} }
...@@ -485,8 +488,7 @@ ValueRefList GradTransformation::apply_transformation( ...@@ -485,8 +488,7 @@ ValueRefList GradTransformation::apply_transformation(
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
if (auto&& grad_value = as_grad_value(inputs[0])) { if (auto&& grad_value = as_grad_value(inputs[0])) {
auto output = imperative::apply(op, grad_value->m_value)[0]; auto output = imperative::apply(op, grad_value->m_value)[0];
auto grad_output = GradValue::make( auto grad_output = m_value_type.make(output, m_key, grad_value->slot());
output, grad_value->key(), grad_value->slot_for(m_key));
return {record_grad(grad_output)}; return {record_grad(grad_output)};
} else { } else {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
...@@ -502,7 +504,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { ...@@ -502,7 +504,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
std::vector<GradSlotPtr> y_slots; std::vector<GradSlotPtr> y_slots;
for (auto&& y : ys) { for (auto&& y : ys) {
if (auto&& grad_value = as_grad_value(y)) { if (auto&& grad_value = as_grad_value(y)) {
y_slots.push_back(grad_value->slot_for(grad_key)); y_slots.push_back(grad_value->slot());
} else { } else {
y_slots.emplace_back(); y_slots.emplace_back();
} }
......
...@@ -32,7 +32,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( ...@@ -32,7 +32,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo()); bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo());
VarNodeArray input_nodes; VarNodeArray input_nodes;
for (auto&& input : inputs) { for (auto&& input : inputs) {
if (auto* input_node = input.as<LazyEvalValue>()) { if (auto* input_node = input.as(m_value_type)) {
input_nodes.push_back(input_node->node()); input_nodes.push_back(input_node->node());
} else { } else {
// ImmutableTensor has empty shape issues // ImmutableTensor has empty shape issues
...@@ -112,7 +112,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( ...@@ -112,7 +112,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
return {record_var(node)}; return {record_var(node)};
} }
} else if (auto* get_attr = op.as<GetAttr>()) { } else if (auto* get_attr = op.as<GetAttr>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) { if (auto* lazy_val = inputs.item().as(m_value_type)) {
switch (get_attr->attr()) { switch (get_attr->attr()) {
case GetAttr::DType: case GetAttr::DType:
return {DTypeValue::make(lazy_val->node()->dtype())}; return {DTypeValue::make(lazy_val->node()->dtype())};
...@@ -167,14 +167,14 @@ ValueRefList LazyEvalTransformation::apply_transformation( ...@@ -167,14 +167,14 @@ ValueRefList LazyEvalTransformation::apply_transformation(
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
} else if (auto* rename_value = op.as<RenameValue>()) { } else if (auto* rename_value = op.as<RenameValue>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) { if (auto* lazy_val = inputs.item().as(m_value_type)) {
return {record_var( return {record_var(
lazy_val->node(), lazy_val->bound_data(), rename_value->name())}; lazy_val->node(), lazy_val->bound_data(), rename_value->name())};
} else { } else {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
} else if (op.is<GetName>()) { } else if (op.is<GetName>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) { if (auto* lazy_val = inputs.item().as(m_value_type)) {
auto name = lazy_val->name(); auto name = lazy_val->name();
if (!name.empty()) { if (!name.empty()) {
return {StringValue::make(lazy_val->name())}; return {StringValue::make(lazy_val->name())};
...@@ -255,7 +255,7 @@ void LazyEvalTransformation::on_unregister() noexcept { ...@@ -255,7 +255,7 @@ void LazyEvalTransformation::on_unregister() noexcept {
DeviceStorage::make(data.storage()))[0]); DeviceStorage::make(data.storage()))[0]);
} }
for (auto&& lazy_val : lazy_vals) { for (auto&& lazy_val : lazy_vals) {
if (lazy_val.is<LazyEvalValue>()) { if (lazy_val.is(m_value_type)) {
std::string repr = std::string repr =
ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str()); ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str());
mgb_log_debug("%s", repr.c_str()); mgb_log_debug("%s", repr.c_str());
......
...@@ -20,7 +20,8 @@ namespace imperative { ...@@ -20,7 +20,8 @@ namespace imperative {
namespace { namespace {
using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>); using ScalarRule = ValueRefList (*)(
const OpDef&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&);
static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules; static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules;
ValueRef make_scalar_shape(CompNode device) { ValueRef make_scalar_shape(CompNode device) {
...@@ -41,17 +42,22 @@ bool is_scalar_shape(ValueRef shape) { ...@@ -41,17 +42,22 @@ bool is_scalar_shape(ValueRef shape) {
return *shape_of_shape == ValueShape{0}; return *shape_of_shape == ValueShape{0};
} }
template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)> template <
typename T,
ValueRefList (*rule)(
const T&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&)>
void register_scalar_rule() { void register_scalar_rule() {
scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs, scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs,
Span<bool> inputs_mask) { Span<bool> inputs_mask,
return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask); const Type<ScalarValue>& value_type) {
return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask, value_type);
}; };
} }
template <typename TOpDef, size_t nr_inputs> template <typename TOpDef, size_t nr_inputs>
ValueRefList elemwise_rule( ValueRefList elemwise_rule(
const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) { const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
if constexpr (nr_inputs != 0) { if constexpr (nr_inputs != 0) {
mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch"); mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch");
} }
...@@ -63,27 +69,29 @@ ValueRefList elemwise_rule( ...@@ -63,27 +69,29 @@ ValueRefList elemwise_rule(
} }
auto outputs = imperative::apply(op_def, inputs); auto outputs = imperative::apply(op_def, inputs);
if (all_scalar) { if (all_scalar) {
outputs[0] = ScalarValue::make(outputs[0]); outputs[0] = scalar_type.make(outputs[0]);
} }
return outputs; return outputs;
} }
ValueRefList remove_axis_rule( ValueRefList remove_axis_rule(
const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(!inputs_mask.item()); mgb_assert(!inputs_mask.item());
bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size(); bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size();
if (is_scalar && remove_axis.axis.size() == 1) { if (is_scalar && remove_axis.axis.size() == 1) {
return {ScalarValue::make(inputs.item())}; return {scalar_type.make(inputs.item())};
} }
auto outputs = imperative::apply(remove_axis, inputs); auto outputs = imperative::apply(remove_axis, inputs);
if (is_scalar) { if (is_scalar) {
outputs[0] = ScalarValue::make(outputs[0]); outputs[0] = scalar_type.make(outputs[0]);
} }
return outputs; return outputs;
} }
ValueRefList reduce_rule( ValueRefList reduce_rule(
const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) { const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
if (inputs.size() == 1) { if (inputs.size() == 1) {
return imperative::apply(reduce, inputs); return imperative::apply(reduce, inputs);
} }
...@@ -91,7 +99,7 @@ ValueRefList reduce_rule( ...@@ -91,7 +99,7 @@ ValueRefList reduce_rule(
bool is_scalar = is_scalar_shape(inputs[1]); bool is_scalar = is_scalar_shape(inputs[1]);
if (is_scalar) { if (is_scalar) {
CompNode device = *inputs[0].device(); CompNode device = *inputs[0].device();
return {ScalarValue::make( return {scalar_type.make(
imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])}; imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])};
} }
return imperative::apply(reduce, inputs); return imperative::apply(reduce, inputs);
...@@ -99,7 +107,7 @@ ValueRefList reduce_rule( ...@@ -99,7 +107,7 @@ ValueRefList reduce_rule(
ValueRefList collective_comm_rule( ValueRefList collective_comm_rule(
const CollectiveComm& collective_comm, Span<ValueRef> inputs, const CollectiveComm& collective_comm, Span<ValueRef> inputs,
Span<bool> inputs_mask) { Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
static std::unordered_set<CollectiveComm::Mode> modes = { static std::unordered_set<CollectiveComm::Mode> modes = {
CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN,
...@@ -110,7 +118,7 @@ ValueRefList collective_comm_rule( ...@@ -110,7 +118,7 @@ ValueRefList collective_comm_rule(
return imperative::apply(collective_comm, inputs); return imperative::apply(collective_comm, inputs);
} }
if (inputs_mask.item()) { if (inputs_mask.item()) {
return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])}; return {scalar_type.make(imperative::apply(collective_comm, inputs[0])[0])};
} else { } else {
return imperative::apply(collective_comm, inputs); return imperative::apply(collective_comm, inputs);
} }
...@@ -118,24 +126,27 @@ ValueRefList collective_comm_rule( ...@@ -118,24 +126,27 @@ ValueRefList collective_comm_rule(
ValueRefList param_pack_split_rule( ValueRefList param_pack_split_rule(
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs, const ParamPackSplit& param_pack_split, Span<ValueRef> inputs,
Span<bool> inputs_mask) { Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
auto outputs = imperative::apply(param_pack_split, inputs); auto outputs = imperative::apply(param_pack_split, inputs);
size_t nr_outputs = outputs.size(); size_t nr_outputs = outputs.size();
mgb_assert(nr_outputs == param_pack_split.shapes.size()); mgb_assert(nr_outputs == param_pack_split.shapes.size());
for (size_t i = 0; i < nr_outputs; ++i) { for (size_t i = 0; i < nr_outputs; ++i) {
if (param_pack_split.shapes[i].empty()) { if (param_pack_split.shapes[i].empty()) {
outputs[i] = ScalarValue::make(outputs[i]); outputs[i] = scalar_type.make(outputs[i]);
} }
} }
return outputs; return outputs;
} }
ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) { ValueRefList dot_rule(
return {ScalarValue::make(imperative::apply(dot, inputs)[0])}; const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
return {scalar_type.make(imperative::apply(dot, inputs)[0])};
} }
ValueRefList add_axis_rule( ValueRefList add_axis_rule(
const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 1); mgb_assert(inputs.size() == 1);
if (inputs_mask.item()) { if (inputs_mask.item()) {
mgb_assert(add_axis.axis[0] == 0); mgb_assert(add_axis.axis[0] == 0);
...@@ -151,7 +162,8 @@ ValueRefList add_axis_rule( ...@@ -151,7 +162,8 @@ ValueRefList add_axis_rule(
} }
ValueRefList remote_recv_rule( ValueRefList remote_recv_rule(
const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) { const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
if (remote_recv.shape.empty()) { if (remote_recv.shape.empty()) {
std::vector<int32_t> shape = {1}; std::vector<int32_t> shape = {1};
auto remote_recv_no_scalar = RemoteRecv::make( auto remote_recv_no_scalar = RemoteRecv::make(
...@@ -167,20 +179,21 @@ ValueRefList remote_recv_rule( ...@@ -167,20 +179,21 @@ ValueRefList remote_recv_rule(
ValueRefList check_no_finite_rule( ValueRefList check_no_finite_rule(
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs, const CheckNonFinite& check_no_finite, Span<ValueRef> inputs,
Span<bool> inputs_mask) { Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
auto outputs = imperative::apply(check_no_finite, inputs); auto outputs = imperative::apply(check_no_finite, inputs);
mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch");
outputs.back() = ScalarValue::make(outputs.back()); outputs.back() = scalar_type.make(outputs.back());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs_mask[i]) { if (inputs_mask[i]) {
outputs[i] = ScalarValue::make(outputs[i]); outputs[i] = scalar_type.make(outputs[i]);
} }
} }
return outputs; return outputs;
} }
ValueRefList subtensor_rule( ValueRefList subtensor_rule(
const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) { const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() >= 1); mgb_assert(inputs.size() >= 1);
auto input = inputs[0]; auto input = inputs[0];
bool is_scalar; bool is_scalar;
...@@ -199,14 +212,14 @@ ValueRefList subtensor_rule( ...@@ -199,14 +212,14 @@ ValueRefList subtensor_rule(
} }
auto outputs = imperative::apply(subtensor, inputs); auto outputs = imperative::apply(subtensor, inputs);
if (is_scalar) { if (is_scalar) {
outputs[0] = ScalarValue::make(outputs[0]); outputs[0] = scalar_type.make(outputs[0]);
} }
return outputs; return outputs;
} }
ValueRefList get_var_shape_rule( ValueRefList get_var_shape_rule(
const GetVarShape& get_var_shape, Span<ValueRef> inputs, const GetVarShape& get_var_shape, Span<ValueRef> inputs, Span<bool> inputs_mask,
Span<bool> inputs_mask) { const Type<ScalarValue>& scalar_type) {
bool all_scalar = true; bool all_scalar = true;
mgb_assert(inputs.size() >= 1); mgb_assert(inputs.size() >= 1);
for (auto&& input_mask : inputs_mask) { for (auto&& input_mask : inputs_mask) {
...@@ -228,11 +241,12 @@ ValueRefList get_var_shape_rule( ...@@ -228,11 +241,12 @@ ValueRefList get_var_shape_rule(
} }
ValueRefList reshape_rule( ValueRefList reshape_rule(
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) { const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]); bool is_scalar = is_scalar_shape(inputs[1]);
if (is_scalar) { if (is_scalar) {
return {ScalarValue::make(imperative::apply( return {scalar_type.make(imperative::apply(
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
} else { } else {
return imperative::apply(reshape, inputs); return imperative::apply(reshape, inputs);
...@@ -240,11 +254,12 @@ ValueRefList reshape_rule( ...@@ -240,11 +254,12 @@ ValueRefList reshape_rule(
} }
ValueRefList broadcast_rule( ValueRefList broadcast_rule(
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) { const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 2); mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]); bool is_scalar = is_scalar_shape(inputs[1]);
if (is_scalar) { if (is_scalar) {
return {ScalarValue::make(imperative::apply( return {scalar_type.make(imperative::apply(
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
} else { } else {
return imperative::apply(broadcast, inputs); return imperative::apply(broadcast, inputs);
...@@ -299,11 +314,11 @@ struct ScalarRuleRegistry { ...@@ -299,11 +314,11 @@ struct ScalarRuleRegistry {
ValueRefList ScalarTransformation::apply_get_attr( ValueRefList ScalarTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) { const GetAttr& get_attr, Span<ValueRef> inputs) {
auto&& input = inputs.item(); auto&& input = inputs.item();
bool is_scalar = input.is<ScalarValue>(); bool is_scalar = input.is(m_value_type);
if (!is_scalar) { if (!is_scalar) {
return imperative::apply(get_attr, input); return imperative::apply(get_attr, input);
} }
auto unwrapped_input = input.cast<ScalarValue>().value(); auto unwrapped_input = input.cast(m_value_type).value();
if (get_attr.attr() == GetAttr::Shape) { if (get_attr.attr() == GetAttr::Shape) {
if (!m_empty_shape) { if (!m_empty_shape) {
m_empty_shape = ShapeValue::make(); m_empty_shape = ShapeValue::make();
...@@ -352,7 +367,7 @@ ValueRefList ScalarTransformation::apply_transformation( ...@@ -352,7 +367,7 @@ ValueRefList ScalarTransformation::apply_transformation(
ValueRefList unwrapped_inputs(nr_inputs); ValueRefList unwrapped_inputs(nr_inputs);
SmallVector<bool> inputs_mask(nr_inputs); SmallVector<bool> inputs_mask(nr_inputs);
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
if (auto&& scalar_value = inputs[i].as_ref<ScalarValue>()) { if (auto&& scalar_value = inputs[i].as_ref(m_value_type)) {
unwrapped_inputs[i] = scalar_value->value(); unwrapped_inputs[i] = scalar_value->value();
inputs_mask[i] = true; inputs_mask[i] = true;
} else { } else {
...@@ -364,7 +379,8 @@ ValueRefList ScalarTransformation::apply_transformation( ...@@ -364,7 +379,8 @@ ValueRefList ScalarTransformation::apply_transformation(
if (auto apply_op = op.as<ApplyOp>()) { if (auto apply_op = op.as<ApplyOp>()) {
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo());
if (iter != scalar_rules.end()) { if (iter != scalar_rules.end()) {
return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask); return iter->second(
apply_op->op(), unwrapped_inputs, inputs_mask, m_value_type);
} else { } else {
// TODO: repeat op // TODO: repeat op
return fallback(); return fallback();
...@@ -375,7 +391,7 @@ ValueRefList ScalarTransformation::apply_transformation( ...@@ -375,7 +391,7 @@ ValueRefList ScalarTransformation::apply_transformation(
CreateTensor scalar_op( CreateTensor scalar_op(
create_tensor->kind(), create_tensor->device(), create_tensor->kind(), create_tensor->device(),
create_tensor->dtype(), scalar_shape); create_tensor->dtype(), scalar_shape);
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; return {m_value_type.make(imperative::apply(scalar_op, inputs)[0])};
} else { } else {
return imperative::apply(op, inputs); return imperative::apply(op, inputs);
} }
...@@ -387,7 +403,7 @@ ValueRefList ScalarTransformation::apply_transformation( ...@@ -387,7 +403,7 @@ ValueRefList ScalarTransformation::apply_transformation(
bool is_scalar = inputs_mask[0]; bool is_scalar = inputs_mask[0];
auto outputs = fallback(); auto outputs = fallback();
if (is_scalar) { if (is_scalar) {
outputs[0] = ScalarValue::make(outputs[0]); outputs[0] = m_value_type.make(outputs[0]);
} }
return outputs; return outputs;
} else { } else {
......
...@@ -160,7 +160,7 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -160,7 +160,7 @@ ValueRefList TracingTransformation::apply_transformation(
SmallVector<TracingValue::ref_t> wrapped_inputs; SmallVector<TracingValue::ref_t> wrapped_inputs;
SmallVector<size_t> input_ids; SmallVector<size_t> input_ids;
for (auto input : inputs) { for (auto input : inputs) {
auto tracing_value = input.as_ref<TracingValue>(); auto tracing_value = input.as_ref(m_value_type);
if (!tracing_value) { if (!tracing_value) {
tracing_value = tracing_value =
record_var(input, m_capture_as_const, VarKind::External); record_var(input, m_capture_as_const, VarKind::External);
...@@ -208,7 +208,7 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -208,7 +208,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (auto* get_attr = op.as<GetAttr>()) { } else if (auto* get_attr = op.as<GetAttr>()) {
auto unwrapped_input = unwrap_var(inputs[0]); auto unwrapped_input = unwrap_var(inputs[0]);
auto outputs = imperative::apply(op, unwrapped_input); auto outputs = imperative::apply(op, unwrapped_input);
if (auto* tracing_value = inputs[0].as<TracingValue>()) { if (auto* tracing_value = inputs[0].as(m_value_type)) {
auto& var_info = m_vars[tracing_value->id()]; auto& var_info = m_vars[tracing_value->id()];
switch (get_attr->attr()) { switch (get_attr->attr()) {
case GetAttr::Shape: case GetAttr::Shape:
...@@ -228,7 +228,7 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -228,7 +228,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (auto* trace_mark_var = op.as<TraceMarkVar>()) { } else if (auto* trace_mark_var = op.as<TraceMarkVar>()) {
mgb_assert(inputs.size() == 1, "TraceMarkVar expects exactly one input"); mgb_assert(inputs.size() == 1, "TraceMarkVar expects exactly one input");
auto input = inputs[0]; auto input = inputs[0];
auto tracing_var = input.as_ref<TracingValue>(); auto tracing_var = input.as_ref(m_value_type);
if (!tracing_var) { if (!tracing_var) {
bool is_input = trace_mark_var->mark().substr(0, 4) == "arg_" || bool is_input = trace_mark_var->mark().substr(0, 4) == "arg_" ||
trace_mark_var->mark().substr(0, 6) == "kwarg_"; trace_mark_var->mark().substr(0, 6) == "kwarg_";
...@@ -247,7 +247,7 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -247,7 +247,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (auto* trace_name_var = op.as<RenameValue>()) { } else if (auto* trace_name_var = op.as<RenameValue>()) {
mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input"); mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input");
auto input = inputs[0]; auto input = inputs[0];
auto tracing_var = input.as_ref<TracingValue>(); auto tracing_var = input.as_ref(m_value_type);
if (!tracing_var) { if (!tracing_var) {
tracing_var = record_var(input, m_capture_as_const, VarKind::External); tracing_var = record_var(input, m_capture_as_const, VarKind::External);
} else { } else {
...@@ -260,7 +260,7 @@ ValueRefList TracingTransformation::apply_transformation( ...@@ -260,7 +260,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (op.is<GetName>()) { } else if (op.is<GetName>()) {
mgb_assert(inputs.size() == 1, "GetName expects exactly one input"); mgb_assert(inputs.size() == 1, "GetName expects exactly one input");
auto input = inputs[0]; auto input = inputs[0];
if (auto tracing_var = input.as_ref<TracingValue>()) { if (auto tracing_var = input.as_ref(m_value_type)) {
auto name = m_vars[tracing_var->id()].name; auto name = m_vars[tracing_var->id()].name;
if (!name.empty()) { if (!name.empty()) {
return {StringValue::make(name)}; return {StringValue::make(name)};
...@@ -425,26 +425,12 @@ void CompiledTransformation::compile() { ...@@ -425,26 +425,12 @@ void CompiledTransformation::compile() {
} }
auto& node = var_accessors[input].node; auto& node = var_accessors[input].node;
if (input_vars.empty() && require_link && mm_io_link.node()) { if (input_vars.empty() && require_link && mm_io_link.node()) {
/*mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");*/
auto comp_node = mm_io_link.node()->comp_node(); auto comp_node = mm_io_link.node()->comp_node();
// auto comp_node = input_vars[0]->comp_node();
node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node) node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node)
.node(); .node();
} }
input_vars.push_back(node); input_vars.push_back(node);
} }
/*if (require_link && mm_io_link.node()) {
mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");
auto comp_node = mm_io_link.node()->comp_node();
// auto comp_node = input_vars[0]->comp_node();
input_vars[0] = opr::VirtualDep::make(
{SymbolVar(input_vars[0]), mm_io_link}, comp_node)
.node();
}*/
VarNodeArray output_vars; VarNodeArray output_vars;
if (item.op) { if (item.op) {
output_vars = OpDef::apply_on_var_node(*item.op, input_vars); output_vars = OpDef::apply_on_var_node(*item.op, input_vars);
...@@ -520,7 +506,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { ...@@ -520,7 +506,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
switch (var.kind) { switch (var.kind) {
case VarKind::External: { case VarKind::External: {
trace_assert( trace_assert(
!value.is<TracedValue>(), "expect external node, got internal"); !value.is(m_value_type), "expect external node, got internal");
if (var.bound_data) { if (var.bound_data) {
assert_tensor_equal(var.bound_data, value); assert_tensor_equal(var.bound_data, value);
} else { } else {
...@@ -545,8 +531,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { ...@@ -545,8 +531,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
} }
case VarKind::Internal: { case VarKind::Internal: {
trace_assert( trace_assert(
value.is<TracedValue>(), "expect internal node, got external"); value.is(m_value_type), "expect internal node, got external");
auto& traced_value = value.cast<TracedValue>(); auto& traced_value = value.cast(m_value_type);
trace_assert(traced_value.id() == id, "input id mismatch"); trace_assert(traced_value.id() == id, "input id mismatch");
break; break;
} }
...@@ -559,7 +545,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { ...@@ -559,7 +545,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
} }
auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t {
auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]); auto traced_value = m_value_type.make(id, &m_vars[id], &m_var_accessors[id]);
m_weak_values.push_back(traced_value); m_weak_values.push_back(traced_value);
return traced_value; return traced_value;
} }
...@@ -569,7 +555,7 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { ...@@ -569,7 +555,7 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
return m_seq[m_pc++]; return m_seq[m_pc++];
} }
ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { ShapeValue::ref_t CompiledTransformation::TracedValue::shape() const {
if (!m_shape) { if (!m_shape) {
trace_assert(m_accessor->shape_getter, "shape unreadable"); trace_assert(m_accessor->shape_getter, "shape unreadable");
m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter()));
...@@ -577,14 +563,14 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { ...@@ -577,14 +563,14 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
return m_shape; return m_shape;
} }
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { DTypeValue::ref_t CompiledTransformation::TracedValue::dtype() const {
return m_var->dtype; return m_var->dtype;
} }
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { CompNodeValue::ref_t CompiledTransformation::TracedValue::comp_node() const {
return m_var->device; return m_var->device;
} }
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { auto CompiledTransformation::TracedValue::accessor() const -> const VarAccessor& {
return *m_accessor; return *m_accessor;
} }
...@@ -605,7 +591,7 @@ ValueRefList CompiledTransformation::apply_op( ...@@ -605,7 +591,7 @@ ValueRefList CompiledTransformation::apply_op(
ValueRefList CompiledTransformation::apply_get_attr( ValueRefList CompiledTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) { const GetAttr& get_attr, Span<ValueRef> inputs) {
if (auto* traced_value = inputs[0].as<TracedValue>()) { if (auto* traced_value = inputs[0].as(m_value_type)) {
ValueRef output; ValueRef output;
auto& var_accessor = traced_value->accessor(); auto& var_accessor = traced_value->accessor();
switch (get_attr.attr()) { switch (get_attr.attr()) {
...@@ -718,15 +704,11 @@ void CompiledTransformation::on_unregister() noexcept { ...@@ -718,15 +704,11 @@ void CompiledTransformation::on_unregister() noexcept {
void CompiledTransformation::execute() { void CompiledTransformation::execute() {
mgb_assert(m_executable != nullptr); mgb_assert(m_executable != nullptr);
m_graph_executor = std::thread([&] { {
try { MGB_LOCK_GUARD(m_mutex);
m_executable->execute(); m_graph_status = 1;
m_executable->wait(); }
} catch (...) { m_cv.notify_all();
auto exc = std::current_exception();
set_exception(exc);
}
});
} }
void CompiledTransformation::wait() { void CompiledTransformation::wait() {
...@@ -735,8 +717,9 @@ void CompiledTransformation::wait() { ...@@ -735,8 +717,9 @@ void CompiledTransformation::wait() {
} catch (...) { } catch (...) {
} }
mgb_assert(m_executable != nullptr); mgb_assert(m_executable != nullptr);
m_graph_executor.join(); std::unique_lock lock{m_mutex};
m_graph_executor = {}; m_cv.wait(lock, [&] { return m_graph_status == 0; });
lock.unlock();
for (auto&& box : m_boxes) { for (auto&& box : m_boxes) {
box->reset(); box->reset();
} }
......
...@@ -25,16 +25,16 @@ ValueRef::storage_t& ValueRef::storage() const { ...@@ -25,16 +25,16 @@ ValueRef::storage_t& ValueRef::storage() const {
return m_storage; return m_storage;
} }
const Value* ValueRef::as(size_t typecode) const { const Value* ValueRef::as(const IType& type) const {
auto&& storage = this->storage(); auto&& storage = this->storage();
if (storage->m_typecode != typecode) { if (storage->type() != type) {
return nullptr; return nullptr;
} }
return static_cast<Value*>(storage.get()); return static_cast<Value*>(storage.get());
} }
bool ValueRef::is(size_t typecode) const { bool ValueRef::is(const IType& type) const {
return this->storage()->m_typecode == typecode; return this->storage()->type() == type;
} }
TypedValueRef<DeviceValue> ValueRef::dev_tensor() const { TypedValueRef<DeviceValue> ValueRef::dev_tensor() const {
...@@ -106,9 +106,7 @@ std::string ValueRef::raw_type() const { ...@@ -106,9 +106,7 @@ std::string ValueRef::raw_type() const {
if (!m_storage) { if (!m_storage) {
return "null"; return "null";
} }
auto& types = Value::registered_types(); return m_storage->type().name();
mgb_assert(types.size() > m_storage->m_typecode);
return types[m_storage->m_typecode].name();
} }
bool ValueRef::watching() const { bool ValueRef::watching() const {
...@@ -137,7 +135,7 @@ ValueRef ValueWeakRef::lock() { ...@@ -137,7 +135,7 @@ ValueRef ValueWeakRef::lock() {
return {strong_storage}; return {strong_storage};
} }
Value::Value(size_t typecode) : m_typecode{typecode} { Value::Value() {
m_id = nr_values++; m_id = nr_values++;
} }
...@@ -147,17 +145,6 @@ Value::~Value() { ...@@ -147,17 +145,6 @@ Value::~Value() {
} }
} }
size_t Value::register_type(std::type_index type) {
auto& types = const_cast<std::vector<std::type_index>&>(registered_types());
types.push_back(type);
return types.size() - 1;
}
const std::vector<std::type_index>& Value::registered_types() {
static std::vector<std::type_index> sm_registered_types;
return sm_registered_types;
}
void Value::register_value(ValueRef value) { void Value::register_value(ValueRef value) {
registered_values[value.id()] = ValueWeakRef(value); registered_values[value.id()] = ValueWeakRef(value);
} }
...@@ -188,7 +175,7 @@ std::vector<ValueRef> Value::end_record_values() { ...@@ -188,7 +175,7 @@ std::vector<ValueRef> Value::end_record_values() {
} }
void Value::try_rethrow() { void Value::try_rethrow() {
if (m_typecode == ErrorValue::TYPE_CODE) { if (type() == PrimitiveType<ErrorValue>::instance) {
auto message = static_cast<ErrorValue*>(this)->message(); auto message = static_cast<ErrorValue*>(this)->message();
mgb_throw(MegBrainError, "invalid value: %s", message.c_str()); mgb_throw(MegBrainError, "invalid value: %s", message.c_str());
} }
...@@ -198,13 +185,9 @@ inline void ValueRefList::init(size_t nr_elems) { ...@@ -198,13 +185,9 @@ inline void ValueRefList::init(size_t nr_elems) {
m_size = nr_elems; m_size = nr_elems;
if (m_size > 0) { if (m_size > 0) {
if (m_size == 1) { if (m_size == 1) {
m_data = inline_storage(); m_data = new (inline_storage()) ValueRef();
} else { } else {
auto& context = Transformation::get_context(); m_data = new ValueRef[m_size];
m_data = context.allocator.allocate(m_size);
}
for (size_t i = 0; i < m_size; ++i) {
new (m_data + i) ValueRef();
} }
} else { } else {
m_data = nullptr; m_data = nullptr;
...@@ -215,9 +198,6 @@ ValueRefList::ValueRefList(size_t nr_elems) { ...@@ -215,9 +198,6 @@ ValueRefList::ValueRefList(size_t nr_elems) {
init(nr_elems); init(nr_elems);
} }
/*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
: ValueRefList(values.begin(), values.end()) {}*/
ValueRefList::ValueRefList(const ValueRefList& rhs) ValueRefList::ValueRefList(const ValueRefList& rhs)
: ValueRefList(rhs.cbegin(), rhs.cend()) {} : ValueRefList(rhs.cbegin(), rhs.cend()) {}
...@@ -271,14 +251,12 @@ ValueRefList::~ValueRefList() { ...@@ -271,14 +251,12 @@ ValueRefList::~ValueRefList() {
} }
void ValueRefList::clear() { void ValueRefList::clear() {
for (size_t i = 0; i < m_size; ++i) {
m_data[i].~ValueRef();
}
if (m_data) { if (m_data) {
if (m_size != 1) { if (m_size != 1) {
Transformation::get_context().allocator.deallocate(m_data, m_size); delete[] m_data;
} else { } else {
mgb_assert(m_data == inline_storage()); mgb_assert(m_data == inline_storage());
m_data->~ValueRef();
} }
} }
m_data = nullptr; m_data = nullptr;
......
...@@ -25,79 +25,68 @@ class GradKey; ...@@ -25,79 +25,68 @@ class GradKey;
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>;
class ShapeValue final class ShapeValue final : public PrimitiveValue<ShapeValue, ValueShape> {
: public MixinValueImpl<ShapeValue, ValueKind::Primitive, ValueShape> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
class CompNodeValue final class CompNodeValue final : public PrimitiveValue<CompNodeValue, CompNode> {
: public MixinValueImpl<CompNodeValue, ValueKind::Primitive, CompNode> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
// TODO: override factory method class Boolean {
class BoolValue final : public ValueImpl<BoolValue, ValueKind::Primitive> {
private: private:
std::optional<bool> m_value; bool m_value;
public: public:
BoolValue(bool value) : m_value{value} {} Boolean() = default;
operator bool() const { return *m_value; } Boolean(bool value) : m_value(value) {}
std::string to_string() const override; operator bool() const { return m_value; }
};
void clear() override { m_value.reset(); } // TODO: override factory method
class BoolValue final : public PrimitiveValue<BoolValue, Boolean> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
}; };
class HostStorage final class HostStorage final : public PrimitiveValue<HostStorage, HostTensorStorage> {
: public MixinValueImpl<HostStorage, ValueKind::Primitive, HostTensorStorage> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
class DeviceStorage final class DeviceStorage final : public PrimitiveValue<DeviceStorage, DeviceTensorStorage> {
: public MixinValueImpl<
DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
/** class HostTensor {
* \brief like HostTensorND mixin, but allow scalar value
*
*/
class HostValue final : public ValueImpl<HostValue, ValueKind::Primitive> {
private: private:
DType m_dtype; DType m_dtype;
ValueShape m_shape; ValueShape m_shape;
HostTensorStorage m_storage; HostTensorStorage m_storage;
public: public:
HostValue(DType dtype, ValueShape shape, HostTensorStorage storage) HostTensor() = default;
HostTensor(DType dtype, ValueShape shape, HostTensorStorage storage)
: m_dtype(dtype), m_shape(shape), m_storage(storage) {} : m_dtype(dtype), m_shape(shape), m_storage(storage) {}
HostValue(HostTensorND value) HostTensor(HostTensorND value)
: HostValue( : HostTensor(
value.dtype(), ValueShape::from(value.shape()), value.storage()) { value.dtype(), ValueShape::from(value.shape()), value.storage()) {
} }
std::string to_string() const override;
void clear() override {
m_dtype = {};
m_shape = {};
m_storage = {};
}
DType dtype() const { return m_dtype; } DType dtype() const { return m_dtype; }
const ValueShape& shape() const { return m_shape; } const ValueShape& shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); } CompNode device() const { return m_storage.comp_node(); }
...@@ -112,31 +101,31 @@ public: ...@@ -112,31 +101,31 @@ public:
}; };
/** /**
* \brief like DeviceTensorND mixin, but allow scalar value * \brief like HostTensorND mixin, but allow scalar value
* *
*/ */
class DeviceValue final : public ValueImpl<DeviceValue, ValueKind::Primitive> { class HostValue final : public PrimitiveValue<HostValue, HostTensor> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class DeviceTensor {
private: private:
DType m_dtype; DType m_dtype;
ValueShape m_shape; ValueShape m_shape;
DeviceTensorStorage m_storage; DeviceTensorStorage m_storage;
public: public:
DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) DeviceTensor() = default;
DeviceTensor(DType dtype, ValueShape shape, DeviceTensorStorage storage)
: m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {} : m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {}
DeviceValue(const DeviceTensorND& value) DeviceTensor(const DeviceTensorND& value)
: DeviceValue( : DeviceTensor(
value.dtype(), ValueShape::from(value.shape()), value.storage()) { value.dtype(), ValueShape::from(value.shape()), value.storage()) {
} }
std::string to_string() const override;
void clear() override {
m_dtype = {};
m_shape = {};
m_storage = {};
}
DType dtype() const { return m_dtype; } DType dtype() const { return m_dtype; }
const ValueShape& shape() const { return m_shape; } const ValueShape& shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); } CompNode device() const { return m_storage.comp_node(); }
...@@ -145,26 +134,34 @@ public: ...@@ -145,26 +134,34 @@ public:
DeviceTensorND as_nd(bool allow_scalar = false) const; DeviceTensorND as_nd(bool allow_scalar = false) const;
}; };
class FunctionValue final /**
: public MixinValueImpl<FunctionValue, ValueKind::Primitive, GenericFunction> { * \brief like DeviceTensorND mixin, but allow scalar value
*
*/
class DeviceValue final : public PrimitiveValue<DeviceValue, DeviceTensor> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class FunctionValue final : public PrimitiveValue<FunctionValue, GenericFunction> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
class DTypeValue final class DTypeValue final : public PrimitiveValue<DTypeValue, DType> {
: public MixinValueImpl<DTypeValue, ValueKind::Primitive, DType> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
class StringValue final class StringValue final : public PrimitiveValue<StringValue, std::string> {
: public MixinValueImpl<StringValue, ValueKind::Primitive, std::string> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
...@@ -180,10 +177,9 @@ public: ...@@ -180,10 +177,9 @@ public:
std::string message() const { return m_message; } std::string message() const { return m_message; }
}; };
class ErrorValue final class ErrorValue final : public PrimitiveValue<ErrorValue, Error> {
: public MixinValueImpl<ErrorValue, ValueKind::Primitive, Error> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override; std::string to_string() const override;
}; };
......
...@@ -57,7 +57,7 @@ struct Subgraph { ...@@ -57,7 +57,7 @@ struct Subgraph {
SmallVector<expr_t> exprs; SmallVector<expr_t> exprs;
template <typename T, typename F, typename C> template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { SmallVector<T> apply(Span<T> input_vars, F&& f, C&& c) const {
std::unordered_map<size_t, T> idx2var; std::unordered_map<size_t, T> idx2var;
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); mgb_assert(inputs.size() == input_vars.size(), "input size mismatch");
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
...@@ -71,8 +71,7 @@ struct Subgraph { ...@@ -71,8 +71,7 @@ struct Subgraph {
for (auto idx : expr.inputs) { for (auto idx : expr.inputs) {
expr_inputs.push_back(idx2var[idx]); expr_inputs.push_back(idx2var[idx]);
} }
SmallVector<T> expr_outputs = SmallVector<T> expr_outputs = f(expr.op, expr_inputs, expr.outputs.size());
f(expr.op, std::move(expr_inputs), expr.outputs.size());
mgb_assert( mgb_assert(
expr_outputs.size() == expr.outputs.size(), "output size mismatch"); expr_outputs.size() == expr.outputs.size(), "output size mismatch");
for (size_t i = 0; i < expr_outputs.size(); ++i) { for (size_t i = 0; i < expr_outputs.size(); ++i) {
...@@ -102,9 +101,9 @@ struct EncodedSubgraph { ...@@ -102,9 +101,9 @@ struct EncodedSubgraph {
SmallVector<bool> input_mask; SmallVector<bool> input_mask;
SmallVector<bool> output_mask; SmallVector<bool> output_mask;
template <typename TContainer> template <typename T>
TContainer encode_inputs(TContainer inputs) const { SmallVector<T> encode_inputs(Span<T> inputs) const {
TContainer encoded_inputs; SmallVector<T> encoded_inputs;
size_t index = 0; size_t index = 0;
for (auto&& input : inputs) { for (auto&& input : inputs) {
mgb_assert(index < input_mask.size(), "index out of range"); mgb_assert(index < input_mask.size(), "index out of range");
...@@ -116,9 +115,9 @@ struct EncodedSubgraph { ...@@ -116,9 +115,9 @@ struct EncodedSubgraph {
return encoded_inputs; return encoded_inputs;
} }
template <typename TContainer> template <typename T>
TContainer encode_outputs(TContainer outputs) const { SmallVector<T> encode_outputs(Span<T> outputs) const {
TContainer encoded_outputs; SmallVector<T> encoded_outputs;
size_t index = 0; size_t index = 0;
for (auto&& output : outputs) { for (auto&& output : outputs) {
mgb_assert(index < output_mask.size(), "index out of range"); mgb_assert(index < output_mask.size(), "index out of range");
...@@ -130,9 +129,9 @@ struct EncodedSubgraph { ...@@ -130,9 +129,9 @@ struct EncodedSubgraph {
return encoded_outputs; return encoded_outputs;
} }
template <typename TContainer> template <typename T>
TContainer decode_outputs(TContainer outputs) const { SmallVector<T> decode_outputs(Span<T> outputs) const {
TContainer decoded_outputs; SmallVector<T> decoded_outputs;
size_t index = 0; size_t index = 0;
for (size_t i = 0; i < output_mask.size(); i++) { for (size_t i = 0; i < output_mask.size(); i++) {
mgb_assert(index < output_mask.size(), "index out of range"); mgb_assert(index < output_mask.size(), "index out of range");
...@@ -150,8 +149,8 @@ struct EncodedSubgraph { ...@@ -150,8 +149,8 @@ struct EncodedSubgraph {
EncodedSubgraph result; EncodedSubgraph result;
result.input_mask = graph.gen_input_mask(); result.input_mask = graph.gen_input_mask();
result.output_mask = graph.gen_output_mask(); result.output_mask = graph.gen_output_mask();
graph.inputs = result.encode_inputs(graph.inputs); graph.inputs = result.encode_inputs<Subgraph::var_t>(graph.inputs);
graph.outputs = result.encode_outputs(graph.outputs); graph.outputs = result.encode_outputs<Subgraph::var_t>(graph.outputs);
result.graph = graph; result.graph = graph;
return result; return result;
} }
...@@ -179,11 +178,11 @@ struct EncodedSubgraph { ...@@ -179,11 +178,11 @@ struct EncodedSubgraph {
} }
template <typename T, typename F, typename C> template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { SmallVector<T> apply(Span<T> input_vars, F&& f, C&& c) const {
auto encoded_inputs = encode_inputs(input_vars); auto encoded_inputs = encode_inputs<T>(input_vars);
auto encoded_outputs = auto encoded_outputs =
graph.apply(encoded_inputs, std::forward<F>(f), std::forward<C>(c)); graph.apply<T>(encoded_inputs, std::forward<F>(f), std::forward<C>(c));
return decode_outputs(encoded_outputs); return decode_outputs<T>(encoded_outputs);
} }
std::string repr() const; std::string repr() const;
...@@ -280,4 +279,4 @@ public: ...@@ -280,4 +279,4 @@ public:
}; };
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
\ No newline at end of file
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace mgb::imperative { namespace mgb::imperative {
struct InterpreterInfo { class InterpreterValue final : public ObjectValue<InterpreterValue> {
public: public:
using Handle = interpreter::Interpreter::Handle; using Handle = interpreter::Interpreter::Handle;
using Channel = interpreter::Interpreter::Channel; using Channel = interpreter::Interpreter::Channel;
...@@ -46,8 +46,7 @@ private: ...@@ -46,8 +46,7 @@ private:
mutable ShapeValue::ref_t m_shape; mutable ShapeValue::ref_t m_shape;
public: public:
InterpreterInfo() = default; InterpreterValue(LocalPtr<RAIIHandle> handle, std::string name = {})
InterpreterInfo(LocalPtr<RAIIHandle> handle, std::string name = {})
: m_handle(handle), m_name(name) {} : m_handle(handle), m_name(name) {}
const LocalPtr<RAIIHandle>& handle() const { return m_handle; } const LocalPtr<RAIIHandle>& handle() const { return m_handle; }
...@@ -57,18 +56,14 @@ public: ...@@ -57,18 +56,14 @@ public:
ShapeValue::ref_t shape() const; ShapeValue::ref_t shape() const;
std::string name() const { return m_name; } std::string name() const { return m_name; }
};
class InterpreterValue final
: public MixinValueImpl<InterpreterValue, ValueKind::Object, InterpreterInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override { std::string to_string() const override {
return ssprintf( return ssprintf(
"Handle{ptr=%p, name=%s}", handle().get(), "Handle{ptr=%p, name=%s}", handle().get(),
imperative::quoted(name()).c_str()); imperative::quoted(name()).c_str());
} }
void clear() override { m_handle = {}; }
}; };
/** /**
...@@ -82,11 +77,12 @@ class InterpreterTransformation final : public Transformation { ...@@ -82,11 +77,12 @@ class InterpreterTransformation final : public Transformation {
public: public:
using Interpreter = interpreter::Interpreter; using Interpreter = interpreter::Interpreter;
using Handle = Interpreter::Handle; using Handle = Interpreter::Handle;
using SharedHandle = LocalPtr<InterpreterInfo::RAIIHandle>; using SharedHandle = LocalPtr<InterpreterValue::RAIIHandle>;
using Channel = Interpreter::Channel; using Channel = Interpreter::Channel;
private: private:
std::shared_ptr<Channel> m_channel; std::shared_ptr<Channel> m_channel;
ObjectType<InterpreterValue> m_value_type{"InterpreterValue"};
public: public:
explicit InterpreterTransformation(std::shared_ptr<Channel> channel) explicit InterpreterTransformation(std::shared_ptr<Channel> channel)
...@@ -105,7 +101,7 @@ public: ...@@ -105,7 +101,7 @@ public:
const Operator& op, Span<ValueRef> inputs) override; const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override { ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<InterpreterValue>()); mgb_assert(!value.is(m_value_type));
return value; return value;
} }
......
...@@ -34,7 +34,8 @@ struct BackwardGraphWithClosure { ...@@ -34,7 +34,8 @@ struct BackwardGraphWithClosure {
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs); std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs);
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); void operator()(
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; }
...@@ -51,7 +52,7 @@ struct CustomBackward; ...@@ -51,7 +52,7 @@ struct CustomBackward;
using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>; using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>;
struct CustomBackward { struct CustomBackward {
using BackwardFn = std::function<ValueRefList(Span<ValueRef>)>; using BackwardFn = std::function<SmallVector<ValueRef>(Span<ValueRef>)>;
using BackwardRule = std::function<std::optional<ValueRefList>( using BackwardRule = std::function<std::optional<ValueRefList>(
const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>; const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>;
BackwardFn m_backward; BackwardFn m_backward;
...@@ -62,7 +63,8 @@ struct CustomBackward { ...@@ -62,7 +63,8 @@ struct CustomBackward {
SmallVector<OutputAttr> m_output_attrs; SmallVector<OutputAttr> m_output_attrs;
public: public:
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); void operator()(
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return m_input_has_grad[i]; } bool input_has_grad(size_t i) { return m_input_has_grad[i]; }
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; }
...@@ -175,7 +177,7 @@ inline GradSlot* GradSlotPtr::operator->() const { ...@@ -175,7 +177,7 @@ inline GradSlot* GradSlotPtr::operator->() const {
return &m_fn->m_slots[m_index]; return &m_fn->m_slots[m_index];
} }
class GradValue final : public ValueImpl<GradValue, ValueKind::Object> { class GradValue final : public ObjectValue<GradValue> {
private: private:
ValueRef m_value; ValueRef m_value;
std::shared_ptr<GradKey> m_key; std::shared_ptr<GradKey> m_key;
...@@ -187,14 +189,9 @@ public: ...@@ -187,14 +189,9 @@ public:
std::string to_string() const override; std::string to_string() const override;
bool has_key(const std::shared_ptr<GradKey>& key) const { return m_key == key; } const GradSlotPtr& slot() const { return m_slot; }
const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const { // std::shared_ptr<GradKey> key() const { return m_key; }
mgb_assert(m_key == key);
return m_slot;
}
std::shared_ptr<GradKey> key() const { return m_key; }
void clear() override { void clear() override {
m_slot = {}; m_slot = {};
...@@ -216,9 +213,12 @@ private: ...@@ -216,9 +213,12 @@ private:
std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape; std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape;
std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape; std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape;
bool m_frozen = false; bool m_frozen = false;
const Type<GradValue>& m_value_type;
public: public:
GradKey() { m_tape.reserve(4 * 1024); } GradKey(const Type<GradValue>& value_type) : m_value_type(value_type) {
m_tape.reserve(4 * 1024);
}
void backward(); void backward();
GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback); GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback);
...@@ -230,10 +230,9 @@ public: ...@@ -230,10 +230,9 @@ public:
}; };
class GradKeyValue final class GradKeyValue final
: public MixinValueImpl< : public PrimitiveValue<GradKeyValue, std::shared_ptr<GradKey>> {
GradKeyValue, ValueKind::Primitive, std::shared_ptr<GradKey>> {
public: public:
using MixinValueImpl::MixinValueImpl; using PrimitiveValue::PrimitiveValue;
std::string to_string() const override { std::string to_string() const override {
return ssprintf("GradKey{%s}", (*this)->name().c_str()); return ssprintf("GradKey{%s}", (*this)->name().c_str());
...@@ -242,26 +241,20 @@ public: ...@@ -242,26 +241,20 @@ public:
class GradTransformation final : public Transformation { class GradTransformation final : public Transformation {
private: private:
ObjectType<GradValue> m_value_type{"GradValue"};
std::shared_ptr<GradKey> m_key; std::shared_ptr<GradKey> m_key;
std::vector<GradValue::weak_ref_t> m_weak_values; std::vector<GradValue::weak_ref_t> m_weak_values;
size_t m_suppressed = 0; size_t m_suppressed = 0;
public: public:
GradTransformation(std::shared_ptr<GradKey> key) : m_key(key) {} GradTransformation() { m_key = std::make_shared<GradKey>(m_value_type); }
auto record_grad(GradValue::ref_t tensor) { auto record_grad(GradValue::ref_t tensor) {
m_weak_values.push_back(tensor); m_weak_values.push_back(tensor);
return tensor; return tensor;
} }
bool is_grad_value(const ValueRef& value) { bool is_grad_value(const ValueRef& value) { return value.is(m_value_type); }
if (auto* grad_value = value.as<GradValue>()) {
if (grad_value->has_key(m_key)) {
return true;
}
}
return false;
}
/** /**
* \brief test whether value is related to this GradTransformation * \brief test whether value is related to this GradTransformation
...@@ -273,13 +266,7 @@ public: ...@@ -273,13 +266,7 @@ public:
* \return GradValue::ref_t * \return GradValue::ref_t
*/ */
const GradValue::ref_t& as_grad_value(const ValueRef& value) { const GradValue::ref_t& as_grad_value(const ValueRef& value) {
auto&& grad_value = value.as_ref<GradValue>(); return value.as_ref(m_value_type);
if (grad_value) {
if (grad_value->has_key(m_key)) {
return grad_value;
}
}
return GradValue::ref_t::nil;
} }
bool has_key(std::shared_ptr<GradKey> key) { bool has_key(std::shared_ptr<GradKey> key) {
...@@ -299,6 +286,8 @@ public: ...@@ -299,6 +286,8 @@ public:
return value; return value;
} }
const std::shared_ptr<GradKey>& key() const { return m_key; }
std::string name() const override { return "GradTransformation"; } std::string name() const override { return "GradTransformation"; }
GenericFunction make_backward_closure(Span<ValueRef> ys); GenericFunction make_backward_closure(Span<ValueRef> ys);
......
...@@ -22,32 +22,27 @@ ...@@ -22,32 +22,27 @@
namespace mgb::imperative { namespace mgb::imperative {
class LazyEvalInfo { class LazyEvalValue final : public ObjectValue<LazyEvalValue> {
private: private:
VarNode* m_node = nullptr; VarNode* m_node = nullptr;
ValueRef m_bound_data; ValueRef m_bound_data;
std::string m_name; std::string m_name;
public: public:
LazyEvalInfo() = default; LazyEvalValue(VarNode* node, ValueRef bound_data, std::string name)
LazyEvalInfo(VarNode* node, ValueRef bound_data, std::string name)
: m_node(node), m_bound_data(bound_data), m_name(name) {} : m_node(node), m_bound_data(bound_data), m_name(name) {}
VarNode* node() const { return m_node; } VarNode* node() const { return m_node; }
ValueRef bound_data() const { return m_bound_data; } ValueRef bound_data() const { return m_bound_data; }
std::string name() const { return m_name; } std::string name() const { return m_name; }
};
class LazyEvalValue final
: public MixinValueImpl<LazyEvalValue, ValueKind::Object, LazyEvalInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override { std::string to_string() const override {
return ssprintf( return ssprintf(
"LazyEvalValue{node=%p, name=%s}", node(), node()->name().c_str()); "LazyEvalValue{node=%p, name=%s}", node(), node()->name().c_str());
} }
void clear() override {}
}; };
/** /**
...@@ -67,6 +62,7 @@ private: ...@@ -67,6 +62,7 @@ private:
std::vector<LazyEvalValue::weak_ref_t> m_weak_vars; std::vector<LazyEvalValue::weak_ref_t> m_weak_vars;
SymbolVar m_io_link = nullptr; SymbolVar m_io_link = nullptr;
std::exception_ptr m_graph_exc; std::exception_ptr m_graph_exc;
ObjectType<LazyEvalValue> m_value_type{"LazyEvalValue"};
public: public:
LazyEvalTransformation(bool no_exec) : m_no_exec(no_exec) { LazyEvalTransformation(bool no_exec) : m_no_exec(no_exec) {
...@@ -75,7 +71,7 @@ public: ...@@ -75,7 +71,7 @@ public:
LazyEvalValue::ref_t record_var( LazyEvalValue::ref_t record_var(
VarNode* node, ValueRef bound_data = {}, std::string name = {}) { VarNode* node, ValueRef bound_data = {}, std::string name = {}) {
auto lazy_eval_val = LazyEvalValue::make(node, bound_data, name); auto lazy_eval_val = m_value_type.make(node, bound_data, name);
m_weak_vars.push_back(lazy_eval_val); m_weak_vars.push_back(lazy_eval_val);
return lazy_eval_val; return lazy_eval_val;
} }
...@@ -86,7 +82,7 @@ public: ...@@ -86,7 +82,7 @@ public:
const Operator& op, Span<ValueRef> inputs) override; const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override { ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<LazyEvalValue>()); mgb_assert(!value.is(m_value_type));
return value; return value;
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace mgb::imperative { namespace mgb::imperative {
class ScalarValue final : public ValueImpl<ScalarValue, ValueKind::Object> { class ScalarValue final : public ObjectValue<ScalarValue> {
private: private:
ValueRef m_value; ValueRef m_value;
...@@ -47,17 +47,21 @@ public: ...@@ -47,17 +47,21 @@ public:
class ScalarTransformation final : public Transformation { class ScalarTransformation final : public Transformation {
private: private:
ShapeValue::ref_t m_empty_shape; // [] ShapeValue::ref_t m_empty_shape; // []
ObjectType<ScalarValue> m_value_type{"ScalarValue"};
public: public:
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs);
ValueRefList apply_transformation( ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override; const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override { ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<ScalarValue>()); mgb_assert(!value.is(m_value_type));
return value; return value;
} }
std::string name() const override { return "ScalarTransformation"; } std::string name() const override { return "ScalarTransformation"; }
const Type<ScalarValue>& value_type() const { return m_value_type; }
}; };
} // namespace mgb::imperative } // namespace mgb::imperative
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace mgb::imperative { namespace mgb::imperative {
class SymbolValue final : public ValueImpl<SymbolValue, ValueKind::Object> { class SymbolValue final : public ObjectValue<SymbolValue> {
private: private:
VarNode* m_node = nullptr; VarNode* m_node = nullptr;
...@@ -47,6 +47,7 @@ public: ...@@ -47,6 +47,7 @@ public:
class SymbolTransformation final : public Transformation { class SymbolTransformation final : public Transformation {
private: private:
ComputingGraph* m_graph = nullptr; ComputingGraph* m_graph = nullptr;
ObjectType<SymbolValue> m_value_type{"SymbolValue"};
public: public:
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {}
...@@ -55,12 +56,12 @@ public: ...@@ -55,12 +56,12 @@ public:
if (auto* apply_op = op.as<ApplyOp>()) { if (auto* apply_op = op.as<ApplyOp>()) {
SmallVector<VarNode*> input_nodes; SmallVector<VarNode*> input_nodes;
for (auto&& input : inputs) { for (auto&& input : inputs) {
input_nodes.push_back(input.cast<SymbolValue>().node()); input_nodes.push_back(input.cast(m_value_type).node());
} }
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes);
ValueRefList outputs(output_nodes.size()); ValueRefList outputs(output_nodes.size());
for (size_t i = 0; i < output_nodes.size(); ++i) { for (size_t i = 0; i < output_nodes.size(); ++i) {
outputs[i] = SymbolValue::make(output_nodes[i]); outputs[i] = m_value_type.make(output_nodes[i]);
} }
return outputs; return outputs;
} else if (auto* create_tensor = op.as<CreateTensor>()) { } else if (auto* create_tensor = op.as<CreateTensor>()) {
...@@ -69,9 +70,9 @@ public: ...@@ -69,9 +70,9 @@ public:
args.kind == CreateTensor::Const, args.kind == CreateTensor::Const,
"only const value is allowed here"); "only const value is allowed here");
auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node();
return {SymbolValue::make(node)}; return {m_value_type.make(node)};
} else if (auto* get_attr = op.as<GetAttr>()) { } else if (auto* get_attr = op.as<GetAttr>()) {
auto* node = inputs.as_array<1>()[0].cast<SymbolValue>().node(); auto* node = inputs.item().cast(m_value_type).node();
switch (get_attr->attr()) { switch (get_attr->attr()) {
case GetAttr::DType: case GetAttr::DType:
return {DTypeValue::make(node->dtype())}; return {DTypeValue::make(node->dtype())};
...@@ -121,11 +122,13 @@ public: ...@@ -121,11 +122,13 @@ public:
} }
ValueRef unwrap(ValueRef value) override { ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<SymbolValue>(), "SymbolValue doesn't support unwrap"); mgb_assert(!value.is(m_value_type), "SymbolValue doesn't support unwrap");
return value; return value;
} }
std::string name() const override { return "SymbolTransformation"; } std::string name() const override { return "SymbolTransformation"; }
const Type<SymbolValue>& value_type() const { return m_value_type; }
}; };
} // namespace mgb::imperative } // namespace mgb::imperative
...@@ -100,22 +100,15 @@ public: ...@@ -100,22 +100,15 @@ public:
} }
}; };
class TracingInfo { class TracingValue final : public ObjectValue<TracingValue> {
private: private:
ValueRef m_value = {}; ValueRef m_value = {};
size_t m_id = 0; size_t m_id = 0;
public: public:
TracingInfo() = default; TracingValue(ValueRef value, size_t id) : m_value(value), m_id(id) {}
TracingInfo(ValueRef value, size_t id) : m_value(value), m_id(id) {}
ValueRef value() const { return m_value; } ValueRef value() const { return m_value; }
size_t id() const { return m_id; } size_t id() const { return m_id; }
};
class TracingValue final
: public MixinValueImpl<TracingValue, ValueKind::Object, TracingInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override { std::string to_string() const override {
return ssprintf( return ssprintf(
...@@ -126,6 +119,8 @@ public: ...@@ -126,6 +119,8 @@ public:
void on_watch() override { value().watch(); } void on_watch() override { value().watch(); }
void on_unwatch() override { value().unwatch(); } void on_unwatch() override { value().unwatch(); }
void clear() override { m_value = {}; }
}; };
/** /**
...@@ -146,6 +141,7 @@ private: ...@@ -146,6 +141,7 @@ private:
std::vector<TracingValue::weak_ref_t> m_weak_vars; std::vector<TracingValue::weak_ref_t> m_weak_vars;
bool m_capture_as_const = false; bool m_capture_as_const = false;
bool m_record_input_shapes = false; bool m_record_input_shapes = false;
ObjectType<TracingValue> m_value_type{"TracingValue"};
public: public:
TracingTransformation(bool capture_as_const, bool record_input_shapes) TracingTransformation(bool capture_as_const, bool record_input_shapes)
...@@ -162,7 +158,7 @@ public: ...@@ -162,7 +158,7 @@ public:
*/ */
TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) { TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) {
size_t id = m_vars.size(); size_t id = m_vars.size();
auto wrapped_value = TracingValue::make(value, id); auto wrapped_value = m_value_type.make(value, id);
m_vars.push_back({id, value.dtype(), value.device()}); m_vars.push_back({id, value.dtype(), value.device()});
auto& var = m_vars.back(); auto& var = m_vars.back();
if (capture) { if (capture) {
...@@ -179,7 +175,7 @@ public: ...@@ -179,7 +175,7 @@ public:
return wrapped_value; return wrapped_value;
} }
ValueRef unwrap_var(ValueRef value) { ValueRef unwrap_var(ValueRef value) {
if (auto* tracing_value = value.as<TracingValue>()) { if (auto* tracing_value = value.as(m_value_type)) {
return tracing_value->value(); return tracing_value->value();
} }
return value; return value;
...@@ -189,7 +185,7 @@ public: ...@@ -189,7 +185,7 @@ public:
const Operator& op, Span<ValueRef> inputs) override; const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override { ValueRef unwrap(ValueRef value) override {
if (auto* tracing_value = value.as<TracingValue>()) { if (auto* tracing_value = value.as(m_value_type)) {
return tracing_value->value(); return tracing_value->value();
} }
return value; return value;
...@@ -234,7 +230,7 @@ public: ...@@ -234,7 +230,7 @@ public:
std::function<void(std::exception_ptr)> exc_setter; std::function<void(std::exception_ptr)> exc_setter;
}; };
class TracedInfo { class TracedValue final : public ObjectValue<TracedValue> {
private: private:
size_t m_id = 0; size_t m_id = 0;
VarInfo* m_var = nullptr; VarInfo* m_var = nullptr;
...@@ -244,8 +240,7 @@ public: ...@@ -244,8 +240,7 @@ public:
mutable CompNodeValue::ref_t m_comp_node; mutable CompNodeValue::ref_t m_comp_node;
public: public:
TracedInfo() = default; TracedValue(size_t id, VarInfo* var, VarAccessor* accessor)
TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor)
: m_id(id), m_var(var), m_accessor(accessor) {} : m_id(id), m_var(var), m_accessor(accessor) {}
size_t id() const { return m_id; } size_t id() const { return m_id; }
ShapeValue::ref_t shape() const; ShapeValue::ref_t shape() const;
...@@ -256,16 +251,12 @@ public: ...@@ -256,16 +251,12 @@ public:
void set_exception(std::exception_ptr exc) const { void set_exception(std::exception_ptr exc) const {
m_accessor->exc_setter(exc); m_accessor->exc_setter(exc);
} }
};
class TracedValue final
: public MixinValueImpl<TracedValue, ValueKind::Object, TracedInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override { std::string to_string() const override {
return ssprintf("TracedValue{\"id\"=%zu}", id()); return ssprintf("TracedValue{\"id\"=%zu}", id());
} }
void clear() override {}
}; };
private: private:
...@@ -280,9 +271,12 @@ private: ...@@ -280,9 +271,12 @@ private:
std::function<bool(ValueRef, ValueRef)> m_value_comparator; std::function<bool(ValueRef, ValueRef)> m_value_comparator;
bool m_input_shape_static; bool m_input_shape_static;
std::mutex m_mutex; std::mutex m_mutex;
std::condition_variable m_cv;
std::exception_ptr m_graph_exc; std::exception_ptr m_graph_exc;
int m_graph_status = 0; // 0 = stop, 1 = running, 2 = finalizing
std::vector<std::shared_ptr<BoxBase>> m_boxes; std::vector<std::shared_ptr<BoxBase>> m_boxes;
ComputingGraph::OutputSpec m_output_spec; ComputingGraph::OutputSpec m_output_spec;
ObjectType<TracedValue> m_value_type{"TracedValue"};
public: public:
CompiledTransformation(TraceResult result, bool input_shape_static) CompiledTransformation(TraceResult result, bool input_shape_static)
...@@ -292,6 +286,27 @@ public: ...@@ -292,6 +286,27 @@ public:
m_graph = ComputingGraph::make(); m_graph = ComputingGraph::make();
options().no_force_inplace = true; options().no_force_inplace = true;
options().async_exec_level = 0b100; options().async_exec_level = 0b100;
m_graph_executor = std::thread([&] {
while (true) {
std::unique_lock lock{m_mutex};
m_cv.wait(lock, [&] { return m_graph_status != 0; });
lock.unlock();
if (m_graph_status == 2) {
break;
}
try {
m_executable->execute();
m_executable->wait();
} catch (...) {
auto exc = std::current_exception();
set_exception(exc);
}
lock.lock();
m_graph_status = 0;
lock.unlock();
m_cv.notify_all();
}
});
} }
ComputingGraph& graph() { return *m_graph; } ComputingGraph& graph() { return *m_graph; }
...@@ -350,7 +365,7 @@ public: ...@@ -350,7 +365,7 @@ public:
void on_unregister() noexcept override; void on_unregister() noexcept override;
ValueRef unwrap(ValueRef value) override { ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<TracedValue>()); mgb_assert(!value.is(m_value_type));
return value; return value;
} }
...@@ -368,6 +383,15 @@ public: ...@@ -368,6 +383,15 @@ public:
m_boxes.push_back(box); m_boxes.push_back(box);
return box; return box;
} }
~CompiledTransformation() {
{
MGB_LOCK_GUARD(m_mutex);
m_graph_status = 2;
}
m_cv.notify_all();
m_graph_executor.join();
}
}; };
} // namespace mgb::imperative } // namespace mgb::imperative
...@@ -11,7 +11,9 @@ ...@@ -11,7 +11,9 @@
#pragma once #pragma once
#include <optional>
#include <typeindex> #include <typeindex>
#include <vector>
#include "megbrain/utils/mempool.h" #include "megbrain/utils/mempool.h"
#include "megbrain/utils/metahelper.h" #include "megbrain/utils/metahelper.h"
......
...@@ -34,7 +34,7 @@ public: ...@@ -34,7 +34,7 @@ public:
Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {} Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {}
Span(const T* begin, size_t size) : Span(begin, begin + size) {} Span(const T* begin, size_t size) : Span(begin, begin + size) {}
template <typename TContainer> template <typename TContainer>
Span(TContainer& container) : Span(container.data(), container.size()) {} Span(const TContainer& container) : Span(container.data(), container.size()) {}
const T* begin() const { return m_begin; } const T* begin() const { return m_begin; }
const T* end() const { return m_end; } const T* end() const { return m_end; }
const T* data() const { return m_begin; } const T* data() const { return m_begin; }
......
...@@ -2,7 +2,10 @@ ...@@ -2,7 +2,10 @@
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include <map>
#include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
namespace mgb { namespace mgb {
...@@ -18,7 +21,7 @@ public: ...@@ -18,7 +21,7 @@ public:
private: private:
clock_t::duration m_duration = clock_t::duration{0}; clock_t::duration m_duration = clock_t::duration{0};
size_t m_timing = 0; size_t m_timing = 0;
const char* m_name = nullptr; std::string m_name;
uint64_t m_count = 0; uint64_t m_count = 0;
size_t m_enabled = 1; size_t m_enabled = 1;
bool m_default_enabled = true; bool m_default_enabled = true;
...@@ -42,7 +45,8 @@ private: ...@@ -42,7 +45,8 @@ private:
} }
if (timer.m_enabled) { if (timer.m_enabled) {
if (!--timer.m_timing) { if (!--timer.m_timing) {
timer.m_duration += (clock_t::now() - start); auto duration = (clock_t::now() - start);
timer.m_duration += duration;
} }
timer.m_count++; timer.m_count++;
} }
...@@ -67,13 +71,10 @@ private: ...@@ -67,13 +71,10 @@ private:
} }
}; };
using TimeScope = TimeScopeRecursive;
public: public:
Timer(const char* name, bool default_enabled); Timer(std::string name, bool default_enabled = true);
const char* name() { return m_name; } std::string name() { return m_name; }
auto time_scope() { return TimeScope(*this); }
auto time_scope_recursive() { return TimeScopeRecursive(*this); }; auto time_scope_recursive() { return TimeScopeRecursive(*this); };
auto enable_scope() { return EnableScope(*this); } auto enable_scope() { return EnableScope(*this); }
void reset() { void reset() {
...@@ -88,7 +89,14 @@ public: ...@@ -88,7 +89,14 @@ public:
} // namespace stats } // namespace stats
struct Stats { struct Stats {
static inline std::vector<stats::Timer*> sm_timers; struct TimerNode {
std::map<std::string, std::unique_ptr<TimerNode>> children;
stats::Timer* timer = nullptr;
TimerNode() {}
};
static inline TimerNode sm_root;
// register your timers here // register your timers here
// for example: // for example:
...@@ -97,33 +105,84 @@ struct Stats { ...@@ -97,33 +105,84 @@ struct Stats {
// //
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code // then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code
static void print() { static std::pair<long, long> print_node(
std::vector<const char*> unused_timers; std::string name, TimerNode& node, size_t indent = 0) {
auto print_indent = [&] {
for (auto* timer : sm_timers) { for (size_t i = 0; i < indent; ++i) {
if (timer->count() == 0) { printf(" ");
unused_timers.push_back(timer->name()); }
} else { };
printf("%s costs %ld ns, happens %ld times\n", timer->name(), long ns = 0, count = 0;
timer->get().count(), timer->count()); if (auto* timer = node.timer) {
print_indent();
printf("%s costs %'ld ns, hits %'ld times\n", name.c_str(),
(long)timer->get().count(), (long)timer->count());
ns = timer->get().count();
count = timer->count();
}
if (!node.children.empty()) {
bool collect_children = node.timer == nullptr;
if (collect_children) {
print_indent();
printf("%s:\n", name.c_str());
}
long ns = 0, count = 0;
for (auto&& child : node.children) {
auto [child_ns, child_count] =
print_node(child.first, *child.second, indent + 4);
if (collect_children) {
ns += child_ns;
count += child_count;
}
}
if (collect_children) {
print_indent();
printf("total costs %'ld ns, hits %'ld times\n", ns, count);
} }
} }
return {ns, count};
}
if (!unused_timers.empty()) { static void print() {
printf("%zu timers unused\n", unused_timers.size()); for (auto&& child : sm_root.children) {
print_node(child.first, *child.second);
} }
} }
static void reset() { static void reset() {
for (auto* timer : sm_timers) { auto reset_node = [](TimerNode& node, auto&& reset_node) -> void {
timer->reset(); if (auto* timer = node.timer) {
} timer->reset();
}
for (auto&& child : node.children) {
reset_node(*child.second, reset_node);
}
};
reset_node(sm_root, reset_node);
} }
}; };
inline stats::Timer::Timer(const char* name, bool default_enabled) inline stats::Timer::Timer(std::string name, bool default_enabled)
: m_name(name), m_default_enabled(default_enabled) { : m_name(name), m_default_enabled(default_enabled) {
Stats::sm_timers.push_back(this); std::vector<std::string> terms;
Stats::TimerNode* node = &Stats::sm_root;
while (true) {
auto pos = name.find(".");
if (pos == std::string::npos) {
auto& child = node->children[name];
child = std::make_unique<Stats::TimerNode>();
node = child.get();
node->timer = this;
break;
} else {
auto& child = node->children[name.substr(0, pos)];
if (!child) {
child = std::make_unique<Stats::TimerNode>();
}
node = child.get();
name = name.substr(pos + 1);
}
}
} }
#if MGE_ENABLE_STATS #if MGE_ENABLE_STATS
......
...@@ -50,18 +50,70 @@ class Operator; ...@@ -50,18 +50,70 @@ class Operator;
class ValueRefList; class ValueRefList;
/**
* \brief base class of all value types
*/
class IType : public NonCopyableObj {
private:
std::string m_name;
// TODO: count values, or make an linkedlist
public:
IType(std::string name) : m_name(std::move(name)) {}
const std::string& name() const { return m_name; }
bool operator==(const IType& rhs) const { return this == &rhs; }
bool operator!=(const IType& rhs) const { return this != &rhs; }
};
/**
* \brief type of values.
*
* \tparam T ctype of value
*/
template <typename T>
class Type : public IType {
protected:
Type(std::string name) : IType(std::move(name)) {}
// TODO: each type owns an allocator
public:
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
TypedValueRef<T> make(TArgs&&... args) const;
};
/**
* \brief type of primitive values.
*
* \tparam T ctype of value
*/
template <typename T> template <typename T>
class Type { class PrimitiveType : public Type<T> {
private: private:
const size_t m_code = T::TYPE_CODE; PrimitiveType();
public: public:
inline size_t code() const { return m_code; } static inline PrimitiveType instance;
}; };
enum class ValueKind { /**
Primitive, * \brief type of object values.
Object, *
* \tparam T ctype of value
*/
template <typename T>
class ObjectType : public Type<T> {
public:
ObjectType(std::string name) : Type<T>(name) {}
}; };
/** /**
...@@ -71,9 +123,8 @@ enum class ValueKind { ...@@ -71,9 +123,8 @@ enum class ValueKind {
* and only the tail node is valid. ValueRef stores a value node, and it may be * and only the tail node is valid. ValueRef stores a value node, and it may be
* an invalid internal node. When you dereference it, it will check its successor, * an invalid internal node. When you dereference it, it will check its successor,
* automatically find the tail node and return. This list would be modified to reduce * automatically find the tail node and return. This list would be modified to reduce
* list length by change value's successor, but a ValueRef always has steady m_storage * list length by change value's successor, but a steady id was kept in ValueRef
* when not explicitly modified. * so we can use it for identify a ValueRef ( hash / equility / id ).
* So we use m_storage to identify a ValueRef ( hash / equility / id ).
*/ */
class ValueRef { class ValueRef {
public: public:
...@@ -93,9 +144,7 @@ private: ...@@ -93,9 +144,7 @@ private:
*/ */
storage_t& storage() const; storage_t& storage() const;
const Value* as(size_t typecode) const; const Value* as(const IType& type) const;
bool is(size_t typecode) const;
public: public:
ValueRef() = default; ValueRef() = default;
...@@ -103,45 +152,76 @@ public: ...@@ -103,45 +152,76 @@ public:
/** /**
* \brief whether value is instance of target type or not * \brief whether value is instance of target type or not
* *
* \tparam TValue target type * \param type target type
* \return true if type of value is TValue * \return true if type of value is instance of type
* \return false if empty or type of value is not TValue * \return false if empty or type of value is not instance of type
*/ */
template <typename TValue> bool is(const IType& type) const;
inline bool is(Type<TValue> type = {}) const;
/** /**
* \brief try cast value as target type * \brief try cast value as target type
* *
* \tparam TValue target type * \tparam type target type
* \return TValue* raw pointer if success, otherwise nullptr * \return TValue* raw pointer if success, otherwise nullptr
*/ */
template <typename TValue> template <typename TValue>
inline const TValue* as(Type<TValue> type = {}) const; inline const TValue* as(const Type<TValue>& type) const;
/** /**
* \brief cast value to target type * \brief cast value to target type
* *
* \tparam TValue target type * \param type target type
* \return TValue& reference of value * \return TValue& reference of value
*/ */
template <typename TValue> template <typename TValue>
inline const TValue& cast(Type<TValue> type = {}) const; inline const TValue& cast(const Type<TValue>& type) const;
/** /**
* \brief like as(), but returns TypedValueRef instead * \brief like as(), but returns TypedValueRef instead
* *
* \tparam TValue target type * \param type target type
* \return TypedValueRef<TValue> reference if success, otherwise empty reference * \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/ */
template <typename TValue> template <typename TValue>
inline const TypedValueRef<TValue>& as_ref(Type<TValue> type = {}) const; inline const TypedValueRef<TValue>& as_ref(const Type<TValue>& type) const;
/**
* \brief like cast(), but allow empty value and returns TypedValueRef instead
*
* \param type target type
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template <typename TValue>
inline const TypedValueRef<TValue>& cast_ref(const Type<TValue>& type) const;
template <typename TValue>
inline std::enable_if_t<TValue::is_primitive, bool> is() const {
return is(PrimitiveType<TValue>::instance);
}
template <typename TValue>
inline std::enable_if_t<TValue::is_primitive, const TValue*> as() const {
return as(PrimitiveType<TValue>::instance);
}
template <typename TValue>
inline std::enable_if_t<TValue::is_primitive, const TValue&> cast() const {
return cast(PrimitiveType<TValue>::instance);
}
template <typename TValue> template <typename TValue>
inline const TypedValueRef<TValue>& cast_ref(Type<TValue> type = {}) const; inline std::enable_if_t<TValue::is_primitive, const TypedValueRef<TValue>&> as_ref()
const {
return as_ref(PrimitiveType<TValue>::instance);
}
template <typename TValue> template <typename TValue>
void on_cast_failure() const; inline std::enable_if_t<TValue::is_primitive, const TypedValueRef<TValue>&>
cast_ref() const {
return cast_ref(PrimitiveType<TValue>::instance);
}
void on_cast_failure(const IType& type) const;
operator bool() const { return bool(m_storage); } operator bool() const { return bool(m_storage); }
...@@ -172,8 +252,6 @@ public: ...@@ -172,8 +252,6 @@ public:
friend class ValueWeakRef; friend class ValueWeakRef;
template <typename> template <typename>
friend class TypedValueRef; friend class TypedValueRef;
template <typename, ValueKind>
friend class ValueImpl;
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs);
}; };
...@@ -195,7 +273,8 @@ protected: ...@@ -195,7 +273,8 @@ protected:
public: public:
ValueWeakRef() = default; ValueWeakRef() = default;
ValueWeakRef(ValueRef value) : m_id(value.id()), m_storage(value.m_storage) {} ValueWeakRef(const ValueRef& value)
: m_id(value.id()), m_storage(value.m_storage) {}
/** /**
* \brief try promote to ValueRef * \brief try promote to ValueRef
...@@ -218,19 +297,15 @@ public: ...@@ -218,19 +297,15 @@ public:
class Value : public NonCopyableObj { class Value : public NonCopyableObj {
private: private:
uint64_t m_id = std::numeric_limits<uint64_t>::max(); uint64_t m_id = std::numeric_limits<uint64_t>::max();
size_t m_typecode = 0; const IType* m_type = nullptr;
ValueRef m_successor; ValueRef m_successor;
size_t m_watching = 0; size_t m_watching = 0;
protected: protected:
Value(size_t typecode); Value();
public: public:
size_t typecode() const { return m_typecode; } const IType& type() const { return *m_type; }
const std::type_index type() const { return registered_types()[m_typecode]; }
static size_t register_type(std::type_index type);
static const std::vector<std::type_index>& registered_types();
static void register_value(ValueRef value); static void register_value(ValueRef value);
static ValueRef get_value_by_id(uint64_t id); static ValueRef get_value_by_id(uint64_t id);
...@@ -251,11 +326,12 @@ public: ...@@ -251,11 +326,12 @@ public:
friend class ValueRef; friend class ValueRef;
friend class ValueWeakRef; friend class ValueWeakRef;
template <typename, ValueKind>
friend class ValueImpl;
template <typename T> template <typename T>
friend class TypedValueRef; friend class TypedValueRef;
template <typename T>
friend class Type;
~Value(); ~Value();
private: private:
...@@ -267,30 +343,17 @@ private: ...@@ -267,30 +343,17 @@ private:
* *
* \tparam T type of value * \tparam T type of value
*/ */
template <typename T, ValueKind Kind> template <typename T>
class ValueImpl : public Value { class ObjectValue : public Value {
protected: protected:
ValueImpl() : Value(TYPE_CODE) {} ObjectValue() {}
public: public:
using ref_t = TypedValueRef<T>; using ref_t = TypedValueRef<T>;
using weak_ref_t = TypedValueWeakRef<T>; using weak_ref_t = TypedValueWeakRef<T>;
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); static constexpr bool is_primitive = false;
static constexpr ValueKind KIND = Kind; static constexpr bool is_object = true;
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
static MGB_NOINLINE TypedValueRef<T> make(TArgs&&... args) {
static_assert(std::is_final_v<T>);
return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...));
}
}; };
/** /**
...@@ -299,74 +362,89 @@ public: ...@@ -299,74 +362,89 @@ public:
* \tparam T type of value * \tparam T type of value
* \tparam TMixin type of mixin class * \tparam TMixin type of mixin class
*/ */
template <typename T, ValueKind Kind, typename TMixin> template <typename T, typename TMixin>
class MixinValueImpl : public ValueImpl<T, Kind>, public TMixin { class PrimitiveValue : public Value, public TMixin {
public: public:
using ref_t = TypedValueRef<T>;
using weak_ref_t = TypedValueWeakRef<T>;
using TMixin::TMixin; using TMixin::TMixin;
MixinValueImpl(TMixin mixin) : TMixin(std::move(mixin)) {} PrimitiveValue(TMixin&& mixin) : TMixin(std::move(mixin)) {}
PrimitiveValue(const TMixin& mixin) : TMixin(mixin) {}
public: public:
void clear() override final { ((TMixin&)*this) = {}; } void clear() override final { ((TMixin&)*this) = {}; }
bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; }
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
static TypedValueRef<T> make(TArgs&&... args) {
return PrimitiveType<T>::instance.make(std::forward<TArgs&&>(args)...);
}
static constexpr bool is_primitive = true;
static constexpr bool is_object = false;
}; };
template <typename T>
PrimitiveType<T>::PrimitiveType() : Type<T>(typeid(T).name()) {
static_assert(std::is_base_of_v<Value, T>);
static_assert(!std::is_base_of_v<ObjectValue<T>, T>);
}
inline ValueRef::ValueRef(storage_t storage) { inline ValueRef::ValueRef(storage_t storage) {
// mgb_assert(storage);
m_storage = storage; m_storage = storage;
m_id = m_storage->m_id; m_id = m_storage->m_id;
} }
template <typename TValue> template <typename TValue>
inline const TValue* ValueRef::as(Type<TValue> type) const { inline const TValue* ValueRef::as(const Type<TValue>& type) const {
// auto _ = Stats::time_value_as.time_scope();
static_assert(std::is_base_of_v<Value, TValue>); static_assert(std::is_base_of_v<Value, TValue>);
return static_cast<const TValue*>(as(type.code())); return static_cast<const TValue*>(as((const IType&)type));
} }
template <typename TValue> template <typename TValue>
inline const TValue& ValueRef::cast(Type<TValue> type) const { inline const TValue& ValueRef::cast(const Type<TValue>& type) const {
// auto _ = Stats::time_value_cast.time_scope();
auto* ptr = as<TValue>(type); auto* ptr = as<TValue>(type);
if (mgb_unlikely(!ptr)) { if (mgb_unlikely(!ptr)) {
on_cast_failure<TValue>(); on_cast_failure(type);
} }
return static_cast<const TValue&>(*ptr); return static_cast<const TValue&>(*ptr);
} }
template <typename TValue> template <typename TValue>
inline bool ValueRef::is(Type<TValue> type) const { inline const TypedValueRef<TValue>& ValueRef::as_ref(const Type<TValue>& type) const {
// auto _ = Stats::time_value_is.time_scope(); if (!is(type)) {
return is(type.code());
}
template <typename TValue>
inline const TypedValueRef<TValue>& ValueRef::as_ref(Type<TValue> type) const {
if (!is<TValue>(type)) {
return TypedValueRef<TValue>::nil; return TypedValueRef<TValue>::nil;
} }
return *reinterpret_cast<const TypedValueRef<TValue>*>(this); return *reinterpret_cast<const TypedValueRef<TValue>*>(this);
} }
template <typename TValue> template <typename TValue>
inline const TypedValueRef<TValue>& ValueRef::cast_ref(Type<TValue> type) const { inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type) const {
if (!m_storage) { if (!m_storage) {
return TypedValueRef<TValue>::nil; return TypedValueRef<TValue>::nil;
} }
if (mgb_unlikely(!is<TValue>(type))) { if (mgb_unlikely(!is(type))) {
on_cast_failure<TValue>(); on_cast_failure(type);
} }
return *reinterpret_cast<const TypedValueRef<TValue>*>(this); return *reinterpret_cast<const TypedValueRef<TValue>*>(this);
} }
template <typename TValue> inline void ValueRef::on_cast_failure(const IType& type) const {
void ValueRef::on_cast_failure() const {
// if this is ErrorValue, rethrow directly // if this is ErrorValue, rethrow directly
storage()->try_rethrow(); storage()->try_rethrow();
mgb_assert( mgb_assert(
storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s", storage()->type() != type, "expect type %s, got %s", type.name().c_str(),
typeid(TValue).name(), to_string().c_str()); to_string().c_str());
} }
/** /**
...@@ -382,26 +460,10 @@ private: ...@@ -382,26 +460,10 @@ private:
public: public:
TypedValueRef() = default; TypedValueRef() = default;
const T& operator*() const { const T& operator*() const {
if constexpr (T::KIND == ValueKind::Object) { mgb_assert(m_storage, "empty storage");
return this->template cast<T>(); return static_cast<const T&>(*m_storage);
} else if constexpr (T::KIND == ValueKind::Primitive) {
if (!m_storage) {
on_cast_failure<T>();
}
return static_cast<const T&>(*m_storage);
} else {
static_assert(!std::is_same_v<T, T>);
}
}
const T* operator->() const {
if constexpr (T::KIND == ValueKind::Object) {
return this->template as<T>();
} else if constexpr (T::KIND == ValueKind::Primitive) {
return static_cast<const T*>(m_storage.get());
} else {
static_assert(!std::is_same_v<T, T>);
}
} }
const T* operator->() const { return static_cast<const T*>(m_storage.get()); }
/** /**
* \brief reset underlying value to another value * \brief reset underlying value to another value
...@@ -409,7 +471,7 @@ public: ...@@ -409,7 +471,7 @@ public:
* \param successor new value * \param successor new value
*/ */
inline void reset(ValueRef successor) { inline void reset(ValueRef successor) {
static_assert(T::KIND == ValueKind::Object); static_assert(std::is_base_of_v<ObjectValue<T>, T>);
mgb_assert(m_storage); mgb_assert(m_storage);
mgb_assert(!m_storage->m_successor); mgb_assert(!m_storage->m_successor);
if (m_storage->m_watching) { if (m_storage->m_watching) {
...@@ -422,25 +484,19 @@ public: ...@@ -422,25 +484,19 @@ public:
static inline const TypedValueRef nil; static inline const TypedValueRef nil;
friend class ValueRef; friend class ValueRef;
friend class Type<T>;
template <typename, ValueKind> friend class TypedValueWeakRef<T>;
friend class ValueImpl;
}; };
template <typename T> template <typename T>
class TypedValueWeakRef : public ValueWeakRef { class TypedValueWeakRef : public ValueWeakRef {
private: private:
TypedValueWeakRef(const ValueRef& value) : ValueWeakRef(value) {}
TypedValueWeakRef(const ValueWeakRef& value) : ValueWeakRef(value) {}
public: public:
TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} TypedValueWeakRef(const TypedValueRef<T>& value) : ValueWeakRef(value) {}
TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} TypedValueRef<T> lock() { return (TypedValueRef<T>)ValueWeakRef::lock(); }
TypedValueRef<T> lock() {
auto value = ValueWeakRef::lock();
if (value) {
return value.template as_ref<T>();
} else {
return {};
}
}
}; };
// TODO: add proxy value type, which is meant to be reset in the end // TODO: add proxy value type, which is meant to be reset in the end
...@@ -509,10 +565,14 @@ inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_s ...@@ -509,10 +565,14 @@ inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_s
m_data[0] = std::move(item); m_data[0] = std::move(item);
} }
/*class ValueRefList : public SmallVector<ValueRef, 1> { template <typename T>
public: template <typename... TArgs>
using SmallVector::SmallVector; TypedValueRef<T> Type<T>::make(TArgs&&... args) const {
};*/ static_assert(std::is_final_v<T>);
auto storage = LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...);
storage->m_type = this;
return ValueRef::make(std::move(storage));
}
} // namespace imperative } // namespace imperative
} // namespace mgb } // namespace mgb
......
...@@ -123,7 +123,7 @@ TEST(TestImperative, BackwardGraphBasic) { ...@@ -123,7 +123,7 @@ TEST(TestImperative, BackwardGraphBasic) {
} }
} }
inputs.clear(); inputs.clear();
auto input_grads = result.graph.apply( auto input_grads = result.graph.apply<TensorPtr>(
backward_graph_inputs, apply_shared_on_physical_tensor, backward_graph_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }); [&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size()); mgb_assert(input_grads.size() == input_has_grad.size());
...@@ -177,7 +177,7 @@ TEST(TestImperative, BackwardGraphIdentity) { ...@@ -177,7 +177,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
} }
} }
inputs.clear(); inputs.clear();
auto input_grads = result.graph.apply( auto input_grads = result.graph.apply<TensorPtr>(
backward_graph_inputs, apply_shared_on_physical_tensor, backward_graph_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }); [&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size()); mgb_assert(input_grads.size() == input_has_grad.size());
...@@ -244,11 +244,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { ...@@ -244,11 +244,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads = expand_grads( auto grads = expand_grads(
bg.output_mask, bg.output_mask,
bg.graph.apply( bg.graph.apply<TensorPtr>(
backward_graph_inputs, apply_shared_on_physical_tensor, backward_graph_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; })); [&](auto&& x) { return x; }));
auto precomp = obg.precomp.apply( auto precomp = obg.precomp.apply<TensorPtr>(
SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, apply_shared_on_physical_tensor, SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }); [&](auto&& x) { return x; });
ASSERT_EQ(precomp.size(), 2); ASSERT_EQ(precomp.size(), 2);
...@@ -261,7 +261,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { ...@@ -261,7 +261,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads( auto grads2 = expand_grads(
obg.input_has_grad, obg.input_has_grad,
obg.backward.apply( obg.backward.apply<TensorPtr>(
backward_inputs, apply_shared_on_physical_tensor, backward_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; })); [&](auto&& x) { return x; }));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册