From 14d8b709e1f0f2dc88175dee97808709d6706308 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 18 Dec 2020 12:50:50 +0800 Subject: [PATCH] perf(mge/imperative): add mini graph to partially replace proxy graph GitOrigin-RevId: 73e2529ba53ccb6c0607f52aee40e69e2c289343 --- imperative/src/impl/interpreter_impl.cpp | 5 +- imperative/src/impl/proxy_graph/common.h | 10 + imperative/src/impl/proxy_graph/mini_graph.h | 617 ++++++++++++++++++ .../src/impl/proxy_graph/proxy_graph.cpp | 27 + .../src/impl/proxy_graph/proxy_graph_base.h | 118 ++++ imperative/src/impl/proxy_graph_detail.cpp | 10 +- .../megbrain/imperative/physical_tensor.h | 8 + .../include/megbrain/graph/static_infer.h | 7 +- src/core/include/megbrain/graph/var_node.h | 6 +- 9 files changed, 799 insertions(+), 9 deletions(-) create mode 100644 imperative/src/impl/proxy_graph/common.h create mode 100644 imperative/src/impl/proxy_graph/mini_graph.h create mode 100644 imperative/src/impl/proxy_graph/proxy_graph.cpp create mode 100644 imperative/src/impl/proxy_graph/proxy_graph_base.h diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 650b846b4..223ce678b 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -258,6 +258,9 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = MGB_LOCK_GUARD(m_mutex); dest->value_fetched = ptr->value_fetched(); // update tensor desc for static infer + // if (dest->desc.layout.ndim) { + // mgb_assert(dest->desc.layout.eq_shape(ptr->layout())); + // } dest->desc.layout = ptr->layout(); dest->desc.comp_node = ptr->comp_node(); dest->ptr = std::move(ptr); @@ -363,7 +366,7 @@ void ChannelImpl::regenerate(TensorInfo* info, bool must_drop = false) { } inputs.push_back(i->ptr); } - auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs); + auto outputs = OpDef::apply_on_physical_tensor(*path.op, inputs); for (size_t i = 0; i < outputs.size(); i ++) { auto out_ptr = path.outputs[i].lock(); if (out_ptr) { diff --git a/imperative/src/impl/proxy_graph/common.h b/imperative/src/impl/proxy_graph/common.h new file mode 100644 index 000000000..846279a3c --- /dev/null +++ b/imperative/src/impl/proxy_graph/common.h @@ -0,0 +1,10 @@ +namespace mgb::imperative::proxy_graph { + +// a "namespace" struct to simplify friend declaration, +// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph +struct ProxyGraph { + struct InputPlaceholder; + struct MiniGraph; +}; + +} // namespace mgb::imperative::proxy_graph diff --git a/imperative/src/impl/proxy_graph/mini_graph.h b/imperative/src/impl/proxy_graph/mini_graph.h new file mode 100644 index 000000000..a3866b4f5 --- /dev/null +++ b/imperative/src/impl/proxy_graph/mini_graph.h @@ -0,0 +1,617 @@ +#include "megbrain/graph/operator_node.h" +#include "megbrain/imperative/physical_tensor.h" +#include "megbrain/imperative/op_def.h" + +#include "./common.h" +#include "./proxy_graph_base.h" + +#include +#include "range/v3/all.hpp" + + +namespace mgb::imperative::proxy_graph { + +using cg::OperatorNodeBase; + + +template +std::pair find_index(const C& container, const E& item) { + auto&& it = std::find(container.begin(), container.end(), item); + return {it != container.end(), it - container.begin()}; +} + + +template class TensorAdaptor; + +template +using enable_if_same_upto_cv_t = std::enable_if_t, std::remove_cv_t>>; + +template +class TensorAdaptor> { + T& wrapped; + template + using maybe_add_const_t = std::conditional_t, const U, U>; + +public: + using type = T; + + TensorAdaptor(T& desc) : wrapped(desc) {} + TensorAdaptor(T* desc) : wrapped(*desc) {} + + DType dtype() {return wrapped.layout.dtype;} + CompNode comp_node() {return wrapped.comp_node;} + maybe_add_const_t& shape() {return wrapped.layout;} + bool has_value() {return wrapped.value.shape_valid();} + auto& value() {return wrapped.value;} + + auto* operator->() {return &wrapped;} +}; + +template +class TensorAdaptor> { + Tensor& wrapped; + +public: + using type = Tensor; + + TensorAdaptor(Tensor& tensor) : wrapped(tensor) {} + TensorAdaptor(Tensor* tensor) : wrapped(*tensor) {} + + DType dtype() {return wrapped.dtype();} + CompNode comp_node() {return wrapped.comp_node();} + const TensorShape& shape() {return wrapped.shape();} + + type* operator->() {return &wrapped;} +}; + +// deduction guides +template TensorAdaptor(T&) -> TensorAdaptor; +template TensorAdaptor(T*) -> TensorAdaptor; + + +// single opr graph, for static inference and execution +// contains static inference descs +class ProxyGraph::MiniGraph { +protected: + struct InferDepItem { + bool is_input : 1; + size_t idx : 63; + cg::static_infer::DepType type; + }; + + enum class InferStatus { + UNKOWN, + READY, + FAILED + }; + + // inference desc and pre-allocated storage for a single var + template + struct InferData { + SmallVector deps; + thin_function infer_func; + + // pre-allocated infer states + InferStatus status = InferStatus::UNKOWN; + cg::static_infer::InpVal inp_val; + T dest; + + void initialize(OperatorNodeBase* opr, const cg::static_infer::DepVal& dep_val, + const thin_function& func) { + mgb_assert(!infer_func); + infer_func = func; + inp_val.val.resize(dep_val.size()); + deps.reserve(dep_val.size()); + + for (auto&& dep : dep_val) { + auto [found, i] = find_index(opr->input(), dep.dest); + if (found) { + deps.push_back({true, i, dep.type}); + } else { + auto [found, i] = find_index(opr->output(), dep.dest); + mgb_assert(found); + deps.push_back({false, i, dep.type}); + } + } + } + + void reset() { + status = InferStatus::UNKOWN; + if constexpr (std::is_same_v) { + dest.ndim = 0; + } else { + static_assert(std::is_same_v); + dest.storage({}); + } + } + }; + + struct OutputData { + InferData shape_infer; + InferData value_infer; + }; + + struct InferSessionBase { + virtual const TensorShape& infer_shape(VarNode*) {mgb_assert(0);} + virtual const TensorShape* infer_shape_fallible(VarNode*) {mgb_assert(0);} + virtual const DeviceTensorND& infer_value(VarNode*) {mgb_assert(0);} + virtual const DeviceTensorND* infer_value_fallible(VarNode*) {mgb_assert(0);} + }; + + OperatorNodeBase* m_opr = nullptr; + SmallVector> opr_ref_keeper; + + size_t run_id = 0; + SmallVector output_data; + SmallVector input_remap; + SmallVector output_remap; + + // pre-allocated buffer for converted inputs + SmallVector> input_value_storage; + + InferSessionBase* m_sess = nullptr; + + template + struct InputAdaptor { + T& wrapped; + SmallVector>& value_storage; + + InputAdaptor(MiniGraph& owner, T& inputs) : wrapped(inputs), value_storage(owner.input_value_storage) {} + ~InputAdaptor() { + for (auto& i : value_storage) { + i.reset(); + } + } + + const TensorShape* shape(size_t i) { + TensorAdaptor tensor(wrapped[i]); + auto& shape = tensor.shape(); + return shape.ndim ? &shape : nullptr; + } + + const DeviceTensorND* value(size_t i, bool sync) { + TensorAdaptor tensor(wrapped[i]); + using tensor_t = std::remove_cv_t; + if constexpr (std::is_same_v) { + auto& storage = value_storage[i]; + if (!storage) { + if (sync) { + return &storage.emplace(tensor->get_value().proxy_to_default_cpu()); + } else { + if (auto* hv = tensor->try_get_value()) { + return &storage.emplace(hv->proxy_to_default_cpu()); + } + return nullptr; + } + } + } else { + auto& value = tensor.value(); + return value.shape_valid() ? &value : nullptr; + } + } + }; + +public: + template + MiniGraph(G& graph, const OpDef& opdef, const I& inputs) : input_value_storage(inputs.size()) { + mgb_assert(!m_opr); + auto _ = graph.scoped_attach(this); + cg::VarNodeArray vinputs(inputs.size()); + for (auto&& [i, t] : ranges::views::enumerate(inputs)) { + auto tensor = TensorAdaptor(t); + opr_ref_keeper.emplace_back(new InputPlaceholder(graph, tensor.dtype(), tensor.comp_node())); + vinputs[i] = opr_ref_keeper.back()->output(0); + } + auto ovars = OpDef::apply_on_var_node(opdef, vinputs); + mgb_assert(m_opr); + output_data.resize(m_opr->output().size()); + for (auto* v : ovars) { + mgb_assert(v->owner_opr() == m_opr); + } + m_opr->init_output_static_infer_desc(); + + // fix permuted input + input_remap.reserve(m_opr->input().size()); + for (auto* v : m_opr->input()) { + auto [found, i] = find_index(vinputs, v); + mgb_assert(found); + input_remap.push_back(i); + } + auto fix_dep_idx = [&](SmallVector& deps) { + for (auto& dep : deps) { + if (dep.is_input) { + dep.idx = input_remap[dep.idx]; + } + } + }; + for (auto& data : output_data) { + fix_dep_idx(data.shape_infer.deps); + fix_dep_idx(data.value_infer.deps); + } + + // fix permuted output + output_remap.reserve(ovars.size()); + for (auto* v : ovars) { + auto [found, i] = find_index(m_opr->output(), v); + mgb_assert(found); + output_remap.push_back(i); + } + } + + // methods for containing graph + + OperatorNodeBase* insert_opr(std::unique_ptr opr_uniqp) { + mgb_assert(!m_opr); + m_opr = opr_uniqp.get(); + mgb_assert(opr_ref_keeper.back()->owner_graph() == m_opr->owner_graph()); + mgb_assert(!m_opr->inserted_in_graph()); + opr_ref_keeper.push_back(std::move(opr_uniqp)); + m_opr->set_inserted_in_graph(); + m_opr->init_output_comp_node(); + m_opr->init_output_dtype(); + return m_opr; + } + + void register_shape_infer(VarNode* varnode, const cg::static_infer::ShapeInferDesc& desc) { + auto [found, i] = find_index(m_opr->output(), varnode); + mgb_assert(found); + output_data[i].shape_infer.initialize(m_opr, desc.deps, desc.infer_func); + } + + void register_value_infer(VarNode* varnode, const cg::static_infer::ValueInferDesc& desc) { + auto [found, i] = find_index(m_opr->output(), varnode); + mgb_assert(found); + output_data[i].value_infer.initialize(m_opr, desc.deps, desc.infer_func); + } + + const TensorShape& infer_shape(VarNode* var) { + return m_sess->infer_shape(var); + } + + const DeviceTensorND& infer_value(VarNode* var) { + return m_sess->infer_value(var); + } + + OperatorNodeBase* opr() { + return m_opr; + } + + // inference routine template for type of input + template + class InferSession : protected InferSessionBase { + MiniGraph& owner; + SmallVector& output_data; + InputAdaptor inputs; + + template + const T* infer(InferData& target, bool sync) { + bool ret; + if (target.status != InferStatus::UNKOWN) { + ret = target.status == InferStatus::READY; + } else { + ret = target.infer_func && do_infer(target, sync); + target.status = ret ? InferStatus::READY : InferStatus::FAILED; + } + return ret ? &target.dest : nullptr; + } + + template + bool do_infer(InferData& target, bool sync) { + for (size_t i = 0; i < target.deps.size(); ++i) { + target.inp_val.run_id = owner.run_id; + auto& dep = target.deps[i]; + if (dep.is_input) { + if (dep.type == cg::static_infer::DepType::SHAPE) { + if (auto* val = inputs.shape(dep.idx)) { + target.inp_val.val[i].m_shape = val; + } else return false; + } else { + if (auto* val = inputs.value(dep.idx, sync)) { + target.inp_val.val[i].m_value = val; + } else return false; + } + } else { + if (dep.type == cg::static_infer::DepType::SHAPE) { + if (auto* val = infer(output_data[dep.idx].shape_infer, sync)) { + target.inp_val.val[i].m_shape = val; + } else return false; + } else { + if (auto* val = infer(output_data[dep.idx].value_infer, sync)) { + target.inp_val.val[i].m_value = val; + } else return false; + } + } + } + return target.infer_func(target.dest, target.inp_val); + } + + // methods for owner mini graph + // corresponding methods of containing ComputingGraph will be redirected here + + const TensorShape& infer_shape(VarNode* var) override { + mgb_assert(owner.m_opr); + auto [found, i] = find_index(owner.m_opr->input(), var); + mgb_assert(found); + i = owner.input_remap[i]; + auto* shape = inputs.shape(i); + mgb_assert(shape); + return *shape; + } + + const DeviceTensorND& infer_value(VarNode* var) override { + mgb_assert(owner.m_opr); + auto [found, i] = find_index(owner.m_opr->input(), var); + mgb_assert(found); + i = owner.input_remap[i]; + auto* value = inputs.value(i, false); + mgb_assert(value); + return *value; + } + + public: + InferSession(MiniGraph& mgraph, I& inputs_) + : owner(mgraph), output_data(mgraph.output_data), inputs(mgraph, inputs_) { + mgraph.run_id++; + mgb_assert(!owner.m_sess); + owner.m_sess = this; + } + ~InferSession() { + owner.m_sess = nullptr; + for (auto& i : output_data) { + i.shape_infer.reset(); + i.value_infer.reset(); + } + } + + const TensorShape* infer_shape(size_t i, bool sync) { + i = owner.output_remap[i]; + return infer(output_data[i].shape_infer, sync); + } + + const DeviceTensorND* infer_value(size_t i, bool sync) { + i = owner.output_remap[i]; + return infer(output_data[i].shape_infer, sync); + } + }; + + template + InferSession infer_session(T& inputs) {return InferSession(*this, inputs);} + + size_t output_size() { + return output_remap.size(); + } + + VarNode* output_var(size_t i) { + i = output_remap[i]; + return m_opr->output(i); + } +}; + + +class CompNodeTracker { + static constexpr size_t bucket_size = 100; + static constexpr size_t bucket_count = 10; + + CompNode comp_node; + std::array, bucket_count> events; + + size_t free_slots = bucket_size; + size_t head = 0; // events[head] is not recorded + size_t tail = 0; // events[tail] is not finished + + void rotate() { + while (tail < head && events[tail % bucket_count]->finished()) { + ++tail; + } + auto& ev = events[head % bucket_count]; + if (head == tail + bucket_count) { + // do not wait if head == tail + ev->host_wait(); + ++tail; + } + ev->record(); + ++head; + free_slots = bucket_size; + } + +public: + CompNodeTracker(CompNode cn) : comp_node(cn) { + for (auto& e : events) { + e = cn.create_event(); + } + } + + size_t add_opr() { + if (!free_slots) rotate(); + --free_slots; + return head; + } + + size_t progress() { + return tail; + } +}; + + +class ExecMiniGraph : public ProxyGraph::MiniGraph { + union BusyListItem { + size_t finish_time; + OperatorNodeBase* opr; + }; + + SmallVector comp_node_trackers; + std::deque busy_oprs; + SmallVector idle_oprs; + + OperatorNodeBase* acquire_opr() { + mgb_assert(!m_opr); + if (!idle_oprs.empty()) { + m_opr = idle_oprs.back(); + idle_oprs.pop_back(); + return m_opr; + } + mgb_assert(busy_oprs.size() > comp_node_trackers.size()); + bool can_pop = true; + for (auto [item, tracker] : ranges::views::zip(busy_oprs, comp_node_trackers)) { + if (item.finish_time >= tracker->progress()) { + can_pop = false; + break; + } + } + if (can_pop) { + for (auto _ : comp_node_trackers) { + busy_oprs.pop_front(); + } + m_opr = busy_oprs.front().opr; + busy_oprs.pop_front(); + return m_opr; + } + + } + + template + void release_opr() { + if constexpr (in_use) { + for (auto tracker : comp_node_trackers) { + tracker->add_opr(); + } + } + } +}; + + +class ProxyGraphTypeI : public ProxyGraphBase { + class StaticInferManager : public StaticInferManagerBase { + ProxyGraph::MiniGraph* target = nullptr; + + friend class ProxyGraphTypeI; + + public: + void register_shape_infer(VarNode* var, const cg::static_infer::ShapeInferDesc& desc) override { + target->register_shape_infer(var, desc); + }; + void register_value_infer(VarNode* var, const cg::static_infer::ValueInferDesc& desc) override { + target->register_value_infer(var, desc); + }; + cg::static_infer::InferType get_infer_type(VarNode*) override { + return {cg::static_infer::InferType::MISSING_INP, cg::static_infer::InferType::MISSING_INP}; + } + // some poorly written inference func would call infer_{shape,value} + const TensorShape& infer_shape(VarNode* var) override { + return target->infer_shape(var); + } + const DeviceTensorND& infer_value(VarNode* var) override { + return target->infer_value(var); + } + }; + + ProxyGraph::MiniGraph* target = nullptr; + StaticInferManager m_static_infer_manager; + std::unordered_map m_mini_graph_cache; + size_t opr_count = 0; + + static thread_local std::unique_ptr sm_instance; + + friend class ProxyGraph::MiniGraph; + + size_t nr_oprs_in_graph() const override { + return opr_count; + } + + size_t next_node_id() override { + return opr_count; + } + + std::shared_ptr on_comp_node_finalize() override { + sm_instance.reset(); + return {}; + } + + cg::static_infer::StaticInferManager& static_infer_manager() override { + return m_static_infer_manager; + } + + void attach(ProxyGraph::MiniGraph* target_) { + target = target_; + m_static_infer_manager.target = target_; + } + + struct AttachGuard { + ProxyGraphTypeI* owner = nullptr; + ProxyGraph::MiniGraph* target = nullptr; + + AttachGuard(ProxyGraphTypeI* owner_ = nullptr, ProxyGraph::MiniGraph* target_ = nullptr) + : owner(owner_), target(target_) {} + AttachGuard(AttachGuard&) = delete; + AttachGuard& operator=(AttachGuard&) = delete; + AttachGuard(AttachGuard&& rhs) : owner(rhs.owner), target(rhs.target) {rhs.owner = nullptr;} + AttachGuard& operator=(AttachGuard&& rhs) = delete; + ~AttachGuard() {if (owner) owner->attach(target);} + }; + + [[nodiscard]] + AttachGuard scoped_attach(ProxyGraph::MiniGraph* target_) { + attach(target_); + return attach_guard(); + } + + [[nodiscard]] + AttachGuard attach_guard(ProxyGraph::MiniGraph* target_ = nullptr) { + return {this, target_}; + } + +public: + OperatorNodeBase* insert_opr(std::unique_ptr opr_uniqp) override { + return target->insert_opr(std::move(opr_uniqp)); + } + + static ProxyGraphTypeI& inst() { + if (!sm_instance) { + sm_instance.reset(new ProxyGraphTypeI); + } + return *sm_instance; + } + + std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, + const SmallVector& inputs) { + size_t buf_size = 2 * inputs.size() + 1; + size_t buf[buf_size]; + size_t pos = 0; + buf[pos++] = def.hash(); + for (auto&& desc : inputs) { + buf[pos++] = mgb::hash(desc.layout.dtype.handle()); + buf[pos++] = mgb::hash(desc.comp_node); + } + mgb_assert(pos == buf_size); + auto key = XXHash{}.update(buf, buf_size*sizeof(size_t)).digest(); + auto it = m_mini_graph_cache.find(key); + if (it == m_mini_graph_cache.end()) { + auto&& result = m_mini_graph_cache.emplace( + std::piecewise_construct, + std::make_tuple(key), + std::forward_as_tuple(*this, def, inputs)); + mgb_assert(result.second); + it = result.first; + } + auto& minigraph = it->second; + auto _ = scoped_attach(&minigraph); + auto sess = minigraph.infer_session(inputs); + std::tuple, bool> ret; + auto& [descs, noerr] = ret; + descs.reserve(minigraph.output_size()); + for (size_t i = 0; i < minigraph.output_size(); ++i) { + descs.emplace_back(); + auto& desc = descs.back(); + desc.layout.dtype = minigraph.output_var(i)->dtype(); + desc.comp_node = minigraph.output_var(i)->comp_node(); + if (auto* shape = sess.infer_shape(i, false)) { + desc.layout.init_contiguous_stride(*shape); + } else { + noerr = false; + } + } + return ret; + } +}; + +} // namespace mgb::imperative::proxy_graph diff --git a/imperative/src/impl/proxy_graph/proxy_graph.cpp b/imperative/src/impl/proxy_graph/proxy_graph.cpp new file mode 100644 index 000000000..5185e0a65 --- /dev/null +++ b/imperative/src/impl/proxy_graph/proxy_graph.cpp @@ -0,0 +1,27 @@ +#include "./mini_graph.h" +// #include "../proxy_graph.h" + +namespace mgb::imperative::proxy_graph { + MGB_DYN_TYPE_OBJ_FINAL_IMPL(ProxyGraph::InputPlaceholder); + + thread_local std::unique_ptr ProxyGraphTypeI::sm_instance = {}; +} // namespace mgb::imperative::proxy_graph + +namespace mgb::imperative::proxy_graph_detail { + +std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, + const SmallVector& inputs) { + auto ret = proxy_graph::ProxyGraphTypeI::inst().infer_output_attrs_fallible(def, inputs); + // auto ref = ProxyGraph::get_default_graph()->infer_output_attrs_fallible(def, inputs); + // auto& [a, _1] = ret; + // auto& [b, _2] = ref; + // if (a.size() != b.size()) mgb_trap(); + // for (size_t i = 0; i < a.size(); ++i) { + // if (a[i].layout.dtype != b[i].layout.dtype) mgb_trap(); + // if (a[i].comp_node != b[i].comp_node) mgb_trap(); + // if (!a[i].layout.eq_shape(b[i].layout)) mgb_trap(); + // } + return ret; +} + +} // namespace mgb::imperative::proxy_graph_detail diff --git a/imperative/src/impl/proxy_graph/proxy_graph_base.h b/imperative/src/impl/proxy_graph/proxy_graph_base.h new file mode 100644 index 000000000..54b83eead --- /dev/null +++ b/imperative/src/impl/proxy_graph/proxy_graph_base.h @@ -0,0 +1,118 @@ +#include "megbrain/graph/cg.h" + +namespace mgb::imperative::proxy_graph { + +using cg::VarNode; + +struct ExecEnvBase : cg::GraphExecutable::ExecEnv { + void dispatch_on_comp_node(CompNode, Task&& task) override { + task(); + } + + void dispatch_on_comp_node_with_mask(CompNode, Task&&, cg::ExecutionMask*) override {mgb_assert(0);} + void pause_exec() override {mgb_assert(0);} + void resume_exec() override {mgb_assert(0);} +}; + +struct StaticInferManagerBase : cg::static_infer::StaticInferManager { +protected: + void register_shape_infer(VarNode*, const cg::static_infer::ShapeInferDesc&) override {mgb_assert(0);}; + void register_value_infer(VarNode*, const cg::static_infer::ValueInferDesc&) override {mgb_assert(0);}; + cg::static_infer::InferType get_infer_type(VarNode*) override {mgb_assert(0);}; + const TensorShape& infer_shape(VarNode*) override {mgb_assert(0);} + const TensorShape* infer_shape_fallible(VarNode*) override {mgb_assert(0);} + const DeviceTensorND& infer_value(VarNode*) override {mgb_assert(0);} + const DeviceTensorND* infer_value_fallible(VarNode*) override {mgb_assert(0);} + cg::static_infer::DepVal get_rt_static_source_deps(const cg::static_infer::DepElement&) override {mgb_assert(0);} +}; + +struct SeqCompNodeOptimizerBase : cg::SeqCompNodeOptimizer { +protected: + void register_stream_var(VarNode*, StreamPropType) override {} + void register_propagate_function(VarNode*, PropFunction) override {} + StreamPropType stream_prop_type(VarNode*) override {mgb_assert(0);} +}; + +struct ProxyGraphBase : cg::ComputingGraph { +private: + VarReceiverInfo m_var_receiver_info; + SeqCompNodeOptimizerBase m_seq_comp_node_optimizer; + StaticInferManagerBase m_static_infer_manager; + +protected: + MemPool m_var_node_pool; + + ProxyGraphBase() { + options().imperative_proxy_graph = true; + options().no_force_inplace = true; + options().log_level = 0; + m_var_receiver_info.dev_value = 1; + m_var_receiver_info.allow_empty_value = 1; + } + + void* alloc_varnode_storage() override { + return m_var_node_pool.alloc_raw(); + } + + void free_varnode_storage(void* ptr) override { + m_var_node_pool.free_raw(ptr); + } + + const VarReceiverInfo& var_receiver_in_current_comp_seq(const VarNode *var) const override { + return m_var_receiver_info; + } + + cg::static_infer::StaticInferManager& static_infer_manager() override { + return m_static_infer_manager; + } + + cg::SeqCompNodeOptimizer& seq_comp_node_optimizer() override { + return m_seq_comp_node_optimizer; + } + + std::shared_ptr on_comp_node_finalize() override { + return {}; + } + + std::unique_ptr compile(const OutputSpec&) override {mgb_assert(0);} + SmallVector> compile_multi_part(const SmallVector&) override {mgb_assert(0);} + cg::AsyncExecutable* current_comp_seq() override {mgb_assert(0);} + std::string get_mem_allocation_info() const override {mgb_assert(0);} + VarNode* find_var_by_id(size_t) const override {mgb_assert(0);} + void share_device_memory_with(ComputingGraph&) override {mgb_assert(0);} + void set_device_memory_allocator(std::shared_ptr) override {mgb_assert(0);} + size_t get_device_memory_size(CompNode) override {mgb_assert(0);} + size_t clear_device_memory() override {mgb_assert(0);} + void set_as_subgraph(ComputingGraph&) override {mgb_assert(0);} + void record_async_error(std::unique_ptr) override {mgb_assert(0);} +}; + +MGB_DEFINE_OPR_CLASS( + ProxyGraph::InputPlaceholder, + cg::OperatorNodeBase) // { + + void on_output_comp_node_stream_changed() override {mgb_assert(0);} + void init_output_comp_node() override {} + void init_output_format() override {} + void init_output_dtype() override {} + void init_output_static_infer_desc() override {} + void init_output_mem_plan(bool) override {mgb_assert(0);} + void do_execute(ExecEnv&) override {mgb_assert(0);} + +public: + InputPlaceholder(cg::ComputingGraph& graph) + : Super(&graph, {}, "placeholder", {}) { + add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC); + // never dedup + add_equivalence_component>(this); + } + + InputPlaceholder(cg::ComputingGraph& graph, DType dtype, CompNode cn) + : InputPlaceholder(graph) { + output(0)->dtype(dtype).comp_node(cn); + } +}; + +using InputPlaceholder = ProxyGraph::InputPlaceholder; + +} // namespace mgb::imperative::proxy_graph diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index 659e88f91..b97facb16 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -80,11 +80,11 @@ apply_on_physical_tensor(const OpDef& def, return outputs; } -std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, - const SmallVector& inputs) { - auto&& graph = ProxyGraph::get_default_graph(); - return graph->infer_output_attrs_fallible(def, inputs); -} +// std::tuple, bool> infer_output_attrs_fallible(const OpDef& def, +// const SmallVector& inputs) { +// auto&& graph = ProxyGraph::get_default_graph(); +// return graph->infer_output_attrs_fallible(def, inputs); +// } namespace { diff --git a/imperative/src/include/megbrain/imperative/physical_tensor.h b/imperative/src/include/megbrain/imperative/physical_tensor.h index 8c98c5c95..1b8e18297 100644 --- a/imperative/src/include/megbrain/imperative/physical_tensor.h +++ b/imperative/src/include/megbrain/imperative/physical_tensor.h @@ -89,10 +89,18 @@ public: return m_blob->comp_node(); } + DType dtype() const { + return m_layout.dtype; + } + TensorLayout layout() const { return m_layout; } + const TensorShape& shape() const { + return m_layout; + } + DeviceTensorND dev_tensor(); static TensorPtr make_scalar(DTypeScalar value, CompNode cn); diff --git a/src/core/include/megbrain/graph/static_infer.h b/src/core/include/megbrain/graph/static_infer.h index 93859bde6..3a53ff69e 100644 --- a/src/core/include/megbrain/graph/static_infer.h +++ b/src/core/include/megbrain/graph/static_infer.h @@ -16,7 +16,10 @@ namespace mgb { namespace imperative { - class ProxyGraph; +class ProxyGraph; +namespace proxy_graph { +class ProxyGraph; +} // namespace proxy_graph } // namespace imperative namespace cg { @@ -56,6 +59,7 @@ namespace static_infer { friend class StaticInferManagerImpl; friend class imperative::ProxyGraph; + friend class imperative::proxy_graph::ProxyGraph; public: /*! @@ -342,4 +346,3 @@ using StaticInferInpVal = static_infer::InpVal; } // mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} - diff --git a/src/core/include/megbrain/graph/var_node.h b/src/core/include/megbrain/graph/var_node.h index cc1c9b078..4f329d710 100644 --- a/src/core/include/megbrain/graph/var_node.h +++ b/src/core/include/megbrain/graph/var_node.h @@ -23,7 +23,10 @@ namespace mgb { namespace imperative { - class ProxyGraph; +class ProxyGraph; +namespace proxy_graph { +class ProxyGraph; +} } // namespace imperative namespace cg { @@ -587,6 +590,7 @@ class VarNode final: public GraphNodeBase { friend class EagerEvalManager; friend class MemAllocPlan; friend class imperative::ProxyGraph; + friend class imperative::proxy_graph::ProxyGraph; }; enum class VarNode::Flag : uint32_t { -- GitLab