/** * \file src/opr/impl/tensor_manip.sereg.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/internal/indexing_helper_sereg.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/serialization/sereg.h" #if MGB_ENABLE_FBS_SERIALIZATION #include "megbrain/serialization/internal/mgb_cpp_opr_generated.h" #endif MGB_SEREG_GET_SUBTENSOR_OPR(Subtensor); MGB_SEREG_MODIFY_SUBTENSOR_OPR(SetSubtensor); MGB_SEREG_MODIFY_SUBTENSOR_OPR(IncrSubtensor); namespace mgb { namespace serialization { template <> struct OprMaker { using Opr = opr::Padding; using Param = Opr::Param; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); if (inputs.size() == 1) { return Opr::make(inputs[0], param, config).node()->owner_opr(); } else { return nullptr; } } }; template <> struct OprMaker { using Opr = opr::PaddingBackward; using Param = Opr::Param; static cg::OperatorNodeBase* make( const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); if (inputs.size() == 2) { return Opr::make(inputs[0], inputs[1], param, config).node()->owner_opr(); } else { return nullptr; } } }; template <> struct OprMaker : public OprMakerVariadic {}; template <> struct OprMaker : public OprMakerVariadic {}; template <> struct OprLoadDumpImpl { using Split = opr::Split; using Options = Split::Options; using Method = Options::Method; static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { auto&& opr = opr_.cast_final_safe(); auto&& opt = opr.options(); mgb_assert( opt.method == Method::SPECIFY, "only Spllit with SPECIFY output shapes can be serialized"); ctx.write_param(opt.axis); } static cg::OperatorNodeBase* load( OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { auto param = ctx.read_param(); opr::Split::Options opt; opt.method = Method::SPECIFY; opt.axis = param.axis; mgb_assert(inputs.size() > 1); opt.nr_part = inputs.size() - 1; opt.partition.resize(opt.nr_part); for (size_t i = 1; i < inputs.size(); ++i) opt.partition[i - 1] = inputs[i]; return Split::make(inputs[0], opt, config)[0].node()->owner_opr(); } }; #if MGB_ENABLE_FBS_SERIALIZATION namespace fbs { template <> struct ParamConverter { using FlatBufferType = param::Dimshuffle; static opr::Dimshuffle::Param to_param(const FlatBufferType* fb) { opr::Dimshuffle::Param param; param.ndim = fb->ndim(); if (fb->pattern()) { param.pattern_len = fb->pattern()->size(); mgb_assert( param.pattern_len <= sizeof(param.pattern) / sizeof(param.pattern[0])); memcpy(param.pattern, fb->pattern()->data(), sizeof(param.pattern[0]) * param.pattern_len); } else { param.pattern_len = 0; } return param; } static flatbuffers::Offset to_flatbuffer( flatbuffers::FlatBufferBuilder& builder, const opr::Dimshuffle::Param& p) { return param::CreateDimshuffle( builder, builder.CreateVector(p.pattern, p.pattern_len), p.ndim); } }; template <> struct ParamConverter { using FlatBufferType = param::AxisAddRemove; static opr::AxisAddRemove::Param to_param(const FlatBufferType* fb) { opr::AxisAddRemove::Param param; if (fb->desc()) { param.nr_desc = fb->desc()->size(); for (uint32_t i = 0; i < param.nr_desc; i++) { param.desc[i].axis = fb->desc()->Get(i)->axis(); param.desc[i].method = static_cast( fb->desc()->Get(i)->method()); } } else { param.nr_desc = 0; } return param; } static flatbuffers::Offset to_flatbuffer( flatbuffers::FlatBufferBuilder& builder, const opr::AxisAddRemove::Param& p) { std::vector desc(p.nr_desc); for (uint32_t i = 0; i < p.nr_desc; i++) { desc[i] = { static_cast(p.desc[i].method), p.desc[i].axis.get_raw()}; } return param::CreateAxisAddRemoveDirect(builder, &desc); } }; } // namespace fbs #endif } // namespace serialization namespace opr { MGB_SEREG_OPR(Broadcast, 2); MGB_SEREG_OPR(Dimshuffle, 1); MGB_SEREG_OPR(AxisAddRemove, 1); MGB_SEREG_OPR(Concat, 0); using GetVarShapeV1 = opr::GetVarShape; MGB_SEREG_OPR(GetVarShapeV1, 0); using ReshapeV1 = opr::Reshape; MGB_SEREG_OPR(ReshapeV1, 2); cg::OperatorNodeBase* opr_shallow_copy_split( const serialization::OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); auto option = opr.options(); using Meth = Split::Options::Method; switch (option.method) { case Meth::CALL_BACK: mgb_assert(inputs.size() == 1); break; case Meth::SPECIFY: mgb_assert(inputs.size() == 1 + option.partition.size()); for (size_t i = 0; i < option.partition.size(); ++i) option.partition[i] = inputs[i + 1]; break; } return Split::make(inputs[0], option, config).at(0).node()->owner_opr(); } MGB_SEREG_OPR(Split, 0); MGB_REG_OPR_SHALLOW_COPY(Split, opr_shallow_copy_split); cg::OperatorNodeBase* opr_shallow_copy_param_pack_split( const serialization::OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); auto&& offsets = opr.get_offsets(); auto&& shape = opr.get_output_shapes(); return ParamPackSplit::make(inputs[0], offsets, shape, config) .at(0) .node() ->owner_opr(); } MGB_REG_OPR_SHALLOW_COPY(ParamPackSplit, opr_shallow_copy_param_pack_split); cg::OperatorNodeBase* opr_shallow_copy_param_pack_concat( const serialization::OprShallowCopyContext& ctx, const cg::OperatorNodeBase& opr_, const VarNodeArray& inputs, const OperatorNodeConfig& config) { auto&& opr = opr_.cast_final_safe(); auto&& offsets = opr.get_offsets(); SymbolVarArray ivar{inputs.size() - 1}; for (size_t i = 0; i < inputs.size() - 1; ++i) ivar[i] = inputs[i]; return ParamPackConcat::make(ivar, inputs.back(), offsets, config) .node() ->owner_opr(); } MGB_REG_OPR_SHALLOW_COPY(ParamPackConcat, opr_shallow_copy_param_pack_concat); using RelayoutFormatV1 = opr::RelayoutFormat; MGB_SEREG_OPR(RelayoutFormatV1, 1); MGB_SEREG_OPR(Padding, 1); MGB_SEREG_OPR(PaddingBackward, 2); } // namespace opr } // namespace mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}