/** * \file src/core/impl/imperative/proxy_graph.h * * This file is part of MegBrain, a deep learning framework developed by Megvii. * * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. * */ #pragma once #include "megbrain/imperative.h" #include "megbrain/graph/cg.h" #include "megbrain/graph/grad_impl.h" #include "megbrain/comp_node.h" #include "megbrain/imperative/ops/backward_graph.h" namespace mgb { namespace imperative { class ProxyGraph : public NonCopyableObj { public: static ProxyGraph* get_default_graph(); /********************** Physical Tensor API **********************/ SmallVector infer_output_attrs( const OpDef& opdef, const SmallVector& inputs); void invoke_op( const OpDef& opdef, const SmallVector& inputs, const SmallVector& outputs); BackwardGraphResult make_backward_graph( const OpDef& opdef, const SmallVector& input_descs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad); /********************** Logical Tensor API **********************/ size_t get_opr_output_size( const OpDef& opdef, const SmallVector& inputs); SmallVector infer_output_attrs_fallible( const OpDef& opdef, const SmallVector& inputs); private: ProxyGraph(); class ProxyGraphImpl; class ExecEnv; class StaticInferManager; class SeqCompNodeOptimizer; class InputPlaceholder; struct ProxyGraphInst; struct GradGraph; struct CurOprGuard; void reset(); /********************** Physical Tensor Helper **********************/ void cleanup(); void init_output_tensor( const SmallVector& outputs); cg::OperatorNodeBase* get_proxy_opr( const OpDef& opdef, const SmallVector& inputs); /********************** Logical Tensor Helper **********************/ cg::OperatorNodeBase* get_proxy_opr( const OpDef& opdef, const SmallVector& inputs); cg::VarNodeArray make_input_place_holders( const SmallVector& inputs); /********************** Common Helper **********************/ void do_shape_infer(bool sync_value); TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true); cg::OperatorNodeBase* m_cur_opr = nullptr; std::unique_ptr m_graph; size_t m_max_op_cnt = 1000; std::unique_ptr m_env; std::unique_ptr m_static_infer_manager; std::unique_ptr m_seq_comp_node_optimizer; }; } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}