misc.sereg.h 2.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
/**
 * \file src/opr/impl/misc.sereg.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "megbrain/opr/misc.h"
#include "megbrain/serialization/sereg.h"

namespace mgb {

namespace serialization {

    template<>
    struct OprMaker<opr::Argsort, 1> {
        using Opr = opr::Argsort;
        using Param = Opr::Param;
        static cg::OperatorNodeBase* make(
                const Param &param, const cg::VarNodeArray &inputs,
                ComputingGraph &graph, const OperatorNodeConfig &config) {
            MGB_MARK_USED_VAR(graph);
            auto out = Opr::make(inputs[0], param, config);
            return out[0].node()->owner_opr();
        }
    };

    template<>
    struct OprMaker<opr::CondTake, 2> {
        using Opr = opr::CondTake;
        using Param = Opr::Param;
        static cg::OperatorNodeBase* make(
                const Param &param, const cg::VarNodeArray &inputs,
                ComputingGraph &graph, const OperatorNodeConfig &config) {
            MGB_MARK_USED_VAR(graph);
            auto out = Opr::make(inputs[0], inputs[1], param, config);
            return out[0].node()->owner_opr();
        }
    };

    template<>
    struct OprMaker<opr::TopK, 2> {
        using Opr = opr::TopK;
        using Param = Opr::Param;
        static cg::OperatorNodeBase* make(
                const Param &param, const cg::VarNodeArray &inputs,
                ComputingGraph &graph, const OperatorNodeConfig &config) {
            MGB_MARK_USED_VAR(graph);
            auto out = Opr::make(inputs[0], inputs[1], param, config);
            return out[0].node()->owner_opr();
        }
    };

} // namespace serialization


namespace opr {

    MGB_SEREG_OPR(Argmax, 1);
    MGB_SEREG_OPR(Argmin, 1);
    MGB_SEREG_OPR(Argsort, 1);
    MGB_SEREG_OPR(ArgsortBackward, 3);
    MGB_SEREG_OPR(CondTake, 2);
    MGB_SEREG_OPR(TopK, 2);
    //! current cumsum version
    using CumsumV1 = opr::Cumsum;
    MGB_SEREG_OPR(CumsumV1, 1);

73 74 75
#if MGB_CUDA
    MGB_SEREG_OPR(NvOf, 1);
#endif
76

77 78 79 80 81 82
} // namespace opr
} // namespace mgb


// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}