/** * \file src/opr/impl/blas.sereg.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "megbrain/opr/blas.h" #include "megbrain/opr/param_defs.h" #include "megbrain/serialization/sereg.h" #include "megdnn/opr_param_defs.h" #include "megdnn/oprs/linalg.h" namespace mgb { namespace serialization { template <> struct OprMaker { using Param = opr::SVD::Param; static cg::OperatorNodeBase* make(const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, const OperatorNodeConfig& config) { MGB_MARK_USED_VAR(graph); return opr::SVD::make(i[0], param, config)[0].node()->owner_opr(); } }; template struct MakeMatrixMulCaller { template static VarNode* make(const cg::VarNodeArray& inputs, const typename MegDNNConv::Param& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { if (inputs.size() == 2) { return Opr::make(inputs[0], inputs[1], param, execution_policy, config) .node(); } return nullptr; } }; template struct MatrixMulLoadDumpImpl { static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) { auto&& opr = opr_.cast_final_safe(); ctx.write_param(opr.param()); } static VarNode* make(const cg::VarNodeArray& inputs, const megdnn::param::MatrixMul& param, const megdnn::param::ExecutionPolicy& execution_policy, const OperatorNodeConfig& config) { VarNode* ret = Maker::template make(inputs, param, execution_policy, config); mgb_assert(ret); return ret; } static cg::OperatorNodeBase* load(OprLoadContext& ctx, const cg::VarNodeArray& inputs, const OperatorNodeConfig& config) { auto param = ctx.read_param(); return make(inputs, param, {}, config)->owner_opr(); } }; template <> struct OprLoadDumpImpl : public MatrixMulLoadDumpImpl, megdnn::MatrixMul> {}; template <> struct OprLoadDumpImpl : public MatrixMulLoadDumpImpl< opr::BatchedMatrixMul, MakeMatrixMulCaller, megdnn::BatchedMatrixMul> {}; } // namespace serialization namespace opr { using MatrixMulV2 = MatrixMul; using BatchedMatrixMulV2 = BatchedMatrixMul; MGB_SEREG_OPR(MatrixMulV2, 2); MGB_SEREG_OPR(BatchedMatrixMulV2, 2); MGB_SEREG_OPR(Dot, 2); MGB_SEREG_OPR(MatrixInverse, 1); MGB_SEREG_OPR(SVD, 1); } // namespace opr } // namespace mgb // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}