/** * \file src/core/include/megbrain/imperative.h * * This file is part of MegBrain, a deep learning framework developed by Megvii. * * \copyright Copyright (c) 2014-2019 Megvii Inc. All rights reserved. * */ #pragma once #include "megbrain/graph.h" #include "megbrain/imperative/physical_tensor.h" namespace mgb { namespace imperative { class OpDef; struct OpTrait; struct BackwardGraphResult { std::shared_ptr backward; std::vector save_for_backward; std::vector input_has_grad; }; class OpDef : public Hashable { mutable const OpTrait* m_trait = nullptr; public: virtual ~OpDef() = default; virtual std::shared_ptr copy() const = 0; static std::shared_ptr make_from_op_node( cg::OperatorNodeBase* node); static SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs); static void exec( const OpDef& def, const SmallVector& inputs, const SmallVector& outputs); static cg::OperatorNodeBase* apply_on_var_node( const OpDef& def, const VarNodeArray& inputs); static SmallVector infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs); static SmallVector infer_output_attrs( const OpDef& def, const SmallVector& inputs); static BackwardGraphResult make_backward_graph( const OpDef& def, const SmallVector& inputs, const SmallVector& input_requires_grad, const SmallVector& output_has_grad); const OpTrait* trait() const; virtual size_t hash() const { mgb_throw(MegBrainError, "not implemented"); } virtual bool is_same_st(const Hashable&) const { mgb_throw(MegBrainError, "not implemented"); } }; template class OpDefImplBase : public OpDef { public: virtual std::shared_ptr copy() const override { return std::shared_ptr(new T(this->cast_final_safe())); } template static std::shared_ptr make(const Args& ...args) { return std::shared_ptr(new T(args...)); } }; } // namespace imperative } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}