/** * \file src/opr/include/megbrain/opr/tensor_manip.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #pragma once #include "megbrain/graph.h" #include "megbrain/opr/internal/identical_fwd.h" #include "megbrain/opr/internal/indexing_helper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/out_shape_by_sym_var.h" #include "megbrain/opr/internal/param_tag_defs.h" #include "megbrain/opr/param_defs.h" #include "megdnn/oprs.h" namespace mgb { namespace opr { /*! * \brief get the shape of a var and store in a 1-dim tensor * * For multiple inputs, shape would be the broadcasted shape. * * \param axis output shape of a single axis */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(GetVarShape, cg::SingleCNOperatorNodeBase) // { class ShapeDevValueExecDep; public: using Param = megdnn::param::OptionalAxisV1; GetVarShape( const VarNodeArrayView& inp, Param axis, const OperatorNodeConfig& config); static SymbolVar make( SymbolVar inp, Param axis = {}, const OperatorNodeConfig& config = {}) { return make(SymbolVarArray({inp}), axis, config); } //! get broadcasted shape MGE_WIN_DECLSPEC_FUC static SymbolVar make( const VarNodeArrayView& inp, Param axis = {}, const OperatorNodeConfig& config = {}); Param param() const { return m_axis; } private: const Param m_axis; //! cached shape, to avoid h2d copy when shape not changed bool m_cached_shape_dev_v_synced = false; TensorShape m_cached_shape; TensorShapeArray m_src_shapes; DeviceTensorND m_cached_shape_cpu_v{CompNode::default_cpu()}, m_cached_shape_dev_v; //! update m_cached_shape from m_src_shapes MGE_WIN_DECLSPEC_FUC void update_cached_shape(); //! update m_cached_shape for static infer MGE_WIN_DECLSPEC_FUC void update_for_static_infer( const cg::static_infer::InpVal& inp); MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override; MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override; MGE_WIN_DECLSPEC_FUC void record_execute_deps(ExecDependencyArray& deps) override; }; namespace intl { /*! * \brief base class for reshape and broadcast */ MGB_DEFINE_CLS_WITH_SUPER( ReshapeBrdcastHelper, ReadonlyFwdHelper) // { bool m_incompatible_inp_layout = false; MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override final; MGE_WIN_DECLSPEC_FUC void outshape_by_symvar_do_get_output_shape( TensorShape& dest, const ShapeInferInfo& shpinfo) override final; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override final; MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override final; MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override; MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override; protected: using Super::Super; void reshapebrdcast_init(VarNode* inp, VarNode* tshp); /*! * \brief get dest layout * * Invalid TensorLayout can be returned if reshape fails */ virtual Maybe reshapebrdcast_get_dest_layout( const TensorLayout& src, const TensorShape& tshape) const = 0; /*! * \brief whether output shape depends on input shape; if true, * reshapebrdcast_get_dest_layout() would be called to get output * shape; otherwise output shape would be value of input(1) */ virtual bool reshapebrdcast_output_shape_need_input_shape() const = 0; }; } // namespace intl /*! * \brief reshape a tensor in-place, without changing total span * \param unspec_axis the axis that shape is not specified in input, but should * be calculated from total number of elements and other dims in dest shape */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Reshape, intl::ReshapeBrdcastHelper) // { public: using Param = megdnn::param::OptionalAxisV1; MGE_WIN_DECLSPEC_FUC Reshape( VarNode* inp, VarNode* tshp, Param unspec_axis, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar inp, SymbolVar tshp, Param unspec_axis = {}, const OperatorNodeConfig& config = {}); static SymbolVar make( SymbolVar inp, const TensorShape& target_shape, Param unspec_axis = {}, const OperatorNodeConfig& config = {}) { return make( inp, cg::var_from_tensor_shape(inp, target_shape), unspec_axis, config); } Param param() const { return m_unspec_axis; } private: Param m_unspec_axis; Maybe reshapebrdcast_get_dest_layout( const TensorLayout& src, const TensorShape& tshape) const override; bool reshapebrdcast_output_shape_need_input_shape() const override; }; /*! * \brief broadcast tensor value along axes whose shape is 1 */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Broadcast, intl::ReshapeBrdcastHelper) // { Maybe reshapebrdcast_get_dest_layout( const TensorLayout& src, const TensorShape& tshape) const override; bool reshapebrdcast_output_shape_need_input_shape() const override; public: MGE_WIN_DECLSPEC_FUC Broadcast( VarNode* inp, VarNode* tshp, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar inp, SymbolVar tshp, const OperatorNodeConfig& config = {}); static SymbolVar make( SymbolVar inp, const TensorShape& target_shape, const OperatorNodeConfig& config = {}) { return make(inp, cg::var_from_tensor_shape(inp, target_shape), config); } // used for serialization using Param = megdnn::param::Empty; Param param() const { return {}; } static SymbolVar make( SymbolVar inp, SymbolVar tshp, const Param&, const OperatorNodeConfig& config) { return make(inp, tshp, config); } }; namespace intl { /*! * \brief base class for oprs that manipulate axis */ MGB_DEFINE_CLS_WITH_SUPER( AxisManipOprBase, ReadonlyFwdHelper) // { MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override final; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override final; MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override final; MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override; protected: using Super::Super; virtual TensorLayout axis_manip_get_output_layout( const TensorLayout& inp_layout) const = 0; void axis_manip_init(VarNode* inp); }; } // namespace intl /*! * \brief dimshuffle a tensor in-place, without changing total span * \param pattern non-negative intergers refer to corresponding dimension; * -1 refers to new dimension * \param ndim number of input dimensions; 0 to be inferred from pattern * * Note that dimensions with shape-1 could be dropped */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Dimshuffle, intl::AxisManipOprBase) // { std::vector m_pattern; size_t m_inp_ndim; TensorLayout axis_manip_get_output_layout( const TensorLayout& inp_layout) const override; public: MGE_WIN_DECLSPEC_FUC Dimshuffle( VarNode* inp, const std::vector& pattern, size_t ndim, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar inp, const std::vector& pattern, size_t ndim = 0, const OperatorNodeConfig& config = {}); MGE_WIN_DECLSPEC_FUC VarNode* grad( size_t wrt_idx, const VarNodeArray& out_grad) const; // used for serialization struct Param { static constexpr uint32_t TAG = param_tag::DIMSHUFFLE; uint32_t pattern_len; int32_t pattern[TensorShape::MAX_NDIM]; uint32_t ndim; }; static SymbolVar make( SymbolVar inp, const Param& param, const OperatorNodeConfig& config) { return make( inp, {param.pattern, param.pattern + param.pattern_len}, param.ndim, config); } Param param() const { Param ret; ret.pattern_len = m_pattern.size(); std::copy(m_pattern.begin(), m_pattern.end(), ret.pattern); ret.ndim = m_inp_ndim; return ret; } }; /*! * \brief add or remove an axis with shape 1 * * All the axis descs would be processed in order */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(AxisAddRemove, intl::AxisManipOprBase) // { public: struct AxisDesc { enum class Method { //! add a dim with shape 1, just before axis ADD_1, //! remove this axis, which must be shape 1 REMOVE }; Method method; indexing::AxisNum axis; static AxisDesc make_add(indexing::AxisNum axis) { AxisDesc r; r.axis = axis; r.method = Method::ADD_1; return r; } static AxisDesc make_remove(indexing::AxisNum axis) { AxisDesc r; r.axis = axis; r.method = Method::REMOVE; return r; } }; MGE_WIN_DECLSPEC_FUC AxisAddRemove( VarNode* inp, const std::vector& desc, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar inp, const std::vector& desc, const OperatorNodeConfig& config = {}); // used for serialization struct Param { static constexpr uint32_t TAG = param_tag::AXIS_ADD_REMOVE, MAX_DESC_SIZE = TensorShape::MAX_NDIM * 2; uint32_t nr_desc; AxisDesc desc[MAX_DESC_SIZE]; }; static SymbolVar make( SymbolVar inp, const Param& param, const OperatorNodeConfig& config) { return make(inp, {param.desc, param.desc + param.nr_desc}, config); } Param param() const { mgb_assert(m_desc.size() <= Param::MAX_DESC_SIZE); Param ret; ret.nr_desc = m_desc.size(); std::copy(m_desc.begin(), m_desc.end(), ret.desc); return ret; } private: std::vector m_desc; MGE_WIN_DECLSPEC_FUC TensorLayout axis_manip_get_output_layout(const TensorLayout& inp_layout) const override; }; namespace intl { MGB_DEFINE_CLS_WITH_SUPER(ModifySubtensorImplHelper, FancyIndexingHelper) // { MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override final; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override final; /*! * \brief implement the actual modifycation * * Note that this method may be used both for exec and static value infer * * \param sub a view of the dest subtensor on target tensor */ virtual void modify(DeviceTensorND& sub, const DeviceTensorND& val) = 0; protected: using Super::Super; }; } // namespace intl /*! * \brief get subtensor in a python-like way */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( Subtensor, intl::ReadonlyFwdHelper) // { MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override; MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override; MGE_WIN_DECLSPEC_FUC void init_rt_force_dynamic_mem_alloc_imply_chain() override; MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override; public: MGE_WIN_DECLSPEC_FUC Subtensor( VarNode* inp, const IndexDesc& desc, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar inp, const IndexDesc& desc, const OperatorNodeConfig& config = {}); }; /*! * \brief replace the value of subtensor by another tensor */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(SetSubtensor, intl::ModifySubtensorImplHelper) // { void modify(DeviceTensorND& sub, const DeviceTensorND& val) override; NodeProp* do_make_node_prop() const override; public: MGB_DECL_FANCY_INDEXING_OPR_MODIFY(SetSubtensor); }; /*! * \brief increase the value of subtensor by another tensor */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(IncrSubtensor, intl::ModifySubtensorImplHelper) // { void modify(DeviceTensorND& sub, const DeviceTensorND& val) override; public: MGB_DECL_FANCY_INDEXING_OPR_MODIFY(IncrSubtensor); }; class IndexAt { public: /*! * \brief helper for Subtensor with only index * \param index list of pairs of (axis, index) */ MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar inp, const std::vector>& index, const OperatorNodeConfig& config = {}); }; /*! * \brief split a tensor along one axis, possibly to different computing nodes * * Note that the computing nodes could be specified in one of the following * ways: * 1. If omitted in OperatorNodeConfig, it would be inferred from input * 2. Specify one comp_node in OperatorNodeConfig, and all output would reside * on this comp_node * 3. Specify comp_node for each output in OperatorNodeConfig */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(Split, intl::OutshapeBySymvarOprBase) // { public: struct Options { enum class Method { SPECIFY, //!< specify output sizes CALL_BACK //!< output sizes obtained from callback }; Method method; size_t nr_part = 0; int axis = 0; using callback_t = thin_function(size_t tot_size)>; callback_t callback; SymbolVarArray partition; /*! * \brief make split option by splitting into average parts */ static Options make_average(int axis, size_t nr_part); static Options make_partition(int axis, const SymbolVarArray& partition); static Options make_partition( SymbolVar inp, int axis, const std::vector& partition); static Options make_callback(int axis, size_t nr_part, callback_t callback); }; Split(VarNode* inp, const Options& opt, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( SymbolVar inp, Options opt, const OperatorNodeConfig& config = {}); const Options& options() const { return m_opt; } private: struct OutputSpec { TensorShape shape; //! recent inferred shape bool mem_fwd_success = false; SubTensorSpec subspec; }; bool m_readonly_fwd_called = false; std::vector m_output_spec; Options m_opt; size_t m_output_shape_version = 0; MGE_WIN_DECLSPEC_FUC void init_output_comp_node() override; MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override; MGE_WIN_DECLSPEC_FUC void do_execute(ExecEnv& env) override; MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override; MGE_WIN_DECLSPEC_FUC void outshape_by_symvar_do_get_output_shape( TensorShape& dest, const ShapeInferInfo& shpinfo) override; MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override; MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override; MGE_WIN_DECLSPEC_FUC bool infer_shape( size_t out_idx, TensorShape& dest, const cg::static_infer::InpVal& inp); MGE_WIN_DECLSPEC_FUC void on_mem_status_changed(); MGE_WIN_DECLSPEC_FUC OprEventCallback get_opr_event_callback() override final; MGE_WIN_DECLSPEC_FUC void init_subspec(bool memfwd); MGE_WIN_DECLSPEC_FUC void on_output_comp_node_stream_changed() override; MGE_WIN_DECLSPEC_FUC void init_rt_force_dynamic_mem_alloc_imply_chain() override; }; /*! * \brief concat a tensor * * To concat to a different computing node, specify the destination in * OperatorNodeConfig */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( Concat, cg::SingleCNOutshapePureByInshapeOprBase) // { public: using Param = megdnn::param::Axis; Concat(const VarNodeArrayView& inp, int axis, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( const VarNodeArrayView& inp, int axis, const OperatorNodeConfig& config = {}); //! for desrialization static SymbolVar make( const VarNodeArrayView& inp, const Param& param, const OperatorNodeConfig& config) { return make(inp, static_cast(param.axis), config); } //! get axis for this concat int axis() const { return m_axis; } Param param() const { return m_axis; } private: int m_axis; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override; MGE_WIN_DECLSPEC_FUC NodeProp* do_make_node_prop() const override; MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override; MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override; MGE_WIN_DECLSPEC_FUC void init_output_comp_node() override; MGE_WIN_DECLSPEC_FUC void get_output_var_shape( const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const override; }; /*! * \brief Opr used to pack parameter, all input node must in same device, dtype * and shape is not needed to be same * \param offsets: size of 2 * inputs.size() * offsets[i * 2] and offsets[i * 2 + 1] means * the begin and the end of inputs[i]'s offsets in output * \param offsets_val: offsets value on cpu */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(ParamPackConcat, cg::SingleCNOperatorNodeBase) // { //! input pointer buffer SmallVector m_inp_ptr; std::vector m_offsets; intl::UniqPtrWithCN m_opr; MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override; MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override; MGE_WIN_DECLSPEC_FUC void init_output_dtype() override; MGE_WIN_DECLSPEC_FUC void on_output_comp_node_stream_changed() override; public: using Param = megdnn::param::Empty; Param param() const { return {}; } MGE_WIN_DECLSPEC_FUC ParamPackConcat( VarNodeArray& inp, VarNode* offsets, const std::vector offsets_val, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( const SmallVector& inp, const SymbolVar& offsets, const std::vector offsets_val, const OperatorNodeConfig& config = {}); static SymbolVar make( const SmallVector& inp, const SymbolVar& offsets, const std::vector offsets_val, const Param&, const OperatorNodeConfig& config) { return make(inp, offsets, offsets_val, config); } const std::vector& get_offsets() const { return m_offsets; } }; /*! * \brief Opr used to split parameter * \param offsets: size of 2 * outputs.size() * offsets[i * 2] and offsets[i * 2 + 1] means * the begin and the end of output[i]'s offsets in input * \param offsets_val: offsets value on cpu * \param shapes: shape of each output */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT(ParamPackSplit, cg::SingleCNOperatorNodeBase) // { TensorShapeArray m_shapes; std::vector m_offsets; MGE_WIN_DECLSPEC_FUC void scn_do_execute() override; MGE_WIN_DECLSPEC_FUC void init_output_static_infer_desc() override; MGE_WIN_DECLSPEC_FUC bool infer_shape( size_t index, TensorShape& dest, const cg::static_infer::InpVal& inp); MGE_WIN_DECLSPEC_FUC void init_output_dtype() override; MGE_WIN_DECLSPEC_FUC void mem_plan_fwd_in2out_readonly() override; MGE_WIN_DECLSPEC_FUC void add_input_layout_constraint() override; public: MGE_WIN_DECLSPEC_FUC ParamPackSplit( VarNode* src, const std::vector offsets, TensorShapeArray& shapes, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( const SymbolVar& src, const std::vector offsets, TensorShapeArray shapes, const OperatorNodeConfig& config = {}); const std::vector& get_offsets() const { return m_offsets; } const TensorShapeArray& get_output_shapes() const { return m_shapes; } MGE_WIN_DECLSPEC_FUC void init_rt_force_dynamic_mem_alloc_imply_chain() override; }; /*! * \brief change the tensor layout to adapt to new format * * See docs of megdnn params for more details */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( RelayoutFormat, intl::MegDNNOprWrapperFwd) // { public: MGE_WIN_DECLSPEC_FUC RelayoutFormat( VarNode* src, const Param& param, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar src, const Param& param, const OperatorNodeConfig& config = {}); MGE_WIN_DECLSPEC_FUC void init_output_format() override final; }; /*! * \brief padding the src tensor to dst tensor */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( PaddingForward, intl::MegDNNOprWrapperFwd) // { public: MGE_WIN_DECLSPEC_FUC PaddingForward( VarNode* src, const Param& param, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar src, const Param& param = {}, const OperatorNodeConfig& config = {}); }; using Padding = PaddingForward; /*! * \brief padding backward */ MGB_DEFINE_OPR_CLASS_WITH_EXPORT( PaddingBackward, intl::MegDNNOprWrapperBwd) // { public: MGE_WIN_DECLSPEC_FUC PaddingBackward( VarNode* src, VarNode* in_for_shape, const Param& param, const OperatorNodeConfig& config); MGE_WIN_DECLSPEC_FUC static SymbolVar make( SymbolVar src, SymbolVar in_for_shape, const Param& param = {}, const OperatorNodeConfig& config = {}); }; } // namespace opr } // namespace mgb // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}