提交 4fa61620 编写于 作者: M Megvii Engine Team

perf(dispatch): improve performance of dispatch system

GitOrigin-RevId: 860028e1af63936e7b4edefbed90d8244e7cb8d2
上级 ca001777
......@@ -13,6 +13,7 @@
#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"
#include "megbrain/imperative/utils/stats.h"
#include "./tensor.h"
......
......@@ -21,6 +21,7 @@
#include "megbrain/imperative/transformations/symbol.h"
#include "megbrain/imperative/transformations/trace.h"
#include "megbrain/imperative/utils/map.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/opr/io.h"
#include "megbrain/plugin/profiler.h"
......@@ -52,8 +53,48 @@ namespace mgb::imperative::python {
namespace {
WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map;
struct SymbolVarContext {
TransformationContext context;
cg::ComputingGraph* graph;
SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) {
Transformation::swap_context(context);
}
void init() {
std::make_shared<SymbolTransformation>(graph)->register_at(
Transformation::top());
std::make_shared<ScalarTransformation>()->register_at(Transformation::top());
}
~SymbolVarContext() { Transformation::swap_context(context); }
};
ValueRef symvar2val(py::handle py_symbol_var) {
auto* symbol_var = py_symbol_var.cast<PySymbolVar*>();
ValueRef value = SymbolValue::make(symbol_var->m_node);
if (symbol_var->is_scalar) {
value = ScalarValue::make(value);
}
return value;
}
py::object val2symvar(py::handle typeobj, ValueRef value) {
bool is_scalar = false;
if (auto* scalar_value = value.as<ScalarValue>()) {
value = scalar_value->value();
is_scalar = true;
}
auto* node = value.cast<SymbolValue>().node();
auto py_symbol_var =
typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic));
py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar;
return py_symbol_var;
}
} // namespace
interpreter::Interpreter::Channel* interpreter_for_py = nullptr;
PyTypeObject* py_tensor_type = nullptr;
PyObject *cpp_use_symbolic_shape, *cpp_astensor1d;
......@@ -91,36 +132,17 @@ PyObject* py_apply(
if (py::isinstance<PySymbolVar>(py::handle(args[0]))) {
// swap to a special context to reuse scalar handle
TransformationContext symbol_var_context;
Transformation::swap_context(symbol_var_context);
CleanupGuard _{[&] { Transformation::swap_context(symbol_var_context); }};
auto* graph =
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph();
std::make_shared<SymbolTransformation>(graph)->register_at(
Transformation::top());
std::make_shared<ScalarTransformation>()->register_at(
Transformation::top());
SymbolVarContext context(
py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph());
context.init();
for (size_t i = 0; i < nargs; ++i) {
auto* py_input = py::handle(args[i]).cast<PySymbolVar*>();
ValueRef input = SymbolValue::make(py_input->m_node);
if (py_input->is_scalar) {
input = ScalarValue::make(input);
}
tensors[i] = input;
tensors[i] = symvar2val(args[i]);
}
auto outputs = imperative::apply(*op, tensors);
auto ret = pybind11::tuple(outputs.size());
auto typeobj = py::handle(args[0]).get_type();
for (size_t i = 0; i < outputs.size(); ++i) {
bool is_scalar = false;
if (auto* scalar_value = outputs[i].as<ScalarValue>()) {
outputs[i] = scalar_value->value();
is_scalar = true;
}
auto* node = outputs[i].cast<SymbolValue>().node();
ret[i] = typeobj(
pybind11::cast(node, pybind11::return_value_policy::automatic));
py::handle(ret[i]).cast<PySymbolVar*>()->is_scalar = is_scalar;
ret[i] = val2symvar(typeobj, outputs[i]);
}
return ret.release().ptr();
}
......@@ -1537,17 +1559,29 @@ void init_tensor(py::module m) {
}
});
m.def("reduce_to_scalar", [](py::object op, py::object tensor) {
auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto make_scalar_shape = [&](CompNode device) {
return imperative::apply(
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
HostStorage::make(device))[0];
m.def("reduce_to_scalar", [](py::object op, py::object tensor) -> py::object {
auto reduce_to_scalar = [](const OpDef& op, const ValueRef& input) {
auto make_scalar_shape = [&](CompNode device) {
return imperative::apply(
CreateTensor(CreateTensor::Const, device, dtype::Int32(), {0}),
HostStorage::make(device))[0];
};
return imperative::apply(op, input, make_scalar_shape(*input.device()))[0];
};
auto output = imperative::apply(
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data(),
make_scalar_shape(tw->m_tensor->comp_node()))[0];
return TensorWrapper::make(py_tensor_type, output);
if (py::isinstance<PySymbolVar>(tensor)) {
auto* graph = tensor.cast<PySymbolVar*>()->m_node->owner_graph();
SymbolVarContext context(graph);
context.init();
auto output = reduce_to_scalar(
*op.cast<std::shared_ptr<OpDef>>(), symvar2val(tensor));
auto typeobj = tensor.get_type();
return val2symvar(typeobj, output);
} else {
auto* tw = TensorWrapper::try_cast(tensor.ptr());
auto output = reduce_to_scalar(
*op.cast<std::shared_ptr<OpDef>>(), tw->m_tensor->data());
return TensorWrapper::make(py_tensor_type, output);
}
});
m.def("name_tensor", [](std::string name, py::object tensor) {
......@@ -1557,7 +1591,7 @@ void init_tensor(py::module m) {
});
m.def("is_grad_attached", [](std::vector<py::object> tensors) -> bool {
ValueRefList values(tensors.size());
SmallVector<ValueRef> values(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
}
......@@ -1570,17 +1604,16 @@ void init_tensor(py::module m) {
});
m.def("get_grad_key", [](std::vector<py::object> tensors) -> py::object {
ValueRefList values(tensors.size());
SmallVector<ValueRef> values(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
values[i] = tensors[i].cast<TensorWrapper>().m_tensor->data();
}
auto outputs = imperative::apply(GetGradKey(), values);
if (auto* grad_key_val = outputs[0].as<GradKeyValue>()) {
return py::reinterpret_borrow<py::object>(
GradKeyWrapper::wrap_t::pycast(GradKeyWrapper::get(*grad_key_val)));
} else {
auto output = imperative::apply(GetGradKey(), values)[0];
if (!output) {
return py::none();
}
return py::reinterpret_borrow<py::object>(GradKeyWrapper::wrap_t::pycast(
GradKeyWrapper::get(output.cast<GradKeyValue>())));
});
m.def("set_grad", [](py::object py_key, py::function backward_fn,
......@@ -1612,7 +1645,7 @@ void init_tensor(py::module m) {
}
return input_grads;
};
ValueRefList values(inputs.size() + outputs.size());
SmallVector<ValueRef> values(inputs.size() + outputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
values[i] = inputs[i].cast<TensorWrapper>().m_tensor->data();
}
......@@ -1669,6 +1702,10 @@ void init_tensor(py::module m) {
return reprs;
});
m.def("print_stats", [] { imperative::Stats::print(); });
m.def("reset_stats", [] { imperative::Stats::reset(); });
py::register_exception<TraceError>(m, "TraceError");
}
......
......@@ -67,7 +67,8 @@ struct TransformationManager {
}
};
class PyValue final : public MixinValueImpl<PyValue, pybind11::object> {
class PyValue final
: public MixinValueImpl<PyValue, ValueKind::Object, pybind11::object> {
public:
using MixinValueImpl::MixinValueImpl;
......
......@@ -14,13 +14,9 @@
#include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/utils/helper.h"
#include "megbrain/imperative/utils/map.h"
#include "megbrain/imperative/utils/stats.h"
namespace mgb {
void imperative_log_profile_begin(const char* message);
void imperative_log_profile(const char* message);
void imperative_log_profile_end(const char* message);
namespace imperative {
namespace {
......
......@@ -19,6 +19,7 @@
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/utils/stats.h"
#include "megbrain/imperative/utils/to_string.h"
#include "../blob_manager_impl.h"
......
#include "megbrain/imperative/transformation.h"
#include "megbrain/imperative/utils/stats.h"
namespace mgb {
namespace imperative {
......
......@@ -11,6 +11,7 @@
#include "megbrain/imperative/transformations/eval.h"
#include "megbrain/imperative/transformations/grad.h"
#include "megbrain/imperative/utils/stats.h"
namespace mgb {
namespace imperative {
......@@ -40,9 +41,6 @@ ShapeValue::ref_t InterpreterInfo::shape() const {
ValueRefList InterpreterTransformation::apply_op(
const ApplyOp& apply_op, Span<ValueRef> inputs) {
if (apply_op.op().same_type<FastpathCopy>()) {
return {inputs[0]};
}
SmallVector<Handle> input_handles;
SmallVector<Handle> output_handles;
CleanupGuard _{[&] {
......@@ -111,7 +109,11 @@ ValueRefList InterpreterTransformation::apply_create_tensor(
ValueRefList InterpreterTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
if (auto* op_val = op.as<ApplyOp>()) {
return apply_op(*op_val, inputs);
if (op_val->op().same_type<FastpathCopy>()) {
return inputs[0];
} else {
return apply_op(*op_val, inputs);
}
} else if (auto* get_attr = op.as<GetAttr>()) {
return apply_get_attr(*get_attr, inputs);
} else if (auto* create_tensor = op.as<CreateTensor>()) {
......
......@@ -11,8 +11,11 @@
#include "megbrain/imperative/transformations/grad.h"
#include <variant>
#include "megbrain/imperative/graph_cache.h"
#include "megbrain/imperative/resource_manager.h"
#include "megbrain/imperative/utils/stats.h"
#include <range/v3/all.hpp>
......@@ -20,20 +23,21 @@ namespace mgb {
namespace imperative {
static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph(
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs,
const OpDef& op, Span<ValueRef> inputs, Span<ValueRef> outputs,
Span<bool> inputs_require_grad) {
// hash
using OptimizedBackwardGraphCache = OpMethResultCache<
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>;
thread_local auto& cache =
*ResourceManager::create_local<OptimizedBackwardGraphCache>();
OptimizedBackwardGraphCache::key_t cache_key{op};
OptimizedBackwardGraphCache::key_t cache_key{op.shared_from_this()};
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs;
std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>();
cache_key.extra<0>() = inputs_require_grad.copy_into<SmallVector<bool>>();
input_descs.resize(inputs.size());
// some overhead, consider simplify LogicalTensorDesc
for (size_t i = 0; i < inputs.size(); ++i) {
input_descs[i].layout.dtype = inputs[i].dtype().cast<DTypeValue>();
input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>();
input_descs[i].layout.dtype = *inputs[i].dtype();
input_descs[i].comp_node = *inputs[i].device();
}
auto iter = cache.find(cache_key);
......@@ -45,7 +49,7 @@ static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_gra
SmallVector<bool> output_has_grad(outputs.size(), true);
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = OpDef::make_backward_graph(
*op, input_descs, std::get<0>(cache_key.extras), output_has_grad);
op, input_descs, std::get<0>(cache_key.extras), output_has_grad);
if (!bg.graph.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
}
......@@ -235,7 +239,7 @@ GradValue::ref_t GradKey::attach(
} else {
GradSlotPtr grad_slot;
auto& grad_fn = grad_slot.m_fn;
grad_fn = std::make_shared<GradFn>();
grad_fn = LocalPtr<GradFn>::make();
grad_fn->m_key = shared_from_this();
grad_fn->m_slots.resize(1);
grad_slot.m_index = 0;
......@@ -260,17 +264,21 @@ ValueRefList GradTransformation::apply_transformation(
const Operator& op, Span<ValueRef> inputs) {
auto fallback = [&] {
ValueRefList unwrapped_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto grad_value = as_grad_value(inputs[i])) {
unwrapped_inputs[i] = grad_value->m_value;
} else {
unwrapped_inputs[i] = inputs[i];
{
// overhead
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto&& grad_value = as_grad_value(inputs[i])) {
unwrapped_inputs[i] = grad_value->m_value;
} else {
unwrapped_inputs[i] = inputs[i];
}
}
}
return imperative::apply(op, unwrapped_inputs);
};
if (auto* get_attr = op.as<GetAttr>()) {
if (auto grad_value = as_grad_value(inputs.item())) {
if (op.is<GetAttr>()) {
// overhead
if (auto&& grad_value = as_grad_value(inputs.item())) {
return imperative::apply(op, grad_value->m_value);
} else {
return imperative::apply(op, inputs);
......@@ -281,28 +289,29 @@ ValueRefList GradTransformation::apply_transformation(
}
if (auto* op_val = op.as<ApplyOp>()) {
size_t nr_require_grad = 0;
SmallVector<bool> require_grads;
for (auto&& input : inputs) {
if (is_grad_value(input)) {
SmallVector<bool> require_grads(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
if (is_grad_value(inputs[i])) {
nr_require_grad++;
require_grads.push_back(true);
require_grads[i] = true;
} else {
require_grads.push_back(false);
require_grads[i] = false;
}
}
if (nr_require_grad == 0) {
return imperative::apply(op, inputs);
}
ValueRefList captured_inputs(inputs.size());
SmallVector<ValueRef> captured_inputs(inputs.size());
SmallVector<bool> inputs_require_grad(inputs.size());
// capture value so that trace could assume input as same
auto capture_value = [](ValueRef value) {
auto capture_value = [](const ValueRef& value) {
// TODO: fastpath copy shouldn't be an OpDef
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0];
static auto fastpath_copy = FastpathCopy::make();
return imperative::apply(ApplyOp(*fastpath_copy), value)[0];
};
for (size_t i = 0; i < inputs.size(); ++i) {
auto& input = inputs[i];
if (auto grad_value = as_grad_value(input)) {
if (auto&& grad_value = as_grad_value(input)) {
captured_inputs[i] = capture_value(grad_value->m_value);
inputs_require_grad[i] = true;
} else {
......@@ -310,32 +319,28 @@ ValueRefList GradTransformation::apply_transformation(
inputs_require_grad[i] = false;
}
}
decltype(std::declval<GradFn>().m_backward) backward_storage;
// copy grad_fn->m_backward is expensive
auto grad_fn = LocalPtr<GradFn>::make();
auto& backward_storage = grad_fn->m_backward;
auto outputs = [&] {
auto backward_rule =
CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo());
if (backward_rule) {
CustomBackward backward;
auto optional_outputs = backward_rule(
op_val->op(), {captured_inputs.data(), captured_inputs.size()},
{inputs_require_grad.data(), inputs_require_grad.size()},
backward);
op_val->op(), captured_inputs, inputs_require_grad, backward);
if (optional_outputs) {
backward_storage = backward;
// backward by rule
return *optional_outputs;
}
}
auto outputs = imperative::apply(
op, {captured_inputs.begin(), captured_inputs.end()});
auto outputs = imperative::apply(op, captured_inputs);
auto backward_graph = make_optimized_backward_graph(
op.cast<ApplyOp>().op().shared_from_this(),
{captured_inputs.begin(), captured_inputs.end()},
{outputs.data(), outputs.size()},
{inputs_require_grad.data(), inputs_require_grad.size()});
op_val->op(), captured_inputs, outputs, inputs_require_grad);
if (backward_graph) {
backward_storage = BackwardGraphWithClosure(
backward_graph, op.cast<ApplyOp>().op().shared_from_this(),
backward_graph, op_val->op().shared_from_this(),
{captured_inputs.begin(), captured_inputs.end()},
{outputs.data(), outputs.size()});
// backward by make_backward_graph
......@@ -348,18 +353,17 @@ ValueRefList GradTransformation::apply_transformation(
if (std::holds_alternative<std::monostate>(backward_storage)) {
return outputs;
}
auto grad_fn = std::make_shared<GradFn>();
grad_fn->m_key = m_key;
grad_fn->m_slots.resize(outputs.size());
grad_fn->m_backward = backward_storage;
mgb_assert(!outputs.empty());
grad_fn->m_dests.reserve(inputs.size());
// clang-format off
std::visit([&](auto& backward) {
auto visitor = [&](auto& backward) {
using T = std::decay_t<decltype(backward)>;
if constexpr (std::is_same_v<T, std::monostate>) {
mgb_throw(AssertionError, "invalid backward");
} else {
// little overhead
for (size_t i = 0; i < inputs.size(); ++i) {
if (backward.input_has_grad(i) && require_grads[i]) {
auto& input_grad_slot =
......@@ -373,19 +377,23 @@ ValueRefList GradTransformation::apply_transformation(
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (backward.output_requires_grad(i)) {
// little overhead: Value::make
auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i});
outputs[i] = record_grad(grad_value);
}
}
}
}, grad_fn->m_backward);
};
// std::visit may be slightly slower than direct if
std::visit(visitor, backward_storage);
// clang-format on
mgb_assert(!grad_fn->m_slots.empty());
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()});
return outputs;
} else if (op.is<CreateTensor>()) {
return imperative::apply(op, inputs);
} else if (auto* attach_grad = op.as<AttachGrad>()) {
}
if (auto* attach_grad = op.as<AttachGrad>()) {
if (!has_key(attach_grad->key())) {
return fallback();
}
......@@ -408,7 +416,7 @@ ValueRefList GradTransformation::apply_transformation(
return {};
} else if (auto* is_attached_to = op.as<IsAttachedTo>()) {
if (has_key(is_attached_to->key())) {
if (auto grad_value = as_grad_value(inputs[0])) {
if (auto&& grad_value = as_grad_value(inputs[0])) {
// TODO: assert grad_fn
return {BoolValue::make(true)};
}
......@@ -416,7 +424,7 @@ ValueRefList GradTransformation::apply_transformation(
return {BoolValue::make(false)};
} else if (auto* set_grad = op.as<SetGrad>()) {
// TODO: merge SetGrad and ApplyOp
auto grad_fn = std::make_shared<GradFn>();
auto grad_fn = LocalPtr<GradFn>::make();
auto& backward =
std::get<CustomBackward>(grad_fn->m_backward = CustomBackward());
size_t nr_inputs = set_grad->nr_inputs();
......@@ -433,7 +441,7 @@ ValueRefList GradTransformation::apply_transformation(
grad_fn->m_slots.resize(nr_outputs);
grad_fn->m_dests.reserve(nr_inputs);
for (size_t i = 0; i < nr_inputs; ++i) {
if (auto grad_value = as_grad_value(inputs_[i])) {
if (auto&& grad_value = as_grad_value(inputs_[i])) {
auto& input_grad_slot = grad_value->m_slot;
grad_fn->m_dests.emplace_back(grad_value->m_slot);
grad_fn->m_dests.back().m_producer_record.insert_after(
......@@ -461,21 +469,21 @@ ValueRefList GradTransformation::apply_transformation(
}
return {FunctionValue::make(make_backward_closure(inputs))};
} else if (op.is<DetachGrad>()) {
if (auto grad_value = as_grad_value(inputs[0])) {
if (auto&& grad_value = as_grad_value(inputs[0])) {
return {grad_value->m_value};
} else {
return {inputs[0]};
}
} else if (op.is<GetGradKey>()) {
for (auto&& input : inputs) {
if (auto grad_value = as_grad_value(input)) {
if (auto&& grad_value = as_grad_value(input)) {
return {GradKeyValue::make(grad_value->m_key)};
}
}
return imperative::apply(op, inputs);
} else if (op.kind() == Operator::IdentityLike) {
mgb_assert(inputs.size() == 1);
if (auto grad_value = as_grad_value(inputs[0])) {
if (auto&& grad_value = as_grad_value(inputs[0])) {
auto output = imperative::apply(op, grad_value->m_value)[0];
auto grad_output = GradValue::make(
output, grad_value->key(), grad_value->slot_for(m_key));
......@@ -493,7 +501,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) {
auto grad_key = m_key;
std::vector<GradSlotPtr> y_slots;
for (auto&& y : ys) {
if (auto grad_value = as_grad_value(y)) {
if (auto&& grad_value = as_grad_value(y)) {
y_slots.push_back(grad_value->slot_for(grad_key));
} else {
y_slots.emplace_back();
......
......@@ -13,6 +13,7 @@
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/imperative/ops/utility.h"
#include "megbrain/imperative/utils/stats.h"
namespace mgb {
namespace imperative {
......@@ -185,7 +186,7 @@ ValueRefList subtensor_rule(
bool is_scalar;
mgb_assert(!inputs_mask[0], "subtensor shouldn't have scalar input");
if (auto shape = input.shape()) {
size_t ndim = input.shape()->ndim;
size_t ndim = shape->ndim;
for (auto&& [axis, begin, end, step, idx] : subtensor.items) {
if (idx) {
ndim--;
......@@ -193,6 +194,7 @@ ValueRefList subtensor_rule(
}
is_scalar = ndim == 0;
} else {
// assume not scalar
is_scalar = false;
}
auto outputs = imperative::apply(subtensor, inputs);
......@@ -341,12 +343,16 @@ ValueRefList ScalarTransformation::apply_transformation(
if (auto* get_attr = op.as<GetAttr>()) {
// fastpath for GetAttr
return apply_get_attr(*get_attr, inputs);
} else if (auto* apply_op = op.as<ApplyOp>()) {
if (apply_op->op().same_type<FastpathCopy>()) {
return inputs[0];
}
}
size_t nr_inputs = inputs.size();
ValueRefList unwrapped_inputs(nr_inputs);
bool inputs_mask[nr_inputs];
SmallVector<bool> inputs_mask(nr_inputs);
for (size_t i = 0; i < inputs.size(); ++i) {
if (auto scalar_value = inputs[i].as_ref<ScalarValue>()) {
if (auto&& scalar_value = inputs[i].as_ref<ScalarValue>()) {
unwrapped_inputs[i] = scalar_value->value();
inputs_mask[i] = true;
} else {
......@@ -358,8 +364,7 @@ ValueRefList ScalarTransformation::apply_transformation(
if (auto apply_op = op.as<ApplyOp>()) {
auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo());
if (iter != scalar_rules.end()) {
return iter->second(
apply_op->op(), unwrapped_inputs, {inputs_mask, nr_inputs});
return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask);
} else {
// TODO: repeat op
return fallback();
......
......@@ -215,8 +215,8 @@ ValueRefList::ValueRefList(size_t nr_elems) {
init(nr_elems);
}
ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
: ValueRefList(values.begin(), values.end()) {}
/*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values)
: ValueRefList(values.begin(), values.end()) {}*/
ValueRefList::ValueRefList(const ValueRefList& rhs)
: ValueRefList(rhs.cbegin(), rhs.cend()) {}
......
......@@ -25,14 +25,16 @@ class GradKey;
using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>;
class ShapeValue final : public MixinValueImpl<ShapeValue, ValueShape> {
class ShapeValue final
: public MixinValueImpl<ShapeValue, ValueKind::Primitive, ValueShape> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class CompNodeValue final : public MixinValueImpl<CompNodeValue, CompNode> {
class CompNodeValue final
: public MixinValueImpl<CompNodeValue, ValueKind::Primitive, CompNode> {
public:
using MixinValueImpl::MixinValueImpl;
......@@ -40,7 +42,7 @@ public:
};
// TODO: override factory method
class BoolValue final : public ValueImpl<BoolValue> {
class BoolValue final : public ValueImpl<BoolValue, ValueKind::Primitive> {
private:
std::optional<bool> m_value;
......@@ -53,14 +55,17 @@ public:
void clear() override { m_value.reset(); }
};
class HostStorage final : public MixinValueImpl<HostStorage, HostTensorStorage> {
class HostStorage final
: public MixinValueImpl<HostStorage, ValueKind::Primitive, HostTensorStorage> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class DeviceStorage final : public MixinValueImpl<DeviceStorage, DeviceTensorStorage> {
class DeviceStorage final
: public MixinValueImpl<
DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> {
public:
using MixinValueImpl::MixinValueImpl;
......@@ -71,7 +76,7 @@ public:
* \brief like HostTensorND mixin, but allow scalar value
*
*/
class HostValue final : public ValueImpl<HostValue> {
class HostValue final : public ValueImpl<HostValue, ValueKind::Primitive> {
private:
DType m_dtype;
ValueShape m_shape;
......@@ -94,9 +99,9 @@ public:
}
DType dtype() const { return m_dtype; }
ValueShape shape() const { return m_shape; }
const ValueShape& shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); }
HostTensorStorage storage() const { return m_storage; }
const HostTensorStorage& storage() const { return m_storage; }
DTypeScalar item() const {
mgb_assert(m_shape.is_scalar());
return DTypeScalar::make_from_raw(m_dtype, m_storage.ptr());
......@@ -109,7 +114,7 @@ public:
* \brief like DeviceTensorND mixin, but allow scalar value
*
*/
class DeviceValue final : public ValueImpl<DeviceValue> {
class DeviceValue final : public ValueImpl<DeviceValue, ValueKind::Primitive> {
private:
DType m_dtype;
ValueShape m_shape;
......@@ -117,8 +122,8 @@ private:
public:
DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage)
: m_dtype(dtype), m_shape(shape), m_storage(storage) {}
DeviceValue(DeviceTensorND value)
: m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {}
DeviceValue(const DeviceTensorND& value)
: DeviceValue(
value.dtype(), ValueShape::from(value.shape()), value.storage()) {
}
......@@ -132,28 +137,31 @@ public:
}
DType dtype() const { return m_dtype; }
ValueShape shape() const { return m_shape; }
const ValueShape& shape() const { return m_shape; }
CompNode device() const { return m_storage.comp_node(); }
DeviceTensorStorage storage() const { return m_storage; }
const DeviceTensorStorage& storage() const { return m_storage; }
DeviceTensorND as_nd(bool allow_scalar = false) const;
};
class FunctionValue final : public MixinValueImpl<FunctionValue, GenericFunction> {
class FunctionValue final
: public MixinValueImpl<FunctionValue, ValueKind::Primitive, GenericFunction> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class DTypeValue final : public MixinValueImpl<DTypeValue, DType> {
class DTypeValue final
: public MixinValueImpl<DTypeValue, ValueKind::Primitive, DType> {
public:
using MixinValueImpl::MixinValueImpl;
std::string to_string() const override;
};
class StringValue final : public MixinValueImpl<StringValue, std::string> {
class StringValue final
: public MixinValueImpl<StringValue, ValueKind::Primitive, std::string> {
public:
using MixinValueImpl::MixinValueImpl;
......@@ -171,7 +179,8 @@ public:
std::string message() const { return m_message; }
};
class ErrorValue final : public MixinValueImpl<ErrorValue, Error> {
class ErrorValue final
: public MixinValueImpl<ErrorValue, ValueKind::Primitive, Error> {
public:
using MixinValueImpl::MixinValueImpl;
......
......@@ -47,9 +47,14 @@ constexpr bool is_all_value_ref_v =
(... && (std::is_base_of_v<ValueRef, std::decay_t<TArgs>> ||
std::is_same_v<ValueRef, std::decay_t<TArgs>>));
template <typename T>
static ValueRefList apply(T&& op, const ValueRef& arg) {
return imperative::apply(std::forward<T&&>(op), Span<ValueRef>{&arg, 1});
}
template <typename T, typename... TArgs>
static auto apply(T&& op, TArgs&&... args)
-> std::enable_if_t<is_all_value_ref_v<TArgs...>, ValueRefList> {
static auto apply(T&& op, TArgs&&... args) -> std::enable_if_t<
is_all_value_ref_v<TArgs...> && sizeof...(args) != 1, ValueRefList> {
ValueRef args_arr[sizeof...(TArgs)] = {std::forward<TArgs&&>(args)...};
return imperative::apply(
std::forward<T&&>(op),
......
......@@ -54,6 +54,11 @@ struct OpMethArgs {
return extras == rhs.extras;
}
template <size_t i>
auto& extra() {
return std::get<i>(extras);
}
struct hash_t {
size_t operator()(const OpMethArgs& key) const { return key.hash(); }
};
......
......@@ -60,7 +60,7 @@ public:
};
class InterpreterValue final
: public MixinValueImpl<InterpreterValue, InterpreterInfo> {
: public MixinValueImpl<InterpreterValue, ValueKind::Object, InterpreterInfo> {
public:
using MixinValueImpl::MixinValueImpl;
......
......@@ -104,37 +104,15 @@ struct ToStringTrait<GradSlot> {
std::string operator()(const GradSlot& value) const { return value.to_string(); }
};
class GradFn {
private:
std::weak_ptr<GradKey> m_key;
std::vector<GradSlot> m_slots;
std::vector<GradSlotProducerPtr> m_dests;
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward;
public:
void clear() {
m_key.reset();
m_slots.clear();
m_dests.clear();
m_backward.emplace<std::monostate>();
}
std::string to_string() const;
friend class GradSlotPtr;
friend class GradKey;
friend class GradTransformation;
};
class GradSlotPtr {
private:
std::shared_ptr<GradFn> m_fn;
LocalPtr<GradFn> m_fn;
size_t m_index = 0;
public:
GradSlotPtr(std::shared_ptr<GradFn> fn, size_t index) : m_fn(fn), m_index(index) {}
GradSlotPtr(LocalPtr<GradFn> fn, size_t index) : m_fn(fn), m_index(index) {}
GradSlotPtr() = default;
GradSlot* operator->() const { return &m_fn->m_slots[m_index]; }
GradSlot* operator->() const;
operator bool() const { return bool(m_fn); }
......@@ -171,7 +149,33 @@ struct ToStringTrait<GradSlotProducerPtr> {
}
};
class GradValue final : public ValueImpl<GradValue> {
class GradFn {
private:
std::weak_ptr<GradKey> m_key;
SmallVector<GradSlot> m_slots;
SmallVector<GradSlotProducerPtr> m_dests;
std::variant<std::monostate, BackwardGraphWithClosure, CustomBackward> m_backward;
public:
void clear() {
m_key.reset();
m_slots.clear();
m_dests.clear();
m_backward.emplace<std::monostate>();
}
std::string to_string() const;
friend class GradSlotPtr;
friend class GradKey;
friend class GradTransformation;
};
inline GradSlot* GradSlotPtr::operator->() const {
return &m_fn->m_slots[m_index];
}
class GradValue final : public ValueImpl<GradValue, ValueKind::Object> {
private:
ValueRef m_value;
std::shared_ptr<GradKey> m_key;
......@@ -179,7 +183,7 @@ private:
public:
GradValue(ValueRef value, std::shared_ptr<GradKey> key, GradSlotPtr slot = {})
: m_value(value), m_key(key), m_slot(slot) {}
: m_value(std::move(value)), m_key(std::move(key)), m_slot(slot) {}
std::string to_string() const override;
......@@ -209,12 +213,13 @@ public:
class GradKey : public std::enable_shared_from_this<GradKey> {
private:
std::string m_name;
std::vector<std::pair<std::weak_ptr<GradFn>, std::shared_ptr<OpDef>>> m_tape;
std::vector<std::pair<std::shared_ptr<GradFn>, std::shared_ptr<OpDef>>>
m_frozen_tape;
std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape;
std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape;
bool m_frozen = false;
public:
GradKey() { m_tape.reserve(4 * 1024); }
void backward();
GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback);
const std::string& name() const { return m_name; }
......@@ -225,7 +230,8 @@ public:
};
class GradKeyValue final
: public MixinValueImpl<GradKeyValue, std::shared_ptr<GradKey>> {
: public MixinValueImpl<
GradKeyValue, ValueKind::Primitive, std::shared_ptr<GradKey>> {
public:
using MixinValueImpl::MixinValueImpl;
......@@ -248,7 +254,7 @@ public:
return tensor;
}
bool is_grad_value(ValueRef value) {
bool is_grad_value(const ValueRef& value) {
if (auto* grad_value = value.as<GradValue>()) {
if (grad_value->has_key(m_key)) {
return true;
......@@ -266,13 +272,14 @@ public:
* \param value
* \return GradValue::ref_t
*/
GradValue::ref_t as_grad_value(ValueRef value) {
if (auto grad_value = value.as_ref<GradValue>()) {
const GradValue::ref_t& as_grad_value(const ValueRef& value) {
auto&& grad_value = value.as_ref<GradValue>();
if (grad_value) {
if (grad_value->has_key(m_key)) {
return grad_value;
}
}
return {};
return GradValue::ref_t::nil;
}
bool has_key(std::shared_ptr<GradKey> key) {
......
......@@ -39,7 +39,8 @@ public:
std::string name() const { return m_name; }
};
class LazyEvalValue final : public MixinValueImpl<LazyEvalValue, LazyEvalInfo> {
class LazyEvalValue final
: public MixinValueImpl<LazyEvalValue, ValueKind::Object, LazyEvalInfo> {
public:
using MixinValueImpl::MixinValueImpl;
......
......@@ -17,7 +17,7 @@
namespace mgb::imperative {
class ScalarValue final : public ValueImpl<ScalarValue> {
class ScalarValue final : public ValueImpl<ScalarValue, ValueKind::Object> {
private:
ValueRef m_value;
......
......@@ -22,7 +22,7 @@
namespace mgb::imperative {
class SymbolValue final : public ValueImpl<SymbolValue> {
class SymbolValue final : public ValueImpl<SymbolValue, ValueKind::Object> {
private:
VarNode* m_node = nullptr;
......
......@@ -111,7 +111,8 @@ public:
size_t id() const { return m_id; }
};
class TracingValue final : public MixinValueImpl<TracingValue, TracingInfo> {
class TracingValue final
: public MixinValueImpl<TracingValue, ValueKind::Object, TracingInfo> {
public:
using MixinValueImpl::MixinValueImpl;
......@@ -256,7 +257,8 @@ public:
}
};
class TracedValue final : public MixinValueImpl<TracedValue, TracedInfo> {
class TracedValue final
: public MixinValueImpl<TracedValue, ValueKind::Object, TracedInfo> {
public:
using MixinValueImpl::MixinValueImpl;
......
#pragma once
#include <chrono>
#include <iostream>
#include <string>
#include <vector>
namespace mgb {
namespace imperative {
namespace stats {
#define MGE_ENABLE_STATS 0
class Timer {
public:
using clock_t = std::chrono::system_clock;
private:
clock_t::duration m_duration = clock_t::duration{0};
size_t m_timing = 0;
const char* m_name = nullptr;
uint64_t m_count = 0;
size_t m_enabled = 1;
bool m_default_enabled = true;
struct TimeScopeRecursive {
Timer& timer;
clock_t::time_point start;
bool released = false;
TimeScopeRecursive(Timer& timer) : timer(timer) {
if (timer.m_enabled && !timer.m_timing++) {
start = clock_t::now();
}
}
~TimeScopeRecursive() { release(); }
void release() {
if (released) {
return;
}
if (timer.m_enabled) {
if (!--timer.m_timing) {
timer.m_duration += (clock_t::now() - start);
}
timer.m_count++;
}
released = true;
}
};
struct EnableScope {
Timer& timer;
bool released = false;
EnableScope(Timer& timer) : timer(timer) { timer.m_enabled++; }
~EnableScope() { release(); }
void release() {
if (released) {
return;
}
timer.m_enabled--;
released = true;
}
};
using TimeScope = TimeScopeRecursive;
public:
Timer(const char* name, bool default_enabled);
const char* name() { return m_name; }
auto time_scope() { return TimeScope(*this); }
auto time_scope_recursive() { return TimeScopeRecursive(*this); };
auto enable_scope() { return EnableScope(*this); }
void reset() {
m_duration = clock_t::duration{0};
m_count = 0;
m_enabled = m_default_enabled ? 1 : 0;
}
clock_t::duration get() const { return m_duration; }
uint64_t count() const { return m_count; }
};
} // namespace stats
struct Stats {
static inline std::vector<stats::Timer*> sm_timers;
// register your timers here
// for example:
//
// static inline stats::Timer mytimer;
//
// then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code
static void print() {
std::vector<const char*> unused_timers;
for (auto* timer : sm_timers) {
if (timer->count() == 0) {
unused_timers.push_back(timer->name());
} else {
printf("%s costs %ld ns, happens %ld times\n", timer->name(),
timer->get().count(), timer->count());
}
}
if (!unused_timers.empty()) {
printf("%zu timers unused\n", unused_timers.size());
}
}
static void reset() {
for (auto* timer : sm_timers) {
timer->reset();
}
}
};
inline stats::Timer::Timer(const char* name, bool default_enabled)
: m_name(name), m_default_enabled(default_enabled) {
Stats::sm_timers.push_back(this);
}
#if MGE_ENABLE_STATS
#define MGE_TIMER_SCOPE(name) auto name = Stats::name.time_scope()
#define MGE_TIMER_SCOPE_RELEASE(name) name.release()
#define MGE_TIMER_SCOPE_ENABLE(name) auto name = Stats::name.enable_scope()
#else
#define MGE_TIMER_SCOPE(name) (void)0
#define MGE_TIMER_SCOPE_RELEASE(name) (void)0
#define MGE_TIMER_SCOPE_ENABLE(name) (void)0
#endif
} // namespace imperative
} // namespace mgb
......@@ -23,6 +23,7 @@
#include "megbrain/imperative/utils/debug.h"
#include "megbrain/imperative/utils/local_ptr.h"
#include "megbrain/imperative/utils/span.h"
#include "megbrain/imperative/utils/stats.h"
namespace mgb {
namespace imperative {
......@@ -58,6 +59,11 @@ public:
inline size_t code() const { return m_code; }
};
enum class ValueKind {
Primitive,
Object,
};
/**
* \brief an smart reference of value
*
......@@ -129,10 +135,10 @@ public:
* \return TypedValueRef<TValue> reference if success, otherwise empty reference
*/
template <typename TValue>
inline TypedValueRef<TValue> as_ref(Type<TValue> type = {}) const;
inline const TypedValueRef<TValue>& as_ref(Type<TValue> type = {}) const;
template <typename TValue>
inline TypedValueRef<TValue> cast_ref(Type<TValue> type = {}) const;
inline const TypedValueRef<TValue>& cast_ref(Type<TValue> type = {}) const;
template <typename TValue>
void on_cast_failure() const;
......@@ -161,14 +167,18 @@ public:
static bool any_watching();
static const ValueRef nil;
friend class ValueWeakRef;
template <typename T>
template <typename>
friend class TypedValueRef;
template <typename T>
template <typename, ValueKind>
friend class ValueImpl;
friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs);
};
inline const ValueRef ValueRef::nil;
template <>
struct ToStringTrait<ValueRef> {
public:
......@@ -241,7 +251,7 @@ public:
friend class ValueRef;
friend class ValueWeakRef;
template <typename T>
template <typename, ValueKind>
friend class ValueImpl;
template <typename T>
friend class TypedValueRef;
......@@ -257,7 +267,7 @@ private:
*
* \tparam T type of value
*/
template <typename T>
template <typename T, ValueKind Kind>
class ValueImpl : public Value {
protected:
ValueImpl() : Value(TYPE_CODE) {}
......@@ -267,6 +277,7 @@ public:
using weak_ref_t = TypedValueWeakRef<T>;
static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }();
static constexpr ValueKind KIND = Kind;
/**
* \brief helper function for construct a value
......@@ -288,8 +299,8 @@ public:
* \tparam T type of value
* \tparam TMixin type of mixin class
*/
template <typename T, typename TMixin>
class MixinValueImpl : public ValueImpl<T>, public TMixin {
template <typename T, ValueKind Kind, typename TMixin>
class MixinValueImpl : public ValueImpl<T, Kind>, public TMixin {
public:
using TMixin::TMixin;
......@@ -309,12 +320,14 @@ inline ValueRef::ValueRef(storage_t storage) {
template <typename TValue>
inline const TValue* ValueRef::as(Type<TValue> type) const {
static_assert(std::is_base_of_v<ValueImpl<TValue>, TValue>);
// auto _ = Stats::time_value_as.time_scope();
static_assert(std::is_base_of_v<Value, TValue>);
return static_cast<const TValue*>(as(type.code()));
}
template <typename TValue>
inline const TValue& ValueRef::cast(Type<TValue> type) const {
// auto _ = Stats::time_value_cast.time_scope();
auto* ptr = as<TValue>(type);
if (mgb_unlikely(!ptr)) {
on_cast_failure<TValue>();
......@@ -324,26 +337,27 @@ inline const TValue& ValueRef::cast(Type<TValue> type) const {
template <typename TValue>
inline bool ValueRef::is(Type<TValue> type) const {
// auto _ = Stats::time_value_is.time_scope();
return is(type.code());
}
template <typename TValue>
inline TypedValueRef<TValue> ValueRef::as_ref(Type<TValue> type) const {
inline const TypedValueRef<TValue>& ValueRef::as_ref(Type<TValue> type) const {
if (!is<TValue>(type)) {
return {};
return TypedValueRef<TValue>::nil;
}
return TypedValueRef<TValue>(*this);
return *reinterpret_cast<const TypedValueRef<TValue>*>(this);
}
template <typename TValue>
inline TypedValueRef<TValue> ValueRef::cast_ref(Type<TValue> type) const {
inline const TypedValueRef<TValue>& ValueRef::cast_ref(Type<TValue> type) const {
if (!m_storage) {
return {};
return TypedValueRef<TValue>::nil;
}
if (mgb_unlikely(!is<TValue>(type))) {
on_cast_failure<TValue>();
}
return TypedValueRef<TValue>(*this);
return *reinterpret_cast<const TypedValueRef<TValue>*>(this);
}
template <typename TValue>
......@@ -363,12 +377,31 @@ void ValueRef::on_cast_failure() const {
template <typename T>
class TypedValueRef : public ValueRef {
private:
TypedValueRef(ValueRef value) : ValueRef(value) {}
TypedValueRef(ValueRef value) : ValueRef(std::move(value)) {}
public:
TypedValueRef() = default;
const T& operator*() const { return this->template cast<T>(); }
const T* operator->() const { return this->template as<T>(); }
const T& operator*() const {
if constexpr (T::KIND == ValueKind::Object) {
return this->template cast<T>();
} else if constexpr (T::KIND == ValueKind::Primitive) {
if (!m_storage) {
on_cast_failure<T>();
}
return static_cast<const T&>(*m_storage);
} else {
static_assert(!std::is_same_v<T, T>);
}
}
const T* operator->() const {
if constexpr (T::KIND == ValueKind::Object) {
return this->template as<T>();
} else if constexpr (T::KIND == ValueKind::Primitive) {
return static_cast<const T*>(m_storage.get());
} else {
static_assert(!std::is_same_v<T, T>);
}
}
/**
* \brief reset underlying value to another value
......@@ -376,6 +409,7 @@ public:
* \param successor new value
*/
inline void reset(ValueRef successor) {
static_assert(T::KIND == ValueKind::Object);
mgb_assert(m_storage);
mgb_assert(!m_storage->m_successor);
if (m_storage->m_watching) {
......@@ -385,9 +419,11 @@ public:
m_storage->m_successor = ValueRef(successor.storage());
}
static inline const TypedValueRef nil;
friend class ValueRef;
template <typename U>
template <typename, ValueKind>
friend class ValueImpl;
};
......@@ -423,7 +459,7 @@ public:
ValueRefList() = default;
ValueRefList(size_t nr_elems);
ValueRefList(ValueRef item);
ValueRefList(std::initializer_list<ValueRef> values);
// ValueRefList(std::initializer_list<ValueRef> values);
template <typename TIterator>
ValueRefList(TIterator begin, TIterator end);
ValueRefList(const ValueRefList& rhs);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册