提交 177001d5 编写于 作者: M Megvii Engine Team

refactor(dispatch): allow dynamic type creation

GitOrigin-RevId: 27dde05cff7e1e0bf61c652a3ae1fe4def829ada
上级 150a6a61
......@@ -19,6 +19,7 @@
#include "range/v3/all.hpp"
#include "./helper.h"
#include "./transformation.h"
namespace py = pybind11;
......@@ -30,9 +31,7 @@ namespace {
std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map;
}
GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) {
grad_key_map[m_key] = this;
}
GradKeyWrapper::GradKeyWrapper() {}
void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) {
if (nargs != 2) {
......@@ -77,8 +76,8 @@ pybind11::function GradKeyWrapper::get_backward_closure(
for (auto&& tensor : tensors) {
args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data());
}
auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0]
.as<FunctionValue>();
auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0];
auto closure = closure_value.as_ref<FunctionValue>();
auto py_function = [closure](std::vector<TensorWrapper*> tensors) {
std::vector<ValueRef> args;
for (auto* tw : tensors) {
......@@ -90,11 +89,14 @@ pybind11::function GradKeyWrapper::get_backward_closure(
}
PyObject* GradKeyWrapper::get_name() {
return py::cast(m_key->name()).release().ptr();
return py::cast(m_name).release().ptr();
}
void GradKeyWrapper::set_name(py::handle name) {
m_key->name(py::cast<std::string>(name));
m_name = py::cast<std::string>(name);
if (m_key) {
m_key->name(m_name);
}
}
PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
......@@ -115,7 +117,10 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) {
}
void GradKeyWrapper::enter() {
m_transformation = std::make_shared<GradTransformation>(m_key);
m_transformation = std::make_shared<GradTransformation>();
m_key = m_transformation->key();
m_key->name(m_name);
grad_key_map[m_key] = this;
TransformationManager::get_instance().register_at<TransformationManager::Grad>(
m_transformation);
}
......@@ -123,6 +128,8 @@ void GradKeyWrapper::enter() {
void GradKeyWrapper::exit() {
TransformationManager::get_instance().unregister<TransformationManager::Grad>(
m_transformation);
grad_key_map.erase(m_key);
m_key = {};
m_transformation.reset();
}
......@@ -138,8 +145,6 @@ GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) {
return grad_key_map.at(key);
}
GradKeyWrapper::~GradKeyWrapper() {
grad_key_map.erase(m_key);
}
GradKeyWrapper::~GradKeyWrapper() {}
} // namespace mgb::imperative::python
......@@ -26,6 +26,7 @@ struct GradKeyWrapper : NonCopyableObj {
using wrap_t = pyext17::wrap<GradKeyWrapper>;
static constexpr auto tp_name = pybind11::detail::_("GradKey");
std::string m_name;
std::shared_ptr<GradKey> m_key;
std::shared_ptr<GradTransformation> m_transformation;
......
......@@ -117,7 +117,7 @@ std::optional<ValueRefList> elemwise_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(2);
SmallVector<ValueRef> ret(2);
if (!grad) {
return ret;
}
......@@ -147,7 +147,7 @@ std::optional<ValueRefList> reshape_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(2);
SmallVector<ValueRef> ret(2);
if (!grad) {
return ret;
}
......@@ -180,7 +180,7 @@ std::optional<ValueRefList> subtensor_grad_rule(
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(1);
SmallVector<ValueRef> ret(1);
if (grad && inputs[0]) {
ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
......@@ -215,7 +215,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule(
grad_op_ = std::move(grad_op)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(1);
SmallVector<ValueRef> ret(1);
if (grad && inputs[0]) {
ValueRefList args_(inputs.size() + 1);
auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype());
......@@ -251,7 +251,7 @@ std::optional<ValueRefList> reduce_grad_rule(
maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(1);
SmallVector<ValueRef> ret(1);
if (grad && shapes[0]) {
ret[0] = broadcast_to(grad, shapes[0]);
}
......@@ -274,7 +274,7 @@ std::optional<ValueRefList> addAxis_grad_rule(
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(1);
SmallVector<ValueRef> ret(1);
if (grad && flag_) {
ret[0] = imperative::apply(*grad_op_, grad)[0];
}
......@@ -297,7 +297,7 @@ std::optional<ValueRefList> removeAxis_grad_rule(
maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(1);
SmallVector<ValueRef> ret(1);
if (grad && flag_) {
ret[0] = imperative::apply(*grad_op_, grad)[0];
}
......@@ -316,7 +316,7 @@ std::optional<ValueRefList> fastpathcopy_grad_rule(
maker.backward([](Span<ValueRef> grads) {
mgb_assert(grads.size() == 1);
ValueRef grad = grads[0];
ValueRefList ret(1);
SmallVector<ValueRef> ret(1);
if (grad) {
ret[0] = grad;
}
......
......@@ -56,42 +56,44 @@ WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
struct SymbolVarContext {
TransformationContext context;
cg::ComputingGraph* graph;
std::shared_ptr<SymbolTransformation> symbol_tsf;
std::shared_ptr<ScalarTransformation> scalar_tsf;
SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) {
SymbolVarContext(cg::ComputingGraph* graph) {
symbol_tsf = std::make_shared<SymbolTransformation>(graph);
scalar_tsf = std::make_shared<ScalarTransformation>();
Transformation::swap_context(context);
}
void init() {
std::make_shared<SymbolTransformation>(graph)->register_at(
Transformation::top());
std::make_shared<ScalarTransformation>()->register_at(Transformation::top());
symbol_tsf->register_at(Transformation::top());
scalar_tsf->register_at(Transformation::top());
}
~SymbolVarContext() { Transformation::swap_context(context); }
};
ValueRef symvar2val(py::handle py_symbol_var) {
auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node);
if (symbol_var->is_scalar) {
value = scalar_tsf->value_type().make(value);
}
return value;
}
ValueRef symvar2val(py::handle py_symbol_var) {
auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
ValueRef value = SymbolValue::make(symbol_var->m_node);
if (symbol_var->is_scalar) {
value = ScalarValue::make(value);
py::object val2symvar(py::handle typeobj, ValueRef value) {
bool is_scalar = false;
if (auto* scalar_value = value.as(scalar_tsf->value_type())) {
value = scalar_value->value();
is_scalar = true;
}
auto* node = value.cast(symbol_tsf->value_type()).node();
auto py_symbol_var =
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
return py_symbol_var;
}
return value;
}
py::object val2symvar(py::handle typeobj, ValueRef value) {
bool is_scalar = false;
if (auto* scalar_value = value.as<ScalarValue>()) {
value = scalar_value->value();
is_scalar = true;
}
auto* node = value.cast<SymbolValue>().node();
auto py_symbol_var =
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
return py_symbol_var;
}
~SymbolVarContext() { Transformation::swap_context(context); }
};
} // namespace
......@@ -130,19 +132,21 @@ PyObject* py_apply(
auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>();
SmallVector<ValueRef, 8> tensors(nargs);
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) {
bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) &&
py::isinstance<PySymbolVar>(py::handle(args[0]));
if (is_symbol_var) {
// swap to a special context to reuse scalar handle
SymbolVarContext context(
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph());
context.init();
for (size_t i = 0; i < nargs; ++i) {
tensors[i] = symvar2val(args[i]);
tensors[i] = context.symvar2val(args[i]);
}
auto outputs = imperative::apply(*op, tensors);
auto ret = pybind11::tuple(outputs.size());
auto typeobj = py::handle(args[0]).get_type();
for (size_t i = 0; i < outputs.size(); ++i) {
ret[i] = val2symvar(typeobj, outputs[i]);
ret[i] = context.val2symvar(typeobj, outputs[i]);
}
return ret.release().ptr();
}
......@@ -161,7 +165,7 @@ PyObject* py_apply(
}
}
auto outputs = imperative::apply(*op, tensors);
auto outputs = [&] { return imperative::apply(*op, tensors); }();
size_t nout = outputs.size();
auto ret = py::tuple(nout);
for (size_t i = 0; i < nout; ++i) {
......@@ -1573,9 +1577,9 @@ void init_tensor(py::module m) {
SymbolVarContext context(graph);
context.init();
auto output = reduce_to_scalar(
*op.cast<std::shared_ptr<OpDef>>(), symvar2val(tensor));
*op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor));
auto typeobj = tensor.get_type();
return val2symvar(typeobj, output);
return context.val2symvar(typeobj, output);
} else {
auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto output = reduce_to_scalar(
......
......@@ -67,10 +67,9 @@ struct TransformationManager {
}
};
class PyValue final
: public MixinValueImpl<PyValue, ValueKind::Object, pybind11::object> {
class PyValue final : public PrimitiveValue<PyValue, pybind11::object> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const {
return pybind11::str((const pybind11::object&)*this).cast<std::string>();
......
......@@ -63,7 +63,7 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
MegBrainError,
"unknown input type, expects HostStorage or DeviceStorage, got "
"%s",
input.name()->c_str());
input.to_string().c_str());
}
}
mgb_assert(
......
......@@ -12,7 +12,7 @@ std::string CompNodeValue::to_string() const {
}
std::string BoolValue::to_string() const {
return (*m_value) ? "true" : "false";
return (*this) ? "true" : "false";
}
std::string HostStorage::to_string() const {
......@@ -26,10 +26,10 @@ std::string DeviceStorage::to_string() const {
std::string HostValue::to_string() const {
return ssprintf(
"HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
m_dtype.name(), m_shape.to_string().c_str());
dtype().name(), shape().to_string().c_str());
}
HostTensorND HostValue::as_nd(bool allow_scalar) const {
HostTensorND HostTensor::as_nd(bool allow_scalar) const {
HostTensorND nd;
TensorShape tensor_shape;
if (m_shape.is_scalar()) {
......@@ -45,10 +45,10 @@ HostTensorND HostValue::as_nd(bool allow_scalar) const {
std::string DeviceValue::to_string() const {
return ssprintf(
"DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
m_dtype.name(), m_shape.to_string().c_str());
dtype().name(), shape().to_string().c_str());
}
DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const {
DeviceTensorND DeviceTensor::as_nd(bool allow_scalar) const {
DeviceTensorND nd;
TensorShape tensor_shape;
if (m_shape.is_scalar()) {
......
......@@ -19,46 +19,18 @@
namespace mgb {
namespace imperative {
namespace {
MGB_NOINLINE void copy_outputs(
ForwardAllocator<ValueRef>& allocator, ValueRefList& outputs) {
size_t nr_outputs = outputs.size();
if (mgb_likely(nr_outputs == 1)) {
ValueRef output_copy;
output_copy = outputs[0];
allocator.clear();
outputs = ValueRefList({output_copy});
} else if (!outputs.empty()) {
SmallVector<ValueRef> outputs_copy(nr_outputs);
for (size_t i = 0; i < nr_outputs; ++i) {
outputs_copy[i] = outputs[i];
}
outputs.clear();
allocator.clear();
outputs = {outputs_copy.begin(), outputs_copy.end()};
} else {
allocator.clear();
}
}
} // namespace
ValueRefList apply(const Operator& op, Span<ValueRef> inputs) {
auto& context = Transformation::get_context();
size_t& depth = context.next_transformation;
bool top = depth == 0;
auto outputs = ([&] {
if (mgb_unlikely(depth >= context.transformations.size())) {
return op.fallback(inputs);
} else {
auto& transformation = *context.transformations[depth++];
CleanupGuard _{[&] { --depth; }};
return transformation.apply_transformation(op, inputs);
}
})();
if (mgb_unlikely(top)) {
copy_outputs(context.allocator, outputs);
// TODO: add fallback transformation
bool fallback = depth >= context.transformations.size();
if (mgb_unlikely(fallback)) {
return op.fallback(inputs);
} else {
auto& transformation = *context.transformations[depth++];
CleanupGuard _{[&] { --depth; }};
return transformation.apply_transformation(op, inputs);
}
return outputs;
}
ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) {
......@@ -66,12 +38,7 @@ ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) {
}
ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) {
SmallVector<ValueRef> inputs_storage;
for (size_t i = 0; i < inputs.size(); ++i) {
inputs_storage.push_back(inputs[i]);
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs,
size_t) {
auto apply_functor = [](std::shared_ptr<OpDef> op, Span<ValueRef> inputs, size_t) {
auto outputs = imperative::apply(*op, inputs);
return SmallVector<ValueRef>(outputs.begin(), outputs.end());
};
......@@ -93,7 +60,7 @@ ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) {
HostStorage::make(host_value.storage()),
DeviceStorage::make(device_value.storage()))[0];
};
auto outputs = graph.apply(inputs_storage, apply_functor, make_const);
auto outputs = graph.apply(inputs, apply_functor, make_const);
return ValueRefList{outputs.begin(), outputs.end()};
}
......
......@@ -331,6 +331,7 @@ void ChannelImpl::dispatch_kernel(
cmd.inputs = std::move(input_infos);
cmd.outputs.reserve(output_descs.size());
outputs->reserve(output_descs.size());
for (int i = 0; i < output_descs.size(); ++i) {
auto&& desc = output_descs[i];
auto info = alloc();
......@@ -730,7 +731,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) {
input_descs.push_back({{{}, input->dtype()}, input->comp_node()});
}
auto forward_graph = OpDef::make_forward_graph(def, input_descs);
auto outputs = forward_graph.apply(inputs, apply_functor, const_functor);
auto outputs = forward_graph.apply<TensorPtr>(
inputs, apply_functor, const_functor);
return outputs;
}
return OpDef::apply_on_physical_tensor(def, inputs);
......
......@@ -11,6 +11,7 @@
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/utility.h"
......@@ -101,7 +102,7 @@ void apply_on_device_tensornd(
const OpDef& def, const SmallVector<DeviceTensorND>& inputs,
SmallVector<DeviceTensorND>* outputs) {
auto&& op_def = def.cast_final_safe<Elemwise>();
auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
auto&& trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode);
mgb_assert(
inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually",
trait.name, trait.arity, inputs.size());
......
......@@ -36,7 +36,7 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
.node();
};
auto subgraph = def.trait()->make_forward_graph(def, input_descs);
auto outputs = subgraph.apply(inputs, apply_functor, const_functor);
auto outputs = subgraph.apply<VarNode*>(inputs, apply_functor, const_functor);
return outputs;
}
......@@ -56,7 +56,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
value->layout(), value->comp_node(),
value->get_value().proxy_to_default_cpu()};
};
auto outputs = subgraph.apply(inputs, apply_functor, const_functor);
auto outputs =
subgraph.apply<LogicalTensorDesc>(inputs, apply_functor, const_functor);
return {outputs, all_validated};
}
......@@ -72,7 +73,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
return OpDef::apply_on_physical_tensor(*op, inputs);
};
auto const_functor = [&](const TensorPtr& value) { return value; };
auto outputs = subgraph.apply(inputs, apply_functor, const_functor);
auto outputs = subgraph.apply<TensorPtr>(inputs, apply_functor, const_functor);
return outputs;
}
......@@ -94,7 +95,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
};
GradContext<var_t> grad_context{accum_grad};
auto input_vars = builder.write_inputs(inputs);
auto outputs = forward_graph.apply(
auto outputs = forward_graph.apply<var_t>(
input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3),
[&](TensorPtr constant) {
return builder.write_constant(
......@@ -102,7 +103,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
});
size_t nr_outputs = outputs.size();
auto apply_mask = [](auto&& values, SmallVector<bool> mask) {
mgb_assert(mask.size() == values.size(), "");
mgb_assert(mask.size() == values.size());
std::decay_t<decltype(values)> results;
for (size_t i = 0; i < mask.size(); ++i) {
if (mask[i]) {
......@@ -143,7 +144,7 @@ static EncodedSubgraph make_backward_graph_from_forward(
return builder.write_constant(
constant, {constant->layout(), constant->comp_node()});
};
return bg.apply(grad_inputs, apply_functor, const_functor);
return bg.apply<var_t>(grad_inputs, apply_functor, const_functor);
});
builder.add_outputs(grad_context.get_grads(input_vars));
for (size_t i = 0; i < nr_outputs; ++i) {
......
......@@ -10,20 +10,19 @@
*/
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/utils/stats.h"
namespace mgb {
namespace imperative {
DTypeValue::ref_t InterpreterInfo::dtype() const {
DTypeValue::ref_t InterpreterValue::dtype() const {
if (!m_dtype) {
m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle()));
}
return m_dtype;
}
CompNodeValue::ref_t InterpreterInfo::comp_node() const {
CompNodeValue::ref_t InterpreterValue::comp_node() const {
if (!m_comp_node) {
m_comp_node = CompNodeValue::make(
handle()->channel()->get_device(handle()->handle()));
......@@ -31,7 +30,7 @@ CompNodeValue::ref_t InterpreterInfo::comp_node() const {
return m_comp_node;
}
ShapeValue::ref_t InterpreterInfo::shape() const {
ShapeValue::ref_t InterpreterValue::shape() const {
if (!m_shape) {
m_shape = ShapeValue::make(
ValueShape::from(handle()->channel()->get_shape(handle()->handle())));
......@@ -51,21 +50,22 @@ ValueRefList InterpreterTransformation::apply_op(
}
}};
for (auto input : inputs) {
input_handles.push_back(input.cast<InterpreterValue>().handle()->handle());
input_handles.push_back(input.cast(m_value_type).handle()->handle());
}
output_handles =
m_channel->apply_op(apply_op.op().shared_from_this(), input_handles);
ValueRefList outputs(output_handles.size());
for (size_t i = 0; i < output_handles.size(); ++i) {
outputs[i] = InterpreterValue::make(share_handle(output_handles[i]));
outputs[i] = m_value_type.make(share_handle(output_handles[i]));
output_handles[i] = nullptr;
}
output_handles.clear();
return outputs;
}
ValueRefList InterpreterTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) {
auto& input = inputs.item().cast<InterpreterValue>();
auto& input = inputs.item().cast(m_value_type);
ValueRef output;
switch (get_attr.attr()) {
case GetAttr::DType:
......@@ -98,10 +98,10 @@ ValueRefList InterpreterTransformation::apply_create_tensor(
if (!args.device) {
// implies H2D
mgb_assert(args.host, "neither host and device value is valid");
return {InterpreterValue::make(share_handle(
return {m_value_type.make(share_handle(
m_channel->put(*args.host, args.kind == CreateTensor::Unique)))};
} else {
return {InterpreterValue::make(share_handle(m_channel->put(
return {m_value_type.make(share_handle(m_channel->put(
*args.device, args.host ? *args.host : HostTensorND())))};
}
}
......@@ -119,7 +119,7 @@ ValueRefList InterpreterTransformation::apply_transformation(
} else if (auto* create_tensor = op.as<CreateTensor>()) {
return apply_create_tensor(*create_tensor, inputs);
} else if (auto* dtr_command = op.as<DTRCommand>()) {
auto handle = inputs[0].cast<InterpreterValue>().handle()->handle();
auto handle = inputs[0].cast(m_value_type).handle()->handle();
switch (dtr_command->kind()) {
case DTRCommand::Drop:
m_channel->drop(handle);
......@@ -129,10 +129,10 @@ ValueRefList InterpreterTransformation::apply_transformation(
}
return {};
} else if (auto* rename_value = op.as<RenameValue>()) {
auto& input = inputs[0].cast<InterpreterValue>();
return {InterpreterValue::make(input.handle(), rename_value->name())};
auto& input = inputs[0].cast(m_value_type);
return {m_value_type.make(input.handle(), rename_value->name())};
} else if (op.is<GetName>()) {
auto name = inputs[0].cast<InterpreterValue>().name();
auto name = inputs[0].cast(m_value_type).name();
if (!name.empty()) {
return {StringValue::make(name)};
} else {
......
......@@ -68,7 +68,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
size_t count = std::count_if(
save_for_backward.begin(), save_for_backward.end(), ranges::identity{});
if (!backward_graph->precomp.empty()) {
ValueRefList inputs_and_outputs(inputs.size() + outputs.size());
SmallVector<ValueRef> inputs_and_outputs(inputs.size() + outputs.size());
auto it = inputs_and_outputs.begin();
for (auto&& input : inputs) {
*it++ = input;
......@@ -94,7 +94,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure(
}
}
void BackwardGraphWithClosure::operator()(
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) {
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
ValueRef args[closure.size() + grads.size()];
size_t nargs = 0;
for (auto&& value : closure) {
......@@ -114,7 +114,9 @@ void BackwardGraphWithClosure::operator()(
if (null_grad) {
return;
}
auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs));
auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs));
SmallVector<ValueRef> igrads = {igrads_.begin(), igrads_.end()};
igrads_.clear();
auto&& iter = igrads.begin();
for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) {
if (p) {
......@@ -125,7 +127,7 @@ void BackwardGraphWithClosure::operator()(
}
void CustomBackward::operator()(
ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) {
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) {
size_t nargs = grads.size();
ValueRef args[nargs];
for (size_t i = 0; i < nargs; ++i) {
......@@ -206,7 +208,7 @@ void GradKey::backward() {
mgb_throw(AssertionError, "invalid backward");
} else {
mgb_assert(grad_fn->m_slots.size() > 0);
ValueRefList grads (grad_fn->m_slots.size());
SmallVector<ValueRef> grads (grad_fn->m_slots.size());
auto iter = grads.begin();
for (auto&& slot : grad_fn->m_slots) {
*iter++ = slot.m_grad;
......@@ -231,11 +233,9 @@ void GradKey::backward() {
GradValue::ref_t GradKey::attach(
ValueRef tensor, std::function<void(ValueRef)> callback) {
auto grad_value = tensor.as_ref<GradValue>();
if (grad_value && grad_value->has_key(shared_from_this())) {
mgb_assert(
!tensor.cast<GradValue>().slot_for(shared_from_this())->callback,
"callback exists");
auto grad_value = tensor.as_ref(m_value_type);
if (grad_value) {
mgb_assert(!tensor.cast(m_value_type).slot()->callback, "callback exists");
} else {
GradSlotPtr grad_slot;
auto& grad_fn = grad_slot.m_fn;
......@@ -243,9 +243,9 @@ GradValue::ref_t GradKey::attach(
grad_fn->m_key = shared_from_this();
grad_fn->m_slots.resize(1);
grad_slot.m_index = 0;
grad_value = GradValue::make(tensor, shared_from_this(), grad_slot);
grad_value = m_value_type.make(tensor, shared_from_this(), grad_slot);
}
grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback;
grad_value->slot().m_fn->m_slots[0].callback = callback;
return grad_value;
}
......@@ -263,7 +263,7 @@ void GradKey::freeze() {
ValueRefList GradTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
auto fallback = [&] {
ValueRefList unwrapped_inputs(inputs.size());
SmallVector<ValueRef> unwrapped_inputs(inputs.size());
{
// overhead
for (size_t i = 0; i < inputs.size(); ++i) {
......@@ -367,7 +367,7 @@ ValueRefList GradTransformation::apply_transformation(
for (size_t i = 0; i < inputs.size(); ++i) {
if (backward.input_has_grad(i) && require_grads[i]) {
auto& input_grad_slot =
inputs[i].cast<GradValue>().slot_for(m_key);
inputs[i].cast(m_value_type).slot();
grad_fn->m_dests.emplace_back(input_grad_slot);
grad_fn->m_dests.back().m_producer_record.insert_after(
input_grad_slot->m_producer_head);
......@@ -378,7 +378,7 @@ ValueRefList GradTransformation::apply_transformation(
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
// little overhead: Value::make
auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i});
auto grad_value = m_value_type.make(outputs[i], m_key, GradSlotPtr{grad_fn, i});
outputs[i] = record_grad(grad_value);
}
}
......@@ -435,7 +435,10 @@ ValueRefList GradTransformation::apply_transformation(
backward.m_input_has_grad = SmallVector(nr_inputs, true);
backward.m_output_attrs =
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true});
backward.m_backward = set_grad->grad_fn();
backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) {
auto result = fn(inputs);
return SmallVector<ValueRef>(result.begin(), result.end());
};
ValueRefList outputs(nr_outputs);
grad_fn->m_key = m_key;
grad_fn->m_slots.resize(nr_outputs);
......@@ -454,10 +457,10 @@ ValueRefList GradTransformation::apply_transformation(
auto& output = outputs_[i];
auto grad_value = as_grad_value(output);
if (grad_value) {
grad_value = GradValue::make(
grad_value = m_value_type.make(
grad_value->m_value, m_key, GradSlotPtr(grad_fn, i));
} else {
grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i));
grad_value = m_value_type.make(output, m_key, GradSlotPtr(grad_fn, i));
}
outputs[i] = record_grad(grad_value);
}
......@@ -485,8 +488,7 @@ ValueRefList GradTransformation::apply_transformation(
mgb_assert(inputs.size() == 1);
if (auto&& grad_value = as_grad_value(inputs[0])) {
auto output = imperative::apply(op, grad_value->m_value)[0];
auto grad_output = GradValue::make(
output, grad_value->key(), grad_value->slot_for(m_key));
auto grad_output = m_value_type.make(output, m_key, grad_value->slot());
return {record_grad(grad_output)};
} else {
return imperative::apply(op, inputs);
......@@ -502,7 +504,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
std::vector<GradSlotPtr> y_slots;
for (auto&& y : ys) {
if (auto&& grad_value = as_grad_value(y)) {
y_slots.push_back(grad_value->slot_for(grad_key));
y_slots.push_back(grad_value->slot());
} else {
y_slots.emplace_back();
}
......
......@@ -32,7 +32,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo());
VarNodeArray input_nodes;
for (auto&& input : inputs) {
if (auto* input_node = input.as<LazyEvalValue>()) {
if (auto* input_node = input.as(m_value_type)) {
input_nodes.push_back(input_node->node());
} else {
// ImmutableTensor has empty shape issues
......@@ -112,7 +112,7 @@ ValueRefList LazyEvalTransformation::apply_transformation(
return {record_var(node)};
}
} else if (auto* get_attr = op.as<GetAttr>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) {
if (auto* lazy_val = inputs.item().as(m_value_type)) {
switch (get_attr->attr()) {
case GetAttr::DType:
return {DTypeValue::make(lazy_val->node()->dtype())};
......@@ -167,14 +167,14 @@ ValueRefList LazyEvalTransformation::apply_transformation(
return imperative::apply(op, inputs);
}
} else if (auto* rename_value = op.as<RenameValue>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) {
if (auto* lazy_val = inputs.item().as(m_value_type)) {
return {record_var(
lazy_val->node(), lazy_val->bound_data(), rename_value->name())};
} else {
return imperative::apply(op, inputs);
}
} else if (op.is<GetName>()) {
if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) {
if (auto* lazy_val = inputs.item().as(m_value_type)) {
auto name = lazy_val->name();
if (!name.empty()) {
return {StringValue::make(lazy_val->name())};
......@@ -255,7 +255,7 @@ void LazyEvalTransformation::on_unregister() noexcept {
DeviceStorage::make(data.storage()))[0]);
}
for (auto&& lazy_val : lazy_vals) {
if (lazy_val.is<LazyEvalValue>()) {
if (lazy_val.is(m_value_type)) {
std::string repr =
ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str());
mgb_log_debug("%s", repr.c_str());
......
......@@ -20,7 +20,8 @@ namespace imperative {
namespace {
using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>);
using ScalarRule = ValueRefList (*)(
const OpDef&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&);
static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules;
ValueRef make_scalar_shape(CompNode device) {
......@@ -41,17 +42,22 @@ bool is_scalar_shape(ValueRef shape) {
return *shape_of_shape == ValueShape{0};
}
template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)>
template <
typename T,
ValueRefList (*rule)(
const T&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&)>
void register_scalar_rule() {
scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask);
Span<bool> inputs_mask,
const Type<ScalarValue>& value_type) {
return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask, value_type);
};
}
template <typename TOpDef, size_t nr_inputs>
ValueRefList elemwise_rule(
const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
if constexpr (nr_inputs != 0) {
mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch");
}
......@@ -63,27 +69,29 @@ ValueRefList elemwise_rule(
}
auto outputs = imperative::apply(op_def, inputs);
if (all_scalar) {
outputs[0] = ScalarValue::make(outputs[0]);
outputs[0] = scalar_type.make(outputs[0]);
}
return outputs;
}
ValueRefList remove_axis_rule(
const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(!inputs_mask.item());
bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size();
if (is_scalar && remove_axis.axis.size() == 1) {
return {ScalarValue::make(inputs.item())};
return {scalar_type.make(inputs.item())};
}
auto outputs = imperative::apply(remove_axis, inputs);
if (is_scalar) {
outputs[0] = ScalarValue::make(outputs[0]);
outputs[0] = scalar_type.make(outputs[0]);
}
return outputs;
}
ValueRefList reduce_rule(
const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
if (inputs.size() == 1) {
return imperative::apply(reduce, inputs);
}
......@@ -91,7 +99,7 @@ ValueRefList reduce_rule(
bool is_scalar = is_scalar_shape(inputs[1]);
if (is_scalar) {
CompNode device = *inputs[0].device();
return {ScalarValue::make(
return {scalar_type.make(
imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])};
}
return imperative::apply(reduce, inputs);
......@@ -99,7 +107,7 @@ ValueRefList reduce_rule(
ValueRefList collective_comm_rule(
const CollectiveComm& collective_comm, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 1);
static std::unordered_set<CollectiveComm::Mode> modes = {
CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN,
......@@ -110,7 +118,7 @@ ValueRefList collective_comm_rule(
return imperative::apply(collective_comm, inputs);
}
if (inputs_mask.item()) {
return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])};
return {scalar_type.make(imperative::apply(collective_comm, inputs[0])[0])};
} else {
return imperative::apply(collective_comm, inputs);
}
......@@ -118,24 +126,27 @@ ValueRefList collective_comm_rule(
ValueRefList param_pack_split_rule(
const ParamPackSplit& param_pack_split, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
auto outputs = imperative::apply(param_pack_split, inputs);
size_t nr_outputs = outputs.size();
mgb_assert(nr_outputs == param_pack_split.shapes.size());
for (size_t i = 0; i < nr_outputs; ++i) {
if (param_pack_split.shapes[i].empty()) {
outputs[i] = ScalarValue::make(outputs[i]);
outputs[i] = scalar_type.make(outputs[i]);
}
}
return outputs;
}
ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) {
return {ScalarValue::make(imperative::apply(dot, inputs)[0])};
ValueRefList dot_rule(
const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
return {scalar_type.make(imperative::apply(dot, inputs)[0])};
}
ValueRefList add_axis_rule(
const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 1);
if (inputs_mask.item()) {
mgb_assert(add_axis.axis[0] == 0);
......@@ -151,7 +162,8 @@ ValueRefList add_axis_rule(
}
ValueRefList remote_recv_rule(
const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
if (remote_recv.shape.empty()) {
std::vector<int32_t> shape = {1};
auto remote_recv_no_scalar = RemoteRecv::make(
......@@ -167,20 +179,21 @@ ValueRefList remote_recv_rule(
ValueRefList check_no_finite_rule(
const CheckNonFinite& check_no_finite, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) {
auto outputs = imperative::apply(check_no_finite, inputs);
mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch");
outputs.back() = ScalarValue::make(outputs.back());
outputs.back() = scalar_type.make(outputs.back());
for (size_t i = 0; i < inputs.size(); ++i) {
if (inputs_mask[i]) {
outputs[i] = ScalarValue::make(outputs[i]);
outputs[i] = scalar_type.make(outputs[i]);
}
}
return outputs;
}
ValueRefList subtensor_rule(
const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() >= 1);
auto input = inputs[0];
bool is_scalar;
......@@ -199,14 +212,14 @@ ValueRefList subtensor_rule(
}
auto outputs = imperative::apply(subtensor, inputs);
if (is_scalar) {
outputs[0] = ScalarValue::make(outputs[0]);
outputs[0] = scalar_type.make(outputs[0]);
}
return outputs;
}
ValueRefList get_var_shape_rule(
const GetVarShape& get_var_shape, Span<ValueRef> inputs,
Span<bool> inputs_mask) {
const GetVarShape& get_var_shape, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
bool all_scalar = true;
mgb_assert(inputs.size() >= 1);
for (auto&& input_mask : inputs_mask) {
......@@ -228,11 +241,12 @@ ValueRefList get_var_shape_rule(
}
ValueRefList reshape_rule(
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]);
if (is_scalar) {
return {ScalarValue::make(imperative::apply(
return {scalar_type.make(imperative::apply(
reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
} else {
return imperative::apply(reshape, inputs);
......@@ -240,11 +254,12 @@ ValueRefList reshape_rule(
}
ValueRefList broadcast_rule(
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) {
const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask,
const Type<ScalarValue>& scalar_type) {
mgb_assert(inputs.size() == 2);
bool is_scalar = is_scalar_shape(inputs[1]);
if (is_scalar) {
return {ScalarValue::make(imperative::apply(
return {scalar_type.make(imperative::apply(
broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])};
} else {
return imperative::apply(broadcast, inputs);
......@@ -299,11 +314,11 @@ struct ScalarRuleRegistry {
ValueRefList ScalarTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) {
auto&& input = inputs.item();
bool is_scalar = input.is<ScalarValue>();
bool is_scalar = input.is(m_value_type);
if (!is_scalar) {
return imperative::apply(get_attr, input);
}
auto unwrapped_input = input.cast<ScalarValue>().value();
auto unwrapped_input = input.cast(m_value_type).value();
if (get_attr.attr() == GetAttr::Shape) {
if (!m_empty_shape) {
m_empty_shape = ShapeValue::make();
......@@ -352,7 +367,7 @@ ValueRefList ScalarTransformation::apply_transformation(
ValueRefList unwrapped_inputs(nr_inputs);
SmallVector<bool> inputs_mask(nr_inputs);
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto&& scalar_value = inputs[i].as_ref<ScalarValue>()) {
if (auto&& scalar_value = inputs[i].as_ref(m_value_type)) {
unwrapped_inputs[i] = scalar_value->value();
inputs_mask[i] = true;
} else {
......@@ -364,7 +379,8 @@ ValueRefList ScalarTransformation::apply_transformation(
if (auto apply_op = op.as<ApplyOp>()) {
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo());
if (iter != scalar_rules.end()) {
return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask);
return iter->second(
apply_op->op(), unwrapped_inputs, inputs_mask, m_value_type);
} else {
// TODO: repeat op
return fallback();
......@@ -375,7 +391,7 @@ ValueRefList ScalarTransformation::apply_transformation(
CreateTensor scalar_op(
create_tensor->kind(), create_tensor->device(),
create_tensor->dtype(), scalar_shape);
return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])};
return {m_value_type.make(imperative::apply(scalar_op, inputs)[0])};
} else {
return imperative::apply(op, inputs);
}
......@@ -387,7 +403,7 @@ ValueRefList ScalarTransformation::apply_transformation(
bool is_scalar = inputs_mask[0];
auto outputs = fallback();
if (is_scalar) {
outputs[0] = ScalarValue::make(outputs[0]);
outputs[0] = m_value_type.make(outputs[0]);
}
return outputs;
} else {
......
......@@ -160,7 +160,7 @@ ValueRefList TracingTransformation::apply_transformation(
SmallVector<TracingValue::ref_t> wrapped_inputs;
SmallVector<size_t> input_ids;
for (auto input : inputs) {
auto tracing_value = input.as_ref<TracingValue>();
auto tracing_value = input.as_ref(m_value_type);
if (!tracing_value) {
tracing_value =
record_var(input, m_capture_as_const, VarKind::External);
......@@ -208,7 +208,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (auto* get_attr = op.as<GetAttr>()) {
auto unwrapped_input = unwrap_var(inputs[0]);
auto outputs = imperative::apply(op, unwrapped_input);
if (auto* tracing_value = inputs[0].as<TracingValue>()) {
if (auto* tracing_value = inputs[0].as(m_value_type)) {
auto& var_info = m_vars[tracing_value->id()];
switch (get_attr->attr()) {
case GetAttr::Shape:
......@@ -228,7 +228,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (auto* trace_mark_var = op.as<TraceMarkVar>()) {
mgb_assert(inputs.size() == 1, "TraceMarkVar expects exactly one input");
auto input = inputs[0];
auto tracing_var = input.as_ref<TracingValue>();
auto tracing_var = input.as_ref(m_value_type);
if (!tracing_var) {
bool is_input = trace_mark_var->mark().substr(0, 4) == "arg_" ||
trace_mark_var->mark().substr(0, 6) == "kwarg_";
......@@ -247,7 +247,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (auto* trace_name_var = op.as<RenameValue>()) {
mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input");
auto input = inputs[0];
auto tracing_var = input.as_ref<TracingValue>();
auto tracing_var = input.as_ref(m_value_type);
if (!tracing_var) {
tracing_var = record_var(input, m_capture_as_const, VarKind::External);
} else {
......@@ -260,7 +260,7 @@ ValueRefList TracingTransformation::apply_transformation(
} else if (op.is<GetName>()) {
mgb_assert(inputs.size() == 1, "GetName expects exactly one input");
auto input = inputs[0];
if (auto tracing_var = input.as_ref<TracingValue>()) {
if (auto tracing_var = input.as_ref(m_value_type)) {
auto name = m_vars[tracing_var->id()].name;
if (!name.empty()) {
return {StringValue::make(name)};
......@@ -425,26 +425,12 @@ void CompiledTransformation::compile() {
}
auto& node = var_accessors[input].node;
if (input_vars.empty() && require_link && mm_io_link.node()) {
/*mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");*/
auto comp_node = mm_io_link.node()->comp_node();
// auto comp_node = input_vars[0]->comp_node();
node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node)
.node();
}
input_vars.push_back(node);
}
/*if (require_link && mm_io_link.node()) {
mgb_assert(
!input_vars.empty(),
"io-mm operator should have at least one input");
auto comp_node = mm_io_link.node()->comp_node();
// auto comp_node = input_vars[0]->comp_node();
input_vars[0] = opr::VirtualDep::make(
{SymbolVar(input_vars[0]), mm_io_link}, comp_node)
.node();
}*/
VarNodeArray output_vars;
if (item.op) {
output_vars = OpDef::apply_on_var_node(*item.op, input_vars);
......@@ -520,7 +506,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
switch (var.kind) {
case VarKind::External: {
trace_assert(
!value.is<TracedValue>(), "expect external node, got internal");
!value.is(m_value_type), "expect external node, got internal");
if (var.bound_data) {
assert_tensor_equal(var.bound_data, value);
} else {
......@@ -545,8 +531,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
}
case VarKind::Internal: {
trace_assert(
value.is<TracedValue>(), "expect internal node, got external");
auto& traced_value = value.cast<TracedValue>();
value.is(m_value_type), "expect internal node, got external");
auto& traced_value = value.cast(m_value_type);
trace_assert(traced_value.id() == id, "input id mismatch");
break;
}
......@@ -559,7 +545,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) {
}
auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t {
auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]);
auto traced_value = m_value_type.make(id, &m_vars[id], &m_var_accessors[id]);
m_weak_values.push_back(traced_value);
return traced_value;
}
......@@ -569,7 +555,7 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() {
return m_seq[m_pc++];
}
ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
ShapeValue::ref_t CompiledTransformation::TracedValue::shape() const {
if (!m_shape) {
trace_assert(m_accessor->shape_getter, "shape unreadable");
m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter()));
......@@ -577,14 +563,14 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const {
return m_shape;
}
DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const {
DTypeValue::ref_t CompiledTransformation::TracedValue::dtype() const {
return m_var->dtype;
}
CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const {
CompNodeValue::ref_t CompiledTransformation::TracedValue::comp_node() const {
return m_var->device;
}
auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& {
auto CompiledTransformation::TracedValue::accessor() const -> const VarAccessor& {
return *m_accessor;
}
......@@ -605,7 +591,7 @@ ValueRefList CompiledTransformation::apply_op(
ValueRefList CompiledTransformation::apply_get_attr(
const GetAttr& get_attr, Span<ValueRef> inputs) {
if (auto* traced_value = inputs[0].as<TracedValue>()) {
if (auto* traced_value = inputs[0].as(m_value_type)) {
ValueRef output;
auto& var_accessor = traced_value->accessor();
switch (get_attr.attr()) {
......@@ -718,15 +704,11 @@ void CompiledTransformation::on_unregister() noexcept {
void CompiledTransformation::execute() {
mgb_assert(m_executable != nullptr);
m_graph_executor = std::thread([&] {
try {
m_executable->execute();
m_executable->wait();
} catch (...) {
auto exc = std::current_exception();
set_exception(exc);
}
});
{
MGB_LOCK_GUARD(m_mutex);
m_graph_status = 1;
}
m_cv.notify_all();
}
void CompiledTransformation::wait() {
......@@ -735,8 +717,9 @@ void CompiledTransformation::wait() {
} catch (...) {
}
mgb_assert(m_executable != nullptr);
m_graph_executor.join();
m_graph_executor = {};
std::unique_lock lock{m_mutex};
m_cv.wait(lock, [&] { return m_graph_status == 0; });
lock.unlock();
for (auto&& box : m_boxes) {
box->reset();
}
......
......@@ -25,16 +25,16 @@ ValueRef::storage_t& ValueRef::storage() const {
return m_storage;
}
const Value* ValueRef::as(size_t typecode) const {
const Value* ValueRef::as(const IType& type) const {
auto&& storage = this->storage();
if (storage->m_typecode != typecode) {
if (storage->type() != type) {
return nullptr;
}
return static_cast<Value*>(storage.get());
}
bool ValueRef::is(size_t typecode) const {
return this->storage()->m_typecode == typecode;
bool ValueRef::is(const IType& type) const {
return this->storage()->type() == type;
}
TypedValueRef<DeviceValue> ValueRef::dev_tensor() const {
......@@ -106,9 +106,7 @@ std::string ValueRef::raw_type() const {
if (!m_storage) {
return "null";
}
auto& types = Value::registered_types();
mgb_assert(types.size() > m_storage->m_typecode);
return types[m_storage->m_typecode].name();
return m_storage->type().name();
}
bool ValueRef::watching() const {
......@@ -137,7 +135,7 @@ ValueRef ValueWeakRef::lock() {
return {strong_storage};
}
Value::Value(size_t typecode) : m_typecode{typecode} {
Value::Value() {
m_id = nr_values++;
}
......@@ -147,17 +145,6 @@ Value::~Value() {
}
}
size_t Value::register_type(std::type_index type) {
auto& types = const_cast<std::vector<std::type_index>&>(registered_types());
types.push_back(type);
return types.size() - 1;
}
const std::vector<std::type_index>& Value::registered_types() {
static std::vector<std::type_index> sm_registered_types;
return sm_registered_types;
}
void Value::register_value(ValueRef value) {
registered_values[value.id()] = ValueWeakRef(value);
}
......@@ -188,7 +175,7 @@ std::vector<ValueRef> Value::end_record_values() {
}
void Value::try_rethrow() {
if (m_typecode == ErrorValue::TYPE_CODE) {
if (type() == PrimitiveType<ErrorValue>::instance) {
auto message = static_cast<ErrorValue*>(this)->message();
mgb_throw(MegBrainError, "invalid value: %s", message.c_str());
}
......@@ -198,13 +185,9 @@ inline void ValueRefList::init(size_t nr_elems) {
m_size = nr_elems;
if (m_size > 0) {
if (m_size == 1) {
m_data = inline_storage();
m_data = new (inline_storage()) ValueRef();
} else {
auto& context = Transformation::get_context();
m_data = context.allocator.allocate(m_size);
}
for (size_t i = 0; i < m_size; ++i) {
new (m_data + i) ValueRef();
m_data = new ValueRef[m_size];
}
} else {
m_data = nullptr;
......@@ -215,9 +198,6 @@ ValueRefList::ValueRefList(size_t nr_elems) {
init(nr_elems);
}
/*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
: ValueRefList(values.begin(), values.end()) {}*/
ValueRefList::ValueRefList(const ValueRefList& rhs)
: ValueRefList(rhs.cbegin(), rhs.cend()) {}
......@@ -271,14 +251,12 @@ ValueRefList::~ValueRefList() {
}
void ValueRefList::clear() {
for (size_t i = 0; i < m_size; ++i) {
m_data[i].~ValueRef();
}
if (m_data) {
if (m_size != 1) {
Transformation::get_context().allocator.deallocate(m_data, m_size);
delete[] m_data;
} else {
mgb_assert(m_data == inline_storage());
m_data->~ValueRef();
}
}
m_data = nullptr;
......
......@@ -25,79 +25,68 @@ class GradKey;
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>;
class ShapeValue final
: public MixinValueImpl<ShapeValue, ValueKind::Primitive, ValueShape> {
class ShapeValue final : public PrimitiveValue<ShapeValue, ValueShape> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class CompNodeValue final
: public MixinValueImpl<CompNodeValue, ValueKind::Primitive, CompNode> {
class CompNodeValue final : public PrimitiveValue<CompNodeValue, CompNode> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
// TODO: override factory method
class BoolValue final : public ValueImpl<BoolValue, ValueKind::Primitive> {
class Boolean {
private:
std::optional<bool> m_value;
bool m_value;
public:
BoolValue(bool value) : m_value{value} {}
operator bool() const { return *m_value; }
Boolean() = default;
Boolean(bool value) : m_value(value) {}
std::string to_string() const override;
operator bool() const { return m_value; }
};
void clear() override { m_value.reset(); }
// TODO: override factory method
class BoolValue final : public PrimitiveValue<BoolValue, Boolean> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class HostStorage final
: public MixinValueImpl<HostStorage, ValueKind::Primitive, HostTensorStorage> {
class HostStorage final : public PrimitiveValue<HostStorage, HostTensorStorage> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class DeviceStorage final
: public MixinValueImpl<
DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> {
class DeviceStorage final : public PrimitiveValue<DeviceStorage, DeviceTensorStorage> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
/**
* \brief like HostTensorND mixin, but allow scalar value
*
*/
class HostValue final : public ValueImpl<HostValue, ValueKind::Primitive> {
class HostTensor {
private:
DType m_dtype;
ValueShape m_shape;
HostTensorStorage m_storage;
public:
HostValue(DType dtype, ValueShape shape, HostTensorStorage storage)
HostTensor() = default;
HostTensor(DType dtype, ValueShape shape, HostTensorStorage storage)
: m_dtype(dtype), m_shape(shape), m_storage(storage) {}
HostValue(HostTensorND value)
: HostValue(
HostTensor(HostTensorND value)
: HostTensor(
value.dtype(), ValueShape::from(value.shape()), value.storage()) {
}
std::string to_string() const override;
void clear() override {
m_dtype = {};
m_shape = {};
m_storage = {};
}
DType dtype() const { return m_dtype; }
const ValueShape& shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); }
......@@ -112,31 +101,31 @@ public:
};
/**
* \brief like DeviceTensorND mixin, but allow scalar value
* \brief like HostTensorND mixin, but allow scalar value
*
*/
class DeviceValue final : public ValueImpl<DeviceValue, ValueKind::Primitive> {
class HostValue final : public PrimitiveValue<HostValue, HostTensor> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class DeviceTensor {
private:
DType m_dtype;
ValueShape m_shape;
DeviceTensorStorage m_storage;
public:
DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage)
DeviceTensor() = default;
DeviceTensor(DType dtype, ValueShape shape, DeviceTensorStorage storage)
: m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {}
DeviceValue(const DeviceTensorND& value)
: DeviceValue(
DeviceTensor(const DeviceTensorND& value)
: DeviceTensor(
value.dtype(), ValueShape::from(value.shape()), value.storage()) {
}
std::string to_string() const override;
void clear() override {
m_dtype = {};
m_shape = {};
m_storage = {};
}
DType dtype() const { return m_dtype; }
const ValueShape& shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); }
......@@ -145,26 +134,34 @@ public:
DeviceTensorND as_nd(bool allow_scalar = false) const;
};
class FunctionValue final
: public MixinValueImpl<FunctionValue, ValueKind::Primitive, GenericFunction> {
/**
* \brief like DeviceTensorND mixin, but allow scalar value
*
*/
class DeviceValue final : public PrimitiveValue<DeviceValue, DeviceTensor> {
public:
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class FunctionValue final : public PrimitiveValue<FunctionValue, GenericFunction> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class DTypeValue final
: public MixinValueImpl<DTypeValue, ValueKind::Primitive, DType> {
class DTypeValue final : public PrimitiveValue<DTypeValue, DType> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
class StringValue final
: public MixinValueImpl<StringValue, ValueKind::Primitive, std::string> {
class StringValue final : public PrimitiveValue<StringValue, std::string> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
......@@ -180,10 +177,9 @@ public:
std::string message() const { return m_message; }
};
class ErrorValue final
: public MixinValueImpl<ErrorValue, ValueKind::Primitive, Error> {
class ErrorValue final : public PrimitiveValue<ErrorValue, Error> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override;
};
......
......@@ -57,7 +57,7 @@ struct Subgraph {
SmallVector<expr_t> exprs;
template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const {
SmallVector<T> apply(Span<T> input_vars, F&& f, C&& c) const {
std::unordered_map<size_t, T> idx2var;
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
......@@ -71,8 +71,7 @@ struct Subgraph {
for (auto idx : expr.inputs) {
expr_inputs.push_back(idx2var[idx]);
}
SmallVector<T> expr_outputs =
f(expr.op, std::move(expr_inputs), expr.outputs.size());
SmallVector<T> expr_outputs = f(expr.op, expr_inputs, expr.outputs.size());
mgb_assert(
expr_outputs.size() == expr.outputs.size(), "output size mismatch");
for (size_t i = 0; i < expr_outputs.size(); ++i) {
......@@ -102,9 +101,9 @@ struct EncodedSubgraph {
SmallVector<bool> input_mask;
SmallVector<bool> output_mask;
template <typename TContainer>
TContainer encode_inputs(TContainer inputs) const {
TContainer encoded_inputs;
template <typename T>
SmallVector<T> encode_inputs(Span<T> inputs) const {
SmallVector<T> encoded_inputs;
size_t index = 0;
for (auto&& input : inputs) {
mgb_assert(index < input_mask.size(), "index out of range");
......@@ -116,9 +115,9 @@ struct EncodedSubgraph {
return encoded_inputs;
}
template <typename TContainer>
TContainer encode_outputs(TContainer outputs) const {
TContainer encoded_outputs;
template <typename T>
SmallVector<T> encode_outputs(Span<T> outputs) const {
SmallVector<T> encoded_outputs;
size_t index = 0;
for (auto&& output : outputs) {
mgb_assert(index < output_mask.size(), "index out of range");
......@@ -130,9 +129,9 @@ struct EncodedSubgraph {
return encoded_outputs;
}
template <typename TContainer>
TContainer decode_outputs(TContainer outputs) const {
TContainer decoded_outputs;
template <typename T>
SmallVector<T> decode_outputs(Span<T> outputs) const {
SmallVector<T> decoded_outputs;
size_t index = 0;
for (size_t i = 0; i < output_mask.size(); i++) {
mgb_assert(index < output_mask.size(), "index out of range");
......@@ -150,8 +149,8 @@ struct EncodedSubgraph {
EncodedSubgraph result;
result.input_mask = graph.gen_input_mask();
result.output_mask = graph.gen_output_mask();
graph.inputs = result.encode_inputs(graph.inputs);
graph.outputs = result.encode_outputs(graph.outputs);
graph.inputs = result.encode_inputs<Subgraph::var_t>(graph.inputs);
graph.outputs = result.encode_outputs<Subgraph::var_t>(graph.outputs);
result.graph = graph;
return result;
}
......@@ -179,11 +178,11 @@ struct EncodedSubgraph {
}
template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const {
auto encoded_inputs = encode_inputs(input_vars);
SmallVector<T> apply(Span<T> input_vars, F&& f, C&& c) const {
auto encoded_inputs = encode_inputs<T>(input_vars);
auto encoded_outputs =
graph.apply(encoded_inputs, std::forward<F>(f), std::forward<C>(c));
return decode_outputs(encoded_outputs);
graph.apply<T>(encoded_inputs, std::forward<F>(f), std::forward<C>(c));
return decode_outputs<T>(encoded_outputs);
}
std::string repr() const;
......@@ -280,4 +279,4 @@ public:
};
} // namespace imperative
} // namespace mgb
\ No newline at end of file
} // namespace mgb
......@@ -18,7 +18,7 @@
namespace mgb::imperative {
struct InterpreterInfo {
class InterpreterValue final : public ObjectValue<InterpreterValue> {
public:
using Handle = interpreter::Interpreter::Handle;
using Channel = interpreter::Interpreter::Channel;
......@@ -46,8 +46,7 @@ private:
mutable ShapeValue::ref_t m_shape;
public:
InterpreterInfo() = default;
InterpreterInfo(LocalPtr<RAIIHandle> handle, std::string name = {})
InterpreterValue(LocalPtr<RAIIHandle> handle, std::string name = {})
: m_handle(handle), m_name(name) {}
const LocalPtr<RAIIHandle>& handle() const { return m_handle; }
......@@ -57,18 +56,14 @@ public:
ShapeValue::ref_t shape() const;
std::string name() const { return m_name; }
};
class InterpreterValue final
: public MixinValueImpl<InterpreterValue, ValueKind::Object, InterpreterInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf(
"Handle{ptr=%p, name=%s}", handle().get(),
imperative::quoted(name()).c_str());
}
void clear() override { m_handle = {}; }
};
/**
......@@ -82,11 +77,12 @@ class InterpreterTransformation final : public Transformation {
public:
using Interpreter = interpreter::Interpreter;
using Handle = Interpreter::Handle;
using SharedHandle = LocalPtr<InterpreterInfo::RAIIHandle>;
using SharedHandle = LocalPtr<InterpreterValue::RAIIHandle>;
using Channel = Interpreter::Channel;
private:
std::shared_ptr<Channel> m_channel;
ObjectType<InterpreterValue> m_value_type{"InterpreterValue"};
public:
explicit InterpreterTransformation(std::shared_ptr<Channel> channel)
......@@ -105,7 +101,7 @@ public:
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<InterpreterValue>());
mgb_assert(!value.is(m_value_type));
return value;
}
......
......@@ -34,7 +34,8 @@ struct BackwardGraphWithClosure {
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph,
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs);
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver);
void operator()(
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; }
......@@ -51,7 +52,7 @@ struct CustomBackward;
using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>;
struct CustomBackward {
using BackwardFn = std::function<ValueRefList(Span<ValueRef>)>;
using BackwardFn = std::function<SmallVector<ValueRef>(Span<ValueRef>)>;
using BackwardRule = std::function<std::optional<ValueRefList>(
const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>;
BackwardFn m_backward;
......@@ -62,7 +63,8 @@ struct CustomBackward {
SmallVector<OutputAttr> m_output_attrs;
public:
void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver);
void operator()(
Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver);
bool input_has_grad(size_t i) { return m_input_has_grad[i]; }
bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; }
......@@ -175,7 +177,7 @@ inline GradSlot* GradSlotPtr::operator->() const {
return &m_fn->m_slots[m_index];
}
class GradValue final : public ValueImpl<GradValue, ValueKind::Object> {
class GradValue final : public ObjectValue<GradValue> {
private:
ValueRef m_value;
std::shared_ptr<GradKey> m_key;
......@@ -187,14 +189,9 @@ public:
std::string to_string() const override;
bool has_key(const std::shared_ptr<GradKey>& key) const { return m_key == key; }
const GradSlotPtr& slot() const { return m_slot; }
const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const {
mgb_assert(m_key == key);
return m_slot;
}
std::shared_ptr<GradKey> key() const { return m_key; }
// std::shared_ptr<GradKey> key() const { return m_key; }
void clear() override {
m_slot = {};
......@@ -216,9 +213,12 @@ private:
std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape;
std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape;
bool m_frozen = false;
const Type<GradValue>& m_value_type;
public:
GradKey() { m_tape.reserve(4 * 1024); }
GradKey(const Type<GradValue>& value_type) : m_value_type(value_type) {
m_tape.reserve(4 * 1024);
}
void backward();
GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback);
......@@ -230,10 +230,9 @@ public:
};
class GradKeyValue final
: public MixinValueImpl<
GradKeyValue, ValueKind::Primitive, std::shared_ptr<GradKey>> {
: public PrimitiveValue<GradKeyValue, std::shared_ptr<GradKey>> {
public:
using MixinValueImpl::MixinValueImpl;
using PrimitiveValue::PrimitiveValue;
std::string to_string() const override {
return ssprintf("GradKey{%s}", (*this)->name().c_str());
......@@ -242,26 +241,20 @@ public:
class GradTransformation final : public Transformation {
private:
ObjectType<GradValue> m_value_type{"GradValue"};
std::shared_ptr<GradKey> m_key;
std::vector<GradValue::weak_ref_t> m_weak_values;
size_t m_suppressed = 0;
public:
GradTransformation(std::shared_ptr<GradKey> key) : m_key(key) {}
GradTransformation() { m_key = std::make_shared<GradKey>(m_value_type); }
auto record_grad(GradValue::ref_t tensor) {
m_weak_values.push_back(tensor);
return tensor;
}
bool is_grad_value(const ValueRef& value) {
if (auto* grad_value = value.as<GradValue>()) {
if (grad_value->has_key(m_key)) {
return true;
}
}
return false;
}
bool is_grad_value(const ValueRef& value) { return value.is(m_value_type); }
/**
* \brief test whether value is related to this GradTransformation
......@@ -273,13 +266,7 @@ public:
* \return GradValue::ref_t
*/
const GradValue::ref_t& as_grad_value(const ValueRef& value) {
auto&& grad_value = value.as_ref<GradValue>();
if (grad_value) {
if (grad_value->has_key(m_key)) {
return grad_value;
}
}
return GradValue::ref_t::nil;
return value.as_ref(m_value_type);
}
bool has_key(std::shared_ptr<GradKey> key) {
......@@ -299,6 +286,8 @@ public:
return value;
}
const std::shared_ptr<GradKey>& key() const { return m_key; }
std::string name() const override { return "GradTransformation"; }
GenericFunction make_backward_closure(Span<ValueRef> ys);
......
......@@ -22,32 +22,27 @@
namespace mgb::imperative {
class LazyEvalInfo {
class LazyEvalValue final : public ObjectValue<LazyEvalValue> {
private:
VarNode* m_node = nullptr;
ValueRef m_bound_data;
std::string m_name;
public:
LazyEvalInfo() = default;
LazyEvalInfo(VarNode* node, ValueRef bound_data, std::string name)
LazyEvalValue(VarNode* node, ValueRef bound_data, std::string name)
: m_node(node), m_bound_data(bound_data), m_name(name) {}
VarNode* node() const { return m_node; }
ValueRef bound_data() const { return m_bound_data; }
std::string name() const { return m_name; }
};
class LazyEvalValue final
: public MixinValueImpl<LazyEvalValue, ValueKind::Object, LazyEvalInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf(
"LazyEvalValue{node=%p, name=%s}", node(), node()->name().c_str());
}
void clear() override {}
};
/**
......@@ -67,6 +62,7 @@ private:
std::vector<LazyEvalValue::weak_ref_t> m_weak_vars;
SymbolVar m_io_link = nullptr;
std::exception_ptr m_graph_exc;
ObjectType<LazyEvalValue> m_value_type{"LazyEvalValue"};
public:
LazyEvalTransformation(bool no_exec) : m_no_exec(no_exec) {
......@@ -75,7 +71,7 @@ public:
LazyEvalValue::ref_t record_var(
VarNode* node, ValueRef bound_data = {}, std::string name = {}) {
auto lazy_eval_val = LazyEvalValue::make(node, bound_data, name);
auto lazy_eval_val = m_value_type.make(node, bound_data, name);
m_weak_vars.push_back(lazy_eval_val);
return lazy_eval_val;
}
......@@ -86,7 +82,7 @@ public:
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<LazyEvalValue>());
mgb_assert(!value.is(m_value_type));
return value;
}
......
......@@ -17,7 +17,7 @@
namespace mgb::imperative {
class ScalarValue final : public ValueImpl<ScalarValue, ValueKind::Object> {
class ScalarValue final : public ObjectValue<ScalarValue> {
private:
ValueRef m_value;
......@@ -47,17 +47,21 @@ public:
class ScalarTransformation final : public Transformation {
private:
ShapeValue::ref_t m_empty_shape; // []
ObjectType<ScalarValue> m_value_type{"ScalarValue"};
public:
ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs);
ValueRefList apply_transformation(
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<ScalarValue>());
mgb_assert(!value.is(m_value_type));
return value;
}
std::string name() const override { return "ScalarTransformation"; }
const Type<ScalarValue>& value_type() const { return m_value_type; }
};
} // namespace mgb::imperative
......@@ -22,7 +22,7 @@
namespace mgb::imperative {
class SymbolValue final : public ValueImpl<SymbolValue, ValueKind::Object> {
class SymbolValue final : public ObjectValue<SymbolValue> {
private:
VarNode* m_node = nullptr;
......@@ -47,6 +47,7 @@ public:
class SymbolTransformation final : public Transformation {
private:
ComputingGraph* m_graph = nullptr;
ObjectType<SymbolValue> m_value_type{"SymbolValue"};
public:
SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {}
......@@ -55,12 +56,12 @@ public:
if (auto* apply_op = op.as<ApplyOp>()) {
SmallVector<VarNode*> input_nodes;
for (auto&& input : inputs) {
input_nodes.push_back(input.cast<SymbolValue>().node());
input_nodes.push_back(input.cast(m_value_type).node());
}
auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes);
ValueRefList outputs(output_nodes.size());
for (size_t i = 0; i < output_nodes.size(); ++i) {
outputs[i] = SymbolValue::make(output_nodes[i]);
outputs[i] = m_value_type.make(output_nodes[i]);
}
return outputs;
} else if (auto* create_tensor = op.as<CreateTensor>()) {
......@@ -69,9 +70,9 @@ public:
args.kind == CreateTensor::Const,
"only const value is allowed here");
auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node();
return {SymbolValue::make(node)};
return {m_value_type.make(node)};
} else if (auto* get_attr = op.as<GetAttr>()) {
auto* node = inputs.as_array<1>()[0].cast<SymbolValue>().node();
auto* node = inputs.item().cast(m_value_type).node();
switch (get_attr->attr()) {
case GetAttr::DType:
return {DTypeValue::make(node->dtype())};
......@@ -121,11 +122,13 @@ public:
}
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<SymbolValue>(), "SymbolValue doesn't support unwrap");
mgb_assert(!value.is(m_value_type), "SymbolValue doesn't support unwrap");
return value;
}
std::string name() const override { return "SymbolTransformation"; }
const Type<SymbolValue>& value_type() const { return m_value_type; }
};
} // namespace mgb::imperative
......@@ -100,22 +100,15 @@ public:
}
};
class TracingInfo {
class TracingValue final : public ObjectValue<TracingValue> {
private:
ValueRef m_value = {};
size_t m_id = 0;
public:
TracingInfo() = default;
TracingInfo(ValueRef value, size_t id) : m_value(value), m_id(id) {}
TracingValue(ValueRef value, size_t id) : m_value(value), m_id(id) {}
ValueRef value() const { return m_value; }
size_t id() const { return m_id; }
};
class TracingValue final
: public MixinValueImpl<TracingValue, ValueKind::Object, TracingInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf(
......@@ -126,6 +119,8 @@ public:
void on_watch() override { value().watch(); }
void on_unwatch() override { value().unwatch(); }
void clear() override { m_value = {}; }
};
/**
......@@ -146,6 +141,7 @@ private:
std::vector<TracingValue::weak_ref_t> m_weak_vars;
bool m_capture_as_const = false;
bool m_record_input_shapes = false;
ObjectType<TracingValue> m_value_type{"TracingValue"};
public:
TracingTransformation(bool capture_as_const, bool record_input_shapes)
......@@ -162,7 +158,7 @@ public:
*/
TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) {
size_t id = m_vars.size();
auto wrapped_value = TracingValue::make(value, id);
auto wrapped_value = m_value_type.make(value, id);
m_vars.push_back({id, value.dtype(), value.device()});
auto& var = m_vars.back();
if (capture) {
......@@ -179,7 +175,7 @@ public:
return wrapped_value;
}
ValueRef unwrap_var(ValueRef value) {
if (auto* tracing_value = value.as<TracingValue>()) {
if (auto* tracing_value = value.as(m_value_type)) {
return tracing_value->value();
}
return value;
......@@ -189,7 +185,7 @@ public:
const Operator& op, Span<ValueRef> inputs) override;
ValueRef unwrap(ValueRef value) override {
if (auto* tracing_value = value.as<TracingValue>()) {
if (auto* tracing_value = value.as(m_value_type)) {
return tracing_value->value();
}
return value;
......@@ -234,7 +230,7 @@ public:
std::function<void(std::exception_ptr)> exc_setter;
};
class TracedInfo {
class TracedValue final : public ObjectValue<TracedValue> {
private:
size_t m_id = 0;
VarInfo* m_var = nullptr;
......@@ -244,8 +240,7 @@ public:
mutable CompNodeValue::ref_t m_comp_node;
public:
TracedInfo() = default;
TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor)
TracedValue(size_t id, VarInfo* var, VarAccessor* accessor)
: m_id(id), m_var(var), m_accessor(accessor) {}
size_t id() const { return m_id; }
ShapeValue::ref_t shape() const;
......@@ -256,16 +251,12 @@ public:
void set_exception(std::exception_ptr exc) const {
m_accessor->exc_setter(exc);
}
};
class TracedValue final
: public MixinValueImpl<TracedValue, ValueKind::Object, TracedInfo> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override {
return ssprintf("TracedValue{\"id\"=%zu}", id());
}
void clear() override {}
};
private:
......@@ -280,9 +271,12 @@ private:
std::function<bool(ValueRef, ValueRef)> m_value_comparator;
bool m_input_shape_static;
std::mutex m_mutex;
std::condition_variable m_cv;
std::exception_ptr m_graph_exc;
int m_graph_status = 0; // 0 = stop, 1 = running, 2 = finalizing
std::vector<std::shared_ptr<BoxBase>> m_boxes;
ComputingGraph::OutputSpec m_output_spec;
ObjectType<TracedValue> m_value_type{"TracedValue"};
public:
CompiledTransformation(TraceResult result, bool input_shape_static)
......@@ -292,6 +286,27 @@ public:
m_graph = ComputingGraph::make();
options().no_force_inplace = true;
options().async_exec_level = 0b100;
m_graph_executor = std::thread([&] {
while (true) {
std::unique_lock lock{m_mutex};
m_cv.wait(lock, [&] { return m_graph_status != 0; });
lock.unlock();
if (m_graph_status == 2) {
break;
}
try {
m_executable->execute();
m_executable->wait();
} catch (...) {
auto exc = std::current_exception();
set_exception(exc);
}
lock.lock();
m_graph_status = 0;
lock.unlock();
m_cv.notify_all();
}
});
}
ComputingGraph& graph() { return *m_graph; }
......@@ -350,7 +365,7 @@ public:
void on_unregister() noexcept override;
ValueRef unwrap(ValueRef value) override {
mgb_assert(!value.is<TracedValue>());
mgb_assert(!value.is(m_value_type));
return value;
}
......@@ -368,6 +383,15 @@ public:
m_boxes.push_back(box);
return box;
}
~CompiledTransformation() {
{
MGB_LOCK_GUARD(m_mutex);
m_graph_status = 2;
}
m_cv.notify_all();
m_graph_executor.join();
}
};
} // namespace mgb::imperative
......@@ -11,7 +11,9 @@
#pragma once
#include <optional>
#include <typeindex>
#include <vector>
#include "megbrain/utils/mempool.h"
#include "megbrain/utils/metahelper.h"
......
......@@ -34,7 +34,7 @@ public:
Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {}
Span(const T* begin, size_t size) : Span(begin, begin + size) {}
template <typename TContainer>
Span(TContainer& container) : Span(container.data(), container.size()) {}
Span(const TContainer& container) : Span(container.data(), container.size()) {}
const T* begin() const { return m_begin; }
const T* end() const { return m_end; }
const T* data() const { return m_begin; }
......
......@@ -2,7 +2,10 @@
#include <chrono>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
namespace mgb {
......@@ -18,7 +21,7 @@ public:
private:
clock_t::duration m_duration = clock_t::duration{0};
size_t m_timing = 0;
const char* m_name = nullptr;
std::string m_name;
uint64_t m_count = 0;
size_t m_enabled = 1;
bool m_default_enabled = true;
......@@ -42,7 +45,8 @@ private:
}
if (timer.m_enabled) {
if (!--timer.m_timing) {
timer.m_duration += (clock_t::now() - start);
auto duration = (clock_t::now() - start);
timer.m_duration += duration;
}
timer.m_count++;
}
......@@ -67,13 +71,10 @@ private:
}
};
using TimeScope = TimeScopeRecursive;
public:
Timer(const char* name, bool default_enabled);
Timer(std::string name, bool default_enabled = true);
const char* name() { return m_name; }
auto time_scope() { return TimeScope(*this); }
std::string name() { return m_name; }
auto time_scope_recursive() { return TimeScopeRecursive(*this); };
auto enable_scope() { return EnableScope(*this); }
void reset() {
......@@ -88,7 +89,14 @@ public:
} // namespace stats
struct Stats {
static inline std::vector<stats::Timer*> sm_timers;
struct TimerNode {
std::map<std::string, std::unique_ptr<TimerNode>> children;
stats::Timer* timer = nullptr;
TimerNode() {}
};
static inline TimerNode sm_root;
// register your timers here
// for example:
......@@ -97,33 +105,84 @@ struct Stats {
//
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code
static void print() {
std::vector<const char*> unused_timers;
for (auto* timer : sm_timers) {
if (timer->count() == 0) {
unused_timers.push_back(timer->name());
} else {
printf("%s costs %ld ns, happens %ld times\n", timer->name(),
timer->get().count(), timer->count());
static std::pair<long, long> print_node(
std::string name, TimerNode& node, size_t indent = 0) {
auto print_indent = [&] {
for (size_t i = 0; i < indent; ++i) {
printf(" ");
}
};
long ns = 0, count = 0;
if (auto* timer = node.timer) {
print_indent();
printf("%s costs %'ld ns, hits %'ld times\n", name.c_str(),
(long)timer->get().count(), (long)timer->count());
ns = timer->get().count();
count = timer->count();
}
if (!node.children.empty()) {
bool collect_children = node.timer == nullptr;
if (collect_children) {
print_indent();
printf("%s:\n", name.c_str());
}
long ns = 0, count = 0;
for (auto&& child : node.children) {
auto [child_ns, child_count] =
print_node(child.first, *child.second, indent + 4);
if (collect_children) {
ns += child_ns;
count += child_count;
}
}
if (collect_children) {
print_indent();
printf("total costs %'ld ns, hits %'ld times\n", ns, count);
}
}
return {ns, count};
}
if (!unused_timers.empty()) {
printf("%zu timers unused\n", unused_timers.size());
static void print() {
for (auto&& child : sm_root.children) {
print_node(child.first, *child.second);
}
}
static void reset() {
for (auto* timer : sm_timers) {
timer->reset();
}
auto reset_node = [](TimerNode& node, auto&& reset_node) -> void {
if (auto* timer = node.timer) {
timer->reset();
}
for (auto&& child : node.children) {
reset_node(*child.second, reset_node);
}
};
reset_node(sm_root, reset_node);
}
};
inline stats::Timer::Timer(const char* name, bool default_enabled)
inline stats::Timer::Timer(std::string name, bool default_enabled)
: m_name(name), m_default_enabled(default_enabled) {
Stats::sm_timers.push_back(this);
std::vector<std::string> terms;
Stats::TimerNode* node = &Stats::sm_root;
while (true) {
auto pos = name.find(".");
if (pos == std::string::npos) {
auto& child = node->children[name];
child = std::make_unique<Stats::TimerNode>();
node = child.get();
node->timer = this;
break;
} else {
auto& child = node->children[name.substr(0, pos)];
if (!child) {
child = std::make_unique<Stats::TimerNode>();
}
node = child.get();
name = name.substr(pos + 1);
}
}
}
#if MGE_ENABLE_STATS
......
......@@ -50,18 +50,70 @@ class Operator;
class ValueRefList;
/**
* \brief base class of all value types
*/
class IType : public NonCopyableObj {
private:
std::string m_name;
// TODO: count values, or make an linkedlist
public:
IType(std::string name) : m_name(std::move(name)) {}
const std::string& name() const { return m_name; }
bool operator==(const IType& rhs) const { return this == &rhs; }
bool operator!=(const IType& rhs) const { return this != &rhs; }
};
/**
* \brief type of values.
*
* \tparam T ctype of value
*/
template <typename T>
class Type : public IType {
protected:
Type(std::string name) : IType(std::move(name)) {}
// TODO: each type owns an allocator
public:
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
TypedValueRef<T> make(TArgs&&... args) const;
};
/**
* \brief type of primitive values.
*
* \tparam T ctype of value
*/
template <typename T>
class Type {
class PrimitiveType : public Type<T> {
private:
const size_t m_code = T::TYPE_CODE;
PrimitiveType();
public:
inline size_t code() const { return m_code; }
static inline PrimitiveType instance;
};
enum class ValueKind {
Primitive,
Object,
/**
* \brief type of object values.
*
* \tparam T ctype of value
*/
template <typename T>
class ObjectType : public Type<T> {
public:
ObjectType(std::string name) : Type<T>(name) {}
};
/**
......@@ -71,9 +123,8 @@ enum class ValueKind {
* and only the tail node is valid. ValueRef stores a value node, and it may be
* an invalid internal node. When you dereference it, it will check its successor,
* automatically find the tail node and return. This list would be modified to reduce
* list length by change value's successor, but a ValueRef always has steady m_storage
* when not explicitly modified.
* So we use m_storage to identify a ValueRef ( hash / equility / id ).
* list length by change value's successor, but a steady id was kept in ValueRef
* so we can use it for identify a ValueRef ( hash / equility / id ).
*/
class ValueRef {
public:
......@@ -93,9 +144,7 @@ private:
*/
storage_t& storage() const;
const Value* as(size_t typecode) const;
bool is(size_t typecode) const;
const Value* as(const IType& type) const;
public:
ValueRef() = default;
......@@ -103,45 +152,76 @@ public:
/**
* \brief whether value is instance of target type or not
*
* \tparam TValue target type
* \return true if type of value is TValue
* \return false if empty or type of value is not TValue
* \param type target type
* \return true if type of value is instance of type
* \return false if empty or type of value is not instance of type
*/
template <typename TValue>
inline bool is(Type<TValue> type = {}) const;
bool is(const IType& type) const;
/**
* \brief try cast value as target type
*
* \tparam TValue target type
* \tparam type target type
* \return TValue* raw pointer if success, otherwise nullptr
*/
template <typename TValue>
inline const TValue* as(Type<TValue> type = {}) const;
inline const TValue* as(const Type<TValue>& type) const;
/**
* \brief cast value to target type
*
* \tparam TValue target type
* \param type target type
* \return TValue& reference of value
*/
template <typename TValue>
inline const TValue& cast(Type<TValue> type = {}) const;
inline const TValue& cast(const Type<TValue>& type) const;
/**
* \brief like as(), but returns TypedValueRef instead
*
* \tparam TValue target type
* \param type target type
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template <typename TValue>
inline const TypedValueRef<TValue>& as_ref(Type<TValue> type = {}) const;
inline const TypedValueRef<TValue>& as_ref(const Type<TValue>& type) const;
/**
* \brief like cast(), but allow empty value and returns TypedValueRef instead
*
* \param type target type
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template <typename TValue>
inline const TypedValueRef<TValue>& cast_ref(const Type<TValue>& type) const;
template <typename TValue>
inline std::enable_if_t<TValue::is_primitive, bool> is() const {
return is(PrimitiveType<TValue>::instance);
}
template <typename TValue>
inline std::enable_if_t<TValue::is_primitive, const TValue*> as() const {
return as(PrimitiveType<TValue>::instance);
}
template <typename TValue>
inline std::enable_if_t<TValue::is_primitive, const TValue&> cast() const {
return cast(PrimitiveType<TValue>::instance);
}
template <typename TValue>
inline const TypedValueRef<TValue>& cast_ref(Type<TValue> type = {}) const;
inline std::enable_if_t<TValue::is_primitive, const TypedValueRef<TValue>&> as_ref()
const {
return as_ref(PrimitiveType<TValue>::instance);
}
template <typename TValue>
void on_cast_failure() const;
inline std::enable_if_t<TValue::is_primitive, const TypedValueRef<TValue>&>
cast_ref() const {
return cast_ref(PrimitiveType<TValue>::instance);
}
void on_cast_failure(const IType& type) const;
operator bool() const { return bool(m_storage); }
......@@ -172,8 +252,6 @@ public:
friend class ValueWeakRef;
template <typename>
friend class TypedValueRef;
template <typename, ValueKind>
friend class ValueImpl;
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs);
};
......@@ -195,7 +273,8 @@ protected:
public:
ValueWeakRef() = default;
ValueWeakRef(ValueRef value) : m_id(value.id()), m_storage(value.m_storage) {}
ValueWeakRef(const ValueRef& value)
: m_id(value.id()), m_storage(value.m_storage) {}
/**
* \brief try promote to ValueRef
......@@ -218,19 +297,15 @@ public:
class Value : public NonCopyableObj {
private:
uint64_t m_id = std::numeric_limits<uint64_t>::max();
size_t m_typecode = 0;
const IType* m_type = nullptr;
ValueRef m_successor;
size_t m_watching = 0;
protected:
Value(size_t typecode);
Value();
public:
size_t typecode() const { return m_typecode; }
const std::type_index type() const { return registered_types()[m_typecode]; }
static size_t register_type(std::type_index type);
static const std::vector<std::type_index>& registered_types();
const IType& type() const { return *m_type; }
static void register_value(ValueRef value);
static ValueRef get_value_by_id(uint64_t id);
......@@ -251,11 +326,12 @@ public:
friend class ValueRef;
friend class ValueWeakRef;
template <typename, ValueKind>
friend class ValueImpl;
template <typename T>
friend class TypedValueRef;
template <typename T>
friend class Type;
~Value();
private:
......@@ -267,30 +343,17 @@ private:
*
* \tparam T type of value
*/
template <typename T, ValueKind Kind>
class ValueImpl : public Value {
template <typename T>
class ObjectValue : public Value {
protected:
ValueImpl() : Value(TYPE_CODE) {}
ObjectValue() {}
public:
using ref_t = TypedValueRef<T>;
using weak_ref_t = TypedValueWeakRef<T>;
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }();
static constexpr ValueKind KIND = Kind;
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
static MGB_NOINLINE TypedValueRef<T> make(TArgs&&... args) {
static_assert(std::is_final_v<T>);
return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...));
}
static constexpr bool is_primitive = false;
static constexpr bool is_object = true;
};
/**
......@@ -299,74 +362,89 @@ public:
* \tparam T type of value
* \tparam TMixin type of mixin class
*/
template <typename T, ValueKind Kind, typename TMixin>
class MixinValueImpl : public ValueImpl<T, Kind>, public TMixin {
template <typename T, typename TMixin>
class PrimitiveValue : public Value, public TMixin {
public:
using ref_t = TypedValueRef<T>;
using weak_ref_t = TypedValueWeakRef<T>;
using TMixin::TMixin;
MixinValueImpl(TMixin mixin) : TMixin(std::move(mixin)) {}
PrimitiveValue(TMixin&& mixin) : TMixin(std::move(mixin)) {}
PrimitiveValue(const TMixin& mixin) : TMixin(mixin) {}
public:
void clear() override final { ((TMixin&)*this) = {}; }
bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; }
/**
* \brief helper function for construct a value
*
* \tparam TArgs types of arguments
* \param args arguments
* \return TypedValueRef<T> reference of value
*/
template <typename... TArgs>
static TypedValueRef<T> make(TArgs&&... args) {
return PrimitiveType<T>::instance.make(std::forward<TArgs&&>(args)...);
}
static constexpr bool is_primitive = true;
static constexpr bool is_object = false;
};
template <typename T>
PrimitiveType<T>::PrimitiveType() : Type<T>(typeid(T).name()) {
static_assert(std::is_base_of_v<Value, T>);
static_assert(!std::is_base_of_v<ObjectValue<T>, T>);
}
inline ValueRef::ValueRef(storage_t storage) {
// mgb_assert(storage);
m_storage = storage;
m_id = m_storage->m_id;
}
template <typename TValue>
inline const TValue* ValueRef::as(Type<TValue> type) const {
// auto _ = Stats::time_value_as.time_scope();
inline const TValue* ValueRef::as(const Type<TValue>& type) const {
static_assert(std::is_base_of_v<Value, TValue>);
return static_cast<const TValue*>(as(type.code()));
return static_cast<const TValue*>(as((const IType&)type));
}
template <typename TValue>
inline const TValue& ValueRef::cast(Type<TValue> type) const {
// auto _ = Stats::time_value_cast.time_scope();
inline const TValue& ValueRef::cast(const Type<TValue>& type) const {
auto* ptr = as<TValue>(type);
if (mgb_unlikely(!ptr)) {
on_cast_failure<TValue>();
on_cast_failure(type);
}
return static_cast<const TValue&>(*ptr);
}
template <typename TValue>
inline bool ValueRef::is(Type<TValue> type) const {
// auto _ = Stats::time_value_is.time_scope();
return is(type.code());
}
template <typename TValue>
inline const TypedValueRef<TValue>& ValueRef::as_ref(Type<TValue> type) const {
if (!is<TValue>(type)) {
inline const TypedValueRef<TValue>& ValueRef::as_ref(const Type<TValue>& type) const {
if (!is(type)) {
return TypedValueRef<TValue>::nil;
}
return *reinterpret_cast<const TypedValueRef<TValue>*>(this);
}
template <typename TValue>
inline const TypedValueRef<TValue>& ValueRef::cast_ref(Type<TValue> type) const {
inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type) const {
if (!m_storage) {
return TypedValueRef<TValue>::nil;
}
if (mgb_unlikely(!is<TValue>(type))) {
on_cast_failure<TValue>();
if (mgb_unlikely(!is(type))) {
on_cast_failure(type);
}
return *reinterpret_cast<const TypedValueRef<TValue>*>(this);
}
template <typename TValue>
void ValueRef::on_cast_failure() const {
inline void ValueRef::on_cast_failure(const IType& type) const {
// if this is ErrorValue, rethrow directly
storage()->try_rethrow();
mgb_assert(
storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s",
typeid(TValue).name(), to_string().c_str());
storage()->type() != type, "expect type %s, got %s", type.name().c_str(),
to_string().c_str());
}
/**
......@@ -382,26 +460,10 @@ private:
public:
TypedValueRef() = default;
const T& operator*() const {
if constexpr (T::KIND == ValueKind::Object) {
return this->template cast<T>();
} else if constexpr (T::KIND == ValueKind::Primitive) {
if (!m_storage) {
on_cast_failure<T>();
}
return static_cast<const T&>(*m_storage);
} else {
static_assert(!std::is_same_v<T, T>);
}
}
const T* operator->() const {
if constexpr (T::KIND == ValueKind::Object) {
return this->template as<T>();
} else if constexpr (T::KIND == ValueKind::Primitive) {
return static_cast<const T*>(m_storage.get());
} else {
static_assert(!std::is_same_v<T, T>);
}
mgb_assert(m_storage, "empty storage");
return static_cast<const T&>(*m_storage);
}
const T* operator->() const { return static_cast<const T*>(m_storage.get()); }
/**
* \brief reset underlying value to another value
......@@ -409,7 +471,7 @@ public:
* \param successor new value
*/
inline void reset(ValueRef successor) {
static_assert(T::KIND == ValueKind::Object);
static_assert(std::is_base_of_v<ObjectValue<T>, T>);
mgb_assert(m_storage);
mgb_assert(!m_storage->m_successor);
if (m_storage->m_watching) {
......@@ -422,25 +484,19 @@ public:
static inline const TypedValueRef nil;
friend class ValueRef;
template <typename, ValueKind>
friend class ValueImpl;
friend class Type<T>;
friend class TypedValueWeakRef<T>;
};
template <typename T>
class TypedValueWeakRef : public ValueWeakRef {
private:
TypedValueWeakRef(const ValueRef& value) : ValueWeakRef(value) {}
TypedValueWeakRef(const ValueWeakRef& value) : ValueWeakRef(value) {}
public:
TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {}
TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {}
TypedValueRef<T> lock() {
auto value = ValueWeakRef::lock();
if (value) {
return value.template as_ref<T>();
} else {
return {};
}
}
TypedValueWeakRef(const TypedValueRef<T>& value) : ValueWeakRef(value) {}
TypedValueRef<T> lock() { return (TypedValueRef<T>)ValueWeakRef::lock(); }
};
// TODO: add proxy value type, which is meant to be reset in the end
......@@ -509,10 +565,14 @@ inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_s
m_data[0] = std::move(item);
}
/*class ValueRefList : public SmallVector<ValueRef, 1> {
public:
using SmallVector::SmallVector;
};*/
template <typename T>
template <typename... TArgs>
TypedValueRef<T> Type<T>::make(TArgs&&... args) const {
static_assert(std::is_final_v<T>);
auto storage = LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...);
storage->m_type = this;
return ValueRef::make(std::move(storage));
}
} // namespace imperative
} // namespace mgb
......
......@@ -123,7 +123,7 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
inputs.clear();
auto input_grads = result.graph.apply(
auto input_grads = result.graph.apply<TensorPtr>(
backward_graph_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size());
......@@ -177,7 +177,7 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
inputs.clear();
auto input_grads = result.graph.apply(
auto input_grads = result.graph.apply<TensorPtr>(
backward_graph_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; });
mgb_assert(input_grads.size() == input_has_grad.size());
......@@ -244,11 +244,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads = expand_grads(
bg.output_mask,
bg.graph.apply(
bg.graph.apply<TensorPtr>(
backward_graph_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }));
auto precomp = obg.precomp.apply(
auto precomp = obg.precomp.apply<TensorPtr>(
SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; });
ASSERT_EQ(precomp.size(), 2);
......@@ -261,7 +261,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(
obg.input_has_grad,
obg.backward.apply(
obg.backward.apply<TensorPtr>(
backward_inputs, apply_shared_on_physical_tensor,
[&](auto&& x) { return x; }));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册