/** * \file imperative/src/impl/proxy_graph.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./blob_manager_impl.h" #include "./proxy_graph.h" #include "megbrain/graph/static_infer.h" #include "megbrain/graph/operator_node.h" #include "megbrain/opr/io.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" #include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/backward_graph.h" #if __cplusplus >= 201703L #include #endif namespace mgb { namespace imperative { using cg::OperatorNodeBase; template constexpr auto&& select(T&& t, F&& f) { if constexpr (p) { return std::forward(t); } else { return std::forward(f); } } MGB_DEFINE_OPR_CLASS( ProxyGraph::InputPlaceholder, cg::OperatorNodeBase) // { void on_output_comp_node_stream_changed() override { mgb_assert(0); } // TODO: consider implement following initialization method, // so InputPlaceholder can be initialized correctly during // operator insertion void init_output_comp_node() override { } void init_output_format() override { } void init_output_dtype() override { } void init_output_static_infer_desc() override { } void init_output_mem_plan(bool dynamic) override { MGB_MARK_USED_VAR(dynamic); mgb_assert(0); } void do_execute(ExecEnv &env) override { mgb_assert(0); } public: Tensor* m_tensor; InputPlaceholder(ComputingGraph& graph, Tensor* tensor = nullptr, const DeviceTensorND& static_infer_value = {}) : Super(&graph, {}, "device_value", {}), m_tensor(tensor), m_static_infer_value(static_infer_value) { mgb_assert(m_static_infer_value.empty() || m_static_infer_value.comp_node() == CompNode::default_cpu()); add_output(None)->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC); // never dedup add_equivalence_component>(this); } static SymbolVar make(ComputingGraph& graph, Tensor& tensor) { auto opr = graph.insert_opr( std::make_unique(graph, &tensor)); auto var = opr->output(0); auto&& dev_tensor = tensor.dev_tensor(); var->m_comp_node = dev_tensor.comp_node(); var->m_shape = dev_tensor.shape(); if (dev_tensor.empty()) { auto layout = dev_tensor.layout(); layout.init_contiguous_stride(); dev_tensor.reset(dev_tensor.storage(), layout); } var->m_dev_tensor = dev_tensor; var->m_mem_plan.reset_from_owner_var().chunk() .mem_alloc_status.set_from_owner_var(); return var; } static SymbolVar make(ComputingGraph& graph, const LogicalTensorDesc& desc) { auto opr = graph.insert_opr( std::make_unique(graph, nullptr, desc.value)); auto var = opr->output(0); var->m_comp_node = desc.comp_node; var->m_shape = desc.layout; var->m_dev_tensor.reset({}, TensorLayout(desc.layout.dtype)); return var; } const DeviceTensorND* get_static_infer_value(bool may_sync) { if (!m_static_infer_value.empty()) { return &m_static_infer_value; } if (m_tensor && (may_sync || m_tensor->try_get_value())) { auto&& hv = m_tensor->get_value(); mgb_assert(!hv.empty()); m_static_infer_value = hv.proxy_to_default_cpu(); // steal ownership from shared_ptr using SP = std::shared_ptr; auto& sp = const_cast(m_static_infer_value.storage().raw_storage()); static auto dummy = std::make_shared(); sp = SP(dummy, sp.get()); return &m_static_infer_value; } return nullptr; } private: DeviceTensorND m_static_infer_value; }; MGB_DYN_TYPE_OBJ_FINAL_IMPL( ProxyGraph::InputPlaceholder); class ProxyGraph::ExecEnv final : public cg::GraphExecutable::ExecEnv { public: void dispatch_on_comp_node(CompNode, Task&& task) override { task(); } void dispatch_on_comp_node_with_mask(CompNode, Task&& task, cg::ExecutionMask* mask) override { mgb_throw_if(mask, GraphError, "ExecutionMask not supported in imperative mode"); task(); } void pause_exec() override {} void resume_exec() override {} }; class ProxyGraph::StaticInferManager : public cg::static_infer::StaticInferManager { public: using Tag = cg::static_infer::Tag; using ShapeInferDesc = cg::static_infer::ShapeInferDesc; using ValueInferDesc = cg::static_infer::ValueInferDesc; using InferType = cg::static_infer::InferType; using DepVal = cg::static_infer::DepVal; using DepElement = cg::static_infer::DepElement; using DepType = cg::static_infer::DepType; using InpElement = cg::static_infer::InpElement; struct Result { TensorShape shape; DeviceTensorND value; }; ProxyGraph* owner; cg::OperatorNodeBase* cur_opr = nullptr; std::vector> shape_descs; std::vector> value_descs; std::vector inferred_outputs; StaticInferManager(ProxyGraph* owner_) : owner(owner_) {} size_t locate_output(VarNode* var) { mgb_assert(cur_opr); auto&& output_vars = cur_opr->output(); mgb_assert(shape_descs.size() == output_vars.size()); auto&& it = std::find(output_vars.begin(), output_vars.end(), var); mgb_assert(it != output_vars.end()); return it - output_vars.begin(); } void register_shape_infer(Tag dest, const ShapeInferDesc &desc) override { auto i = locate_output(dest); mgb_assert(!shape_descs[i]); shape_descs[i].emplace(desc); } void register_value_infer(Tag dest, const ValueInferDesc &desc) override { auto i = locate_output(dest); mgb_assert(!value_descs[i]); value_descs[i].emplace(desc); } InferType get_infer_type(Tag var) override { // may be called during get_proxy_opr or make_backward_graph // don't let opr apply any immediate optimization return {InferType::MISSING_INP, InferType::MISSING_INP}; if (auto opr = var->owner_opr()->try_cast_final()) { return {var->shape().ndim ? InferType::CONST : InferType::MISSING_INP, opr->m_tensor ? InferType::CONST : InferType::MISSING_INP}; } if (cur_opr) { auto&& outputs = cur_opr->output(); auto&& it = std::find(outputs.begin(), outputs.end(), var); if (it != outputs.end()) { return {infer_shape_fallible(var) ? InferType::CONST : InferType::MISSING_INP, // value inference could be expensive InferType::MISSING_INP}; } } return {InferType::MISSING_INP, InferType::MISSING_INP}; } void update() { if (cur_opr != owner->m_cur_opr) { clear(); cur_opr = owner->m_cur_opr; if (cur_opr) { auto nout = cur_opr->output().size(); shape_descs.resize(nout); value_descs.resize(nout); inferred_outputs.resize(nout); cur_opr->init_output_static_infer_desc(); } } } void clear() { cur_opr = nullptr; shape_descs.clear(); value_descs.clear(); inferred_outputs.clear(); } template auto do_infer(Tag dest, bool may_sync) -> const std::conditional_t* { // Some infer_func does not use InpVal passed to them, but // call infer_* on their inputs instead, so dest could be an input. // It is also possible that an opr call infer_* on its inputs before it // is inserted if (auto opr = dest->owner_opr()->try_cast_final()) { if constexpr (is_shape) { auto* shp = &dest->shape(); return shp->ndim ? shp : nullptr; } else { return opr->get_static_infer_value(may_sync); } } mgb_assert(cur_opr); mgb_assert(cur_opr->output().size() == shape_descs.size()); // dest must be an output now auto i = locate_output(dest); auto& result = inferred_outputs[i]; auto& desc = select(shape_descs[i], value_descs[i]); // return if no need to call infer_func if constexpr (is_shape) { if (result.shape.ndim != 0) { return &result.shape; } } else { if (!result.value.empty()) { return &result.value; } } if (!desc) { return nullptr; } // fill args for infer_func cg::static_infer::InpVal args{1}; args.val.reserve(desc->deps.size()); auto push_shape = [&args](const TensorShape* shape) { args.val.emplace_back(); args.val.back().m_shape = shape; }; auto push_value = [&args](const DeviceTensorND* value) { args.val.emplace_back(); args.val.back().m_value = value; }; for (auto&& dep : desc->deps) { if (auto opr = dep.dest->owner_opr()->template try_cast_final()) { if (dep.type == DepType::SHAPE) { if (dep.dest->shape().ndim) { push_shape(&dep.dest->shape()); } else { return nullptr; } } else { if (auto* p = opr->get_static_infer_value(may_sync)) { push_value(p); } else { return nullptr; } } continue; } // dep must be an output if (dep.type == DepType::SHAPE) { if (auto* p = do_infer(dep.dest, may_sync)) { push_shape(p); } else { return nullptr; } } else { if (auto* p = do_infer(dep.dest, may_sync)) { push_value(p); } else { return nullptr; } } } // call infer_func if constexpr (is_shape) { if (!desc->infer_func(result.shape, args)) { mgb_log_warn("something is missing for shape inference of %s", cur_opr->dyn_typeinfo()->name); return nullptr; } return &result.shape; } else { if (!desc->infer_func(result.value, args)) { mgb_log_warn("something is missing for value inference of %s", cur_opr->dyn_typeinfo()->name); return nullptr; } return &result.value; } } const TensorShape& infer_shape(Tag var) override { auto* p = do_infer(var, true); mgb_assert(p, "failed to infer shape for %s", var->name().c_str()); return *p; } const TensorShape* infer_shape_fallible(Tag var) override { return do_infer(var, false); } const DeviceTensorND& infer_value(Tag var) override { auto* p = do_infer(var, true); mgb_assert(p, "failed to infer value for %s", var->name().c_str()); return *p; } const DeviceTensorND* infer_value_fallible(Tag var) override { return do_infer(var, false); } DepVal get_rt_static_source_deps(const DepElement&) override {mgb_assert(0);} }; class ProxyGraph::SeqCompNodeOptimizer : public cg::SeqCompNodeOptimizer { void register_stream_var(VarNode*, StreamPropType) override {} void register_propagate_function(VarNode*, PropFunction) override {} StreamPropType stream_prop_type(VarNode*) override {mgb_assert(0);} }; class ProxyGraph::ProxyGraphImpl : public cg::ComputingGraph { static std::atomic m_node_id; ProxyGraph* m_owner; MemPool m_var_node_pool; std::vector> m_opr_refkeeper; std::mutex m_opr_refkeeper_mtx; CompNode::UnorderedSet m_used_comp_node; VarReceiverInfo m_var_receiver_info; public: ~ProxyGraphImpl() { mgb_assert(!m_owner->m_cur_opr); if (is_finalized()) return; for (auto&& i : m_used_comp_node) { if (i.device_type() == CompNode::DeviceType::CUDA) continue; if (i.device_type() == CompNode::DeviceType::ROCM) continue; i.sync(); } } ProxyGraphImpl(ProxyGraph* owner) : m_owner(owner) { options().imperative_proxy_graph = true; options().no_force_inplace = true; options().log_level = 0; m_var_receiver_info.dev_value = 1; m_var_receiver_info.allow_empty_value = 1; } static std::unique_ptr make(ProxyGraph* owner) { return std::make_unique(owner); } void add_used_comp_node(CompNode cn) { m_used_comp_node.insert(cn); } bool invalid() const { return is_finalized() || nr_oprs_in_graph() > m_owner->m_max_op_cnt; } size_t next_node_id() override { return m_node_id.fetch_add(1); } void* alloc_varnode_storage() override { return m_var_node_pool.alloc_raw(); } void free_varnode_storage(void* ptr) override { m_var_node_pool.free_raw(ptr); } OperatorNodeBase* insert_opr(std::unique_ptr opr_uniqp) override { mgb_assert(!is_finalized()); auto opr = opr_uniqp.get(); if (!opr->inserted_in_graph()) { m_opr_refkeeper.emplace_back(std::move(opr_uniqp)); opr->set_inserted_in_graph(); opr->init_output_comp_node(); opr->init_output_dtype(); opr->init_output_format(); } return opr; } cg::static_infer::StaticInferManager& static_infer_manager() override { return *m_owner->m_static_infer_manager; } cg::SeqCompNodeOptimizer& seq_comp_node_optimizer() override { return *m_owner->m_seq_comp_node_optimizer; } std::shared_ptr on_comp_node_finalize() override { MGB_LOCK_GUARD(m_opr_refkeeper_mtx); mgb_assert(!m_owner->m_cur_opr); // finalize would do sync first m_opr_refkeeper.clear(); return {}; } const VarReceiverInfo& var_receiver_in_current_comp_seq( const VarNode *var) const override { return m_var_receiver_info; } size_t nr_oprs_in_graph() const override {return m_opr_refkeeper.size();} void record_async_error(std::unique_ptr async_exc) override { if (!ProxyGraph::tm_async_error) { std::swap(async_exc, tm_async_error); } } std::unique_ptr compile(const OutputSpec &out_spec) override {mgb_assert(0);} SmallVector> compile_multi_part( const SmallVector& out_specs) override {mgb_assert(0);} cg::AsyncExecutable* current_comp_seq() override {mgb_assert(0);} std::string get_mem_allocation_info() const override {mgb_assert(0);} VarNode* find_var_by_id(size_t id) const override {mgb_assert(0);} void share_device_memory_with(ComputingGraph &other) override {mgb_assert(0);} void set_device_memory_allocator( std::shared_ptr allocator) override {mgb_assert(0);} size_t get_device_memory_size(CompNode cn) override {mgb_assert(0);} size_t clear_device_memory() override {mgb_assert(0);} void set_as_subgraph(ComputingGraph &par_graph) override {mgb_assert(0);} }; std::atomic ProxyGraph::ProxyGraphImpl::m_node_id = 0; ProxyGraph::ProxyGraph() : m_graph(ProxyGraphImpl::make(this)), m_env{new ExecEnv}, m_static_infer_manager(new StaticInferManager(this)), m_seq_comp_node_optimizer(new SeqCompNodeOptimizer()) { } void ProxyGraph::reset() { mgb_assert(!m_cur_opr); m_graph = ProxyGraphImpl::make(this); } ProxyGraph* ProxyGraph::get_default_graph() { static thread_local ProxyGraph inst; if (inst.m_graph->invalid()) { inst.reset(); } return &inst; } class ProxyGraph::CurOprGuard { public: CurOprGuard(ProxyGraph* owner, OperatorNodeBase* opr) : m_owner(owner) { mgb_assert(!owner->m_cur_opr); owner->m_cur_opr = opr; } CurOprGuard(const CurOprGuard&) = delete; ~CurOprGuard() { m_owner->cleanup(); } private: ProxyGraph* m_owner; }; #define CUR_OPR_GUARD(opr) CurOprGuard MGB_TOKENPASTE2(__cur_opr_guard_, __LINE__)(this, opr) /*********************** Physical Tensor Impl ***********************/ SmallVector ProxyGraph::infer_output_attrs( const OpDef& opdef, const SmallVector& inputs) { SmallVector ret; CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); do_shape_infer(true); for (auto&& i: m_cur_opr->usable_output()) { mgb_assert(i->dtype().valid() && i->comp_node().valid()); mgb_assert(i->shape().ndim || i->contain_flag(VarNode::Flag::NO_SYS_MEM_ALLOC)); ret.push_back({{i->shape(), i->dtype()}, i->comp_node()}); } return ret; } void ProxyGraph::invoke_op(const OpDef& opdef, const SmallVector& inputs, const SmallVector& outputs, const SmallVector& workspaces) { CUR_OPR_GUARD(get_proxy_opr(opdef, inputs)); init_output_tensor(outputs, workspaces); for (auto oup : m_cur_opr->output()) { m_graph->add_used_comp_node(oup->comp_node()); } m_cur_opr->execute(*m_env); } void ProxyGraph::cleanup() { if (m_cur_opr) { for (auto&& i : m_cur_opr->input()) { i->m_dev_tensor.storage({}); } for (auto&& i : m_cur_opr->output()) { i->m_dev_tensor.storage({}); } m_static_infer_manager->clear(); } m_cur_opr = nullptr; } void ProxyGraph::init_output_tensor(const SmallVector& outputs, const SmallVector& workspaces) { // get proxy opr auto proxy = m_cur_opr; do_shape_infer(true); size_t j = 0; size_t k = 0; for (auto&& var : proxy->output()) { auto &&chk = var->m_mem_plan.reset_from_owner_var().chunk(); if (var->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { // workspace if (workspaces.size()) { mgb_assert(k < workspaces.size()); auto && layout = workspaces[k]->layout(); mgb_assert(var->comp_node() == workspaces[k]->comp_node() && var->shape().eq_shape(layout) && var->dtype() == layout.dtype); var->m_dev_tensor = workspaces[k]->dev_tensor(); ++ k; } else { TensorLayout layout{var->shape(), var->dtype(), var->format()}; var->m_dev_tensor = BlobManager::inst()->alloc_workspace_with_defrag(var->comp_node(), layout); } } else { mgb_assert(j < outputs.size()); auto &&tensor = outputs[j]; auto &&layout = tensor->layout(); mgb_assert(var->comp_node() == tensor->comp_node() && var->shape().eq_shape(layout) && var->dtype() == layout.dtype); var->assign_dev_tensor_from_tensor(tensor->dev_tensor()); ++ j; } chk.mem_alloc_status.set_from_owner_var(); } mgb_assert(j == outputs.size()); mgb_assert(k == workspaces.size()); // Memory forwarding was bypassed in megbrain with graph option // imerative_proxy_graph on, here we call mem_plan_fwd_in2out_readonly // to initialize some opr(e.g. Subtensor)'s internal state // TODO: implement memory forwarding proxy->mem_plan_fwd_in2out_readonly(); { // some opr (e.g. Reduce) rely on on_mem_status_changed to set // input/output tensor corretly, since we bypass var_node_mem_mgr // on_mem_status_changed should be called here auto&& cb = proxy->get_opr_event_callback().on_mem_status_changed; if (cb.valid()) { cb.val()(); } } } cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( const OpDef& opdef, const SmallVector& inputs) { VarNodeArray vinputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++ i) { vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); } auto opr = OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); mgb_assert(!opr->same_type()); for (auto &&i : opr->input()) { mgb_assert(i->owner_opr()->same_type()); } return opr; } /*********************** Logical Tensor Impl ***********************/ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef, const SmallVector& inputs) { return get_proxy_opr(opdef, inputs)->usable_output().size(); } std::tuple, bool> ProxyGraph::infer_output_attrs_fallible( const OpDef& opdef, const SmallVector& inputs) { auto opr = get_proxy_opr(opdef, inputs); CUR_OPR_GUARD(opr); SmallVector outputs; bool validated = do_shape_infer(false); for (auto&& i : opr->usable_output()) { outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()}); } bool need_check = opr->same_type(); return {outputs, validated && !need_check}; } std::tuple, SmallVector> ProxyGraph::infer_output_mem_desc( const OpDef& def, const SmallVector& inputs_tensors, const SmallVector& inputs_mems) { auto opr = get_proxy_opr(def, inputs_tensors); CUR_OPR_GUARD(opr); do_shape_infer(true); SmallVector outputs; SmallVector workspaces; size_t cur_id = 0; for (auto&& i : opr->output()) { if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { workspaces.push_back({{i->shape(), i->dtype(), i->format()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)}); } else { outputs.push_back({{i->shape(), i->dtype()}, 0, i->comp_node(), StorageIdentifier::make(++ cur_id)}); } } return {outputs, workspaces}; } struct ProxyGraph::GradGraph { cg::VarNodeArray inputs; cg::VarNodeArray outputs; cg::VarNodeArray output_grads; cg::VarNode* grad; }; EncodedSubraph ProxyGraph::make_backward_graph( const OpDef& opdef, const SmallVector& input_descs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad) { ThinHashMap var2idx; auto push = [&var2idx, cnt=1](VarNode* var) mutable { //cnt is always greater non zero auto&& ret = var2idx.emplace(var, cnt ++); mgb_assert(ret.second, "var %s has been already inserted", var->cname()); return ret.first->second; }; auto inputs = make_input_place_holders(input_descs); auto fwd = OpDef::apply_on_var_node(opdef, inputs)[0]->owner_opr(); auto&& outputs = fwd->usable_output(); SmallVector output_descs; for (auto&& i : outputs) { output_descs.push_back({TensorLayout{i->dtype()}, i->comp_node()}); } auto output_grads = make_input_place_holders(output_descs); mgb_assert(output_grads.size() == output_has_grad.size()); bool any_input_has_grad = false; for (size_t i = 0; i < output_grads.size(); ++ i) { if (!output_has_grad[i]) { output_grads[i] = nullptr; } else { any_input_has_grad = true; } } if (!any_input_has_grad) { return {}; } auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); EncodedSubraph result; auto&& igraph = result.graph; size_t nr_backward_graph_inputs = 0; auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, &nr_backward_graph_inputs](cg::OperatorNodeBase* op) { if (auto t = as_tensor(op)) { mgb_assert(op->output().size() == 1); igraph.constants.emplace_back(push(op->output(0)), std::move(t)); } else if (op->same_type()) { ++ nr_backward_graph_inputs; push(op->output(0)); } else { SmallVector inputs, outputs; for (auto &&i : op->input()) { if (i->owner_opr() == fwd) { if (var2idx.find(i) == var2idx.end()) { ++ nr_backward_graph_inputs; push(i); } } inputs.push_back(var2idx.at(i)); } for (auto &&i : op->usable_output()) { outputs.push_back(push(i)); } igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); } }; // set backward graph outputs cg::DepOprIter iter{gen_expr}; iter.set_visited(fwd); result.output_mask.resize(inputs.size()); VarNodeArray output_grads_with_unused_var; { auto iter = output_grads.begin(); for (auto&& i : fwd->output()) { if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { // the var node with VOLATILE_CONTENT(e.g. workspace // or an empty var) would not be considered as a normal // output, so its grad is always NULL output_grads_with_unused_var.push_back(nullptr); } else { output_grads_with_unused_var.push_back(*iter); ++ iter; } } mgb_assert(iter == output_grads.end()); } Maybe grad_results; for (size_t i = 0; i < inputs.size(); ++ i) { VarNode* grad; if (grad_results.valid()) { grad = grad_results.val()[i]; } else { mgb_assert(gfunc, "could not find grad function"); auto res = (*gfunc)(fwd, i, output_grads_with_unused_var); if (res.from_single()) { grad = res.single(); } else { grad_results.emplace(res.all(fwd)); grad = grad_results.val()[i]; } } if (grad && !grad->owner_opr()->same_type() && input_requires_grad[i]) { mgb_assert(!grad->owner_opr()->same_type(), "gradient of operator %s w.r.t. input #%lu is " "either not well defined or not implemented", fwd->dyn_typeinfo()->name, i); iter.add(grad); igraph.outputs.push_back(var2idx.at(grad)); result.output_mask[i] = true; } else { result.output_mask[i] = false; } } if (igraph.outputs.empty()) { return {}; } // set backward graph inputs igraph.inputs.reserve(nr_backward_graph_inputs); result.input_mask.reserve(nr_backward_graph_inputs); auto write_inputs = [&igraph, &var2idx, &result](const VarNodeArray& vars) { for (auto&& i: vars) { auto&& iter = var2idx.find(i); if (iter != var2idx.end()) { igraph.inputs.push_back(iter->second); result.input_mask.push_back(true); } else { result.input_mask.push_back(false); } } }; write_inputs(inputs); write_inputs(outputs); write_inputs(output_grads); mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); return result; } cg::OperatorNodeBase* ProxyGraph::get_proxy_opr(const OpDef& opdef, const SmallVector& inputs) { mgb_assert(!m_cur_opr); auto vinputs = make_input_place_holders(inputs); return OpDef::apply_on_var_node(opdef, vinputs)[0]->owner_opr(); } VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector& inputs) { VarNodeArray vinputs(inputs.size()); for (size_t i = 0; i < inputs.size(); ++ i) { vinputs[i] = InputPlaceholder::make(*m_graph, inputs[i]).node(); } return vinputs; } /*********************** Common Impl ***********************/ bool ProxyGraph::do_shape_infer(bool sync_value) { m_static_infer_manager->update(); bool validated = true; for (auto* var : m_cur_opr->output()) { if (sync_value) { var->shape(m_static_infer_manager->infer_shape(var)); } else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) { var->shape(*shape); } else { validated = false; } } return validated; } TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { // TODO : maybe some tensor should copy value from origin opr rather than // share the RawStorage mgb_assert(share, "can't share memory with opr %s", opr->cname()); if (opr->same_type()) { auto&& dv = opr->cast_final_safe().value(); HostTensorND hv(dv.comp_node(), dv.shape(), dv.dtype()); const DeviceTensorND* cpu_value; // get host value if (opr->owner_graph() == m_graph.get()) { CUR_OPR_GUARD(opr); m_static_infer_manager->update(); cpu_value = m_static_infer_manager->infer_value_fallible(opr->output(0)); } else { cpu_value = opr->owner_graph()->static_infer_manager().infer_value_fallible(opr->output(0)); } mgb_assert(cpu_value); mgb_assert(cpu_value->comp_node() == CompNode::default_cpu()); // default_cpu is synchronous with respect to caller hv.proxy_to_default_cpu().copy_from_fixlayout(*cpu_value); return Tensor::make(dv, hv); } else if (opr->same_type()) { return Tensor::make(opr->cast_final_safe().get_dev_tensor()); } else { return {}; } } thread_local std::unique_ptr ProxyGraph::tm_async_error; } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}