opr_defs.h 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
/**
 * \file python_module/src/cpp/opr_defs.h
 *
 * This file is part of MegBrain, a deep learning framework developed by Megvii.
 *
 * \brief extra opr definitions
 *
 * \copyright Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 */

#ifndef SWIG
#pragma once

#include "./megbrain_wrap.h"
#include "./opr_helper.h"

#if MGB_ENABLE_OPR_MM
#include "megbrain/opr/collective_comm.h"
#endif
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/tensor_manip.h"
using mgb::SymbolVar;
using mgb::SymbolVarArray;
using mgb::OperatorNodeConfig;

#endif

class _Opr {

public:

// basic arith

static SymbolVar add_update(SymbolVar dest, SymbolVar delta,
        const SharedScalar &alpha, const SharedScalar &beta,
        const SharedScalar &bias, const SharedScalar &disable,
        const OperatorNodeConfig &config) {
    return mgb::opr::AddUpdate::make(dest, delta,
            {alpha.get_val(), beta.get_val(), bias.get_val(), disable.get_val()},
            config);
}

// tensor manip

static SymbolVarArray param_pack_split(
47
        SymbolVar src, const std::vector<std::vector<size_t>>& shapes,
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        const OperatorNodeConfig& config);

static SymbolVar dimshuffle(SymbolVar src,
        const std::vector<int> &pattern, size_t ndim,
        const OperatorNodeConfig &config) {
    return mgb::opr::Dimshuffle::make(src, pattern, ndim, config);
}

static SymbolVar _axis_add_remove(SymbolVar src,
        const std::vector<int>& axis, bool is_add,
        const OperatorNodeConfig &config);

static SymbolVar callback_injector(SymbolVar src, _CompGraphCallback &callback,
        const OperatorNodeConfig &config) {
    return mgb::opr::CallbackInjector::make(src, callback.make_callback());
}

static SymbolVar callback_injector(SymbolVarArray src, _CompGraphCallback &callback,
                                   const OperatorNodeConfig &config) {
    return mgb::opr::CallbackInjector::make(src, callback.make_multi_input_callback());
}

static SymbolVar set_grad(SymbolVar src, _SetGradCallback &grad_getter,
        const OperatorNodeConfig &config) {
    return mgb::opr::SetGrad::make(src, grad_getter.make_callback(), config);
}

// multi machine

static SymbolVar lock_acquire(SymbolVar var, size_t lock_id, size_t group_id,
        const OperatorNodeConfig &config);

static SymbolVar lock_release(SymbolVar var, size_t lock_id, size_t group_id,
        const OperatorNodeConfig &config);

static SymbolVar remote_send(
        const std::string& server_addr, const int port,
        const std::string& key, SymbolVar var,
        const bool is_grad,
        const OperatorNodeConfig& config);

static SymbolVar remote_recv(const std::string& server_addr, const int port,
                             const std::string& key,
                             CompGraph& graph,
                             const std::vector<size_t>& shape, PyObject* dtype,
                             const OperatorNodeConfig& config);

static SymbolVar collective_comm_with_input(
        SymbolVar inpvar, const std::string& key, const size_t nr_devices,
97 98 99 100
        const bool is_root, const int rank, const bool local_grad,
        const std::string& server_addr, const int port, PyObject* params,
        PyObject* dtype, const std::string& backend, SharedND* output_buf,
        const OperatorNodeConfig& config, const SharedScalar& disable);
101 102 103

static SymbolVar collective_comm_without_input(
        CompGraph& graph, const std::string& key, const size_t nr_devices,
104 105 106 107
        const bool is_root, const int rank, const bool local_grad,
        const std::string& server_addr, const int port, PyObject* params,
        PyObject* dtype, const std::string& backend, SharedND* output_buf,
        const OperatorNodeConfig& config, const SharedScalar& disable);
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

// misc
static SymbolVarArray extern_c_opr_placeholder(
        const SymbolVarArray& inputs,
        const std::vector<std::vector<size_t>>& output_shapes,
        PyObject* dtypes,
        const char* dump_name, PyObject* data_bytes,
        const OperatorNodeConfig& config);

static SymbolVarArray tensor_rt_runtime(const SymbolVarArray& inputs,
                                        PyObject* data_bytes,
                                        const OperatorNodeConfig& config);

static SymbolVar timestamp(SymbolVar input, PyObject* dest, size_t dest_off,
                           const OperatorNodeConfig& config);

static SymbolVar virtual_loss(const SymbolVarArray& ys,
                              const SymbolVarArray& y_grads,
                              const OperatorNodeConfig& config);

static SymbolVar virtual_dep(const SymbolVarArray& symvars,
                             const OperatorNodeConfig& config);


#ifdef SWIG
%pythoncode {

@classmethod
def _make_axis_vec(cls, axis):
    ret = _VectorInt()
    if isinstance(axis, collections.Iterable):
        for i in axis:
            ret.push_back(i)
    else:
        ret.push_back(axis)
    return ret

@classmethod
def add_axis(cls, src, axis, config):
    return cls._axis_add_remove(src, cls._make_axis_vec(axis), True, config)

@classmethod
def remove_axis(cls, src, axis, config):
    return cls._axis_add_remove(src, cls._make_axis_vec(axis), False, config)

} // %pythoncode
#endif // SWIG

};

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}