blas.sereg.h 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/**
 * \file src/opr/impl/blas.sereg.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/blas.h"
13
#include "megbrain/opr/param_defs.h"
14
#include "megbrain/serialization/sereg.h"
15 16
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/linalg.h"
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32

namespace mgb {
namespace serialization {

template <>
struct OprMaker<opr::SVD, 1> {
    using Param = opr::SVD::Param;
    static cg::OperatorNodeBase* make(const Param& param,
                                      const cg::VarNodeArray& i,
                                      ComputingGraph& graph,
                                      const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        return opr::SVD::make(i[0], param, config)[0].node()->owner_opr();
    }
};

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
template <class MegDNNConv = megdnn::MatrixMul>
struct MakeMatrixMulCaller {
    template <typename Opr>
    static VarNode* make(const cg::VarNodeArray& inputs,
                         const typename MegDNNConv::Param& param,
                         const megdnn::param::ExecutionPolicy& execution_policy,
                         const OperatorNodeConfig& config) {
        if (inputs.size() == 2) {
            return Opr::make(inputs[0], inputs[1], param, execution_policy,
                             config)
                    .node();
        }
        return nullptr;
    }
};

template <class Opr, class Maker, class MegDNNMatrixMul>
struct MatrixMulLoadDumpImpl {
    static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr_) {
        auto&& opr = opr_.cast_final_safe<Opr>();
        ctx.write_param<megdnn::param::MatrixMul>(opr.param());
        ctx.write_param<megdnn::param::ExecutionPolicy>(opr.execution_policy());
    }

    static VarNode* make(const cg::VarNodeArray& inputs,
                         const megdnn::param::MatrixMul& param,
                         const megdnn::param::ExecutionPolicy& execution_policy,
                         const OperatorNodeConfig& config) {
        VarNode* ret = Maker::template make<Opr>(inputs, param,
                                                 execution_policy, config);
        mgb_assert(ret);
        return ret;
    }

    static cg::OperatorNodeBase* load(OprLoadContext& ctx,
                                      const cg::VarNodeArray& inputs,
                                      const OperatorNodeConfig& config) {
        auto param = ctx.read_param<megdnn::param::MatrixMul>();
        auto execution_policy =
                ctx.read_param<megdnn::param::ExecutionPolicy>();
        return make(inputs, param, execution_policy, config)->owner_opr();
    }
};

template <>
struct OprLoadDumpImpl<opr::MatrixMul, 2>
        : public MatrixMulLoadDumpImpl<opr::MatrixMul,
                                       MakeMatrixMulCaller<megdnn::MatrixMul>,
                                       megdnn::MatrixMul> {};
template <>
struct OprLoadDumpImpl<opr::BatchedMatrixMul, 2>
        : public MatrixMulLoadDumpImpl<
                  opr::BatchedMatrixMul,
                  MakeMatrixMulCaller<megdnn::BatchedMatrixMul>,
                  megdnn::BatchedMatrixMul> {};

89 90 91 92
}  // namespace serialization

namespace opr {

93 94 95 96
using MatrixMulV3 = MatrixMul;
using BatchedMatrixMulV3 = BatchedMatrixMul;
MGB_SEREG_OPR(MatrixMulV3, 2);
MGB_SEREG_OPR(BatchedMatrixMulV3, 2);
97 98 99 100 101 102 103 104 105 106
MGB_SEREG_OPR(Dot, 2);
MGB_SEREG_OPR(MatrixInverse, 1);
MGB_SEREG_OPR(SVD, 1);

}  // namespace opr


}  // namespace mgb

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}