rand.sereg.h 2.6 KB
Newer Older
1 2 3 4
#include "megbrain/opr/rand.h"
#include "megbrain/serialization/sereg.h"

namespace mgb {
5

6 7 8 9 10 11
namespace serialization {

template <>
struct OprMaker<opr::ShuffleRNG, 1> {
    using Opr = opr::ShuffleRNG;
    using Param = Opr::Param;
M
Megvii Engine Team 已提交
12 13 14
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& inputs, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
15 16 17 18 19
        MGB_MARK_USED_VAR(graph);
        auto out = Opr::make(inputs[0], param, config);
        return out[0].node()->owner_opr();
    }
};
20 21 22 23 24 25 26 27 28 29 30 31 32

// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::DropoutForward, 1> {
    using Param = opr::DropoutForward::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        return opr::DropoutForward::make(i[0], param, config)[0].node()->owner_opr();
    }
};

33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
template <>
struct OprMaker<opr::MultiHeadAttn, 0> {
    using Param = opr::MultiHeadAttn::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);
        return opr::MultiHeadAttn::make(i[0], i[1], i[2], i[3], param, config)[0]
                .node()
                ->owner_opr();
    }
};

// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::MultiHeadAttnBackward, 0> {
    using Param = opr::MultiHeadAttnBackward::Param;
    static cg::OperatorNodeBase* make(
            const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
            const OperatorNodeConfig& config) {
        MGB_MARK_USED_VAR(graph);

        return opr::MultiHeadAttnBackward::make(
                       i[0], i[1], i[2], i[3], i[4], i[5], param, config)[0]
                .node()
                ->owner_opr();
    }
};

62 63
}  // namespace serialization

64 65
namespace opr {

66 67 68 69 70 71 72 73
using UniformRNGV1 = opr::UniformRNG;
MGB_SEREG_OPR(UniformRNGV1, 1);
using GaussianRNGV1 = opr::GaussianRNG;
MGB_SEREG_OPR(GaussianRNGV1, 1);
MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR(PoissonRNG, 1);
MGB_SEREG_OPR(PermutationRNG, 1);
MGB_SEREG_OPR(BetaRNG, 2);
74 75
MGB_SEREG_OPR(ShuffleRNG, 1);
MGB_SEREG_OPR(ShuffleRNGBackward, 3);
76 77
MGB_SEREG_OPR(Dropout, 1);
MGB_SEREG_OPR(DropoutBackward, 2);
78 79
MGB_SEREG_OPR(MultiHeadAttn, 0);
MGB_SEREG_OPR(MultiHeadAttnBackward, 0);
80

81 82
}  // namespace opr
}  // namespace mgb
83 84

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}