misc.i 4.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
/*
 * $File: misc.i
 *
 * This file is part of MegBrain, a deep learning framework developed by Megvii.
 *
 * $Copyright: Copyright (c) 2014-2017 Megvii Inc. All rights reserved.
 */


%{
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/serialization/helper.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/plugin/opr_footprint.h"
using _PyStackExtracter = PyStackExtracter;
using _PersistentCache = mgb::PersistentCache;
using _PersistentCacheBlob = _PersistentCache::Blob;
using _MaybePersistentCacheBlob = mgb::Maybe<_PersistentCacheBlob>;
using _OptimizeForInferenceOptions = mgb::gopt::OptimizeForInferenceOptions;
%}

%feature("director") _PyStackExtracter;
class _PyStackExtracter {
    public:
        virtual ~_PyStackExtracter() = default;
        virtual std::string extract() = 0;
        static void reg(_PyStackExtracter *p);
};

// from Blob to python bytes
%typemap(in) const _PersistentCacheBlob& {
    mgb_assert(PyBytes_Check($input));
    $1->ptr = PyBytes_AsString($input);
    $1->size = PyBytes_Size($input);
}
%typemap(directorin) const _PersistentCacheBlob& {
    $input = PyBytes_FromStringAndSize(
        static_cast<const char*>($1.ptr), $1.size);
}
%typemap(directorout) _MaybePersistentCacheBlob {
    mgb_assert($1->ob_refcnt >= 2, "persistent cache result refcnt too small");
    if ($1 == Py_None) {
        $result = mgb::None;
    } else {
        mgb_assert(PyBytes_Check($input));
        _PersistentCacheBlob blob;
        blob.ptr = PyBytes_AsString($1);
        blob.size = PyBytes_Size($1);
        $result = blob;
    }
}

%feature("director") _PersistentCache;
class _PersistentCache {
    public:
        virtual ~_PersistentCache() = default;

        virtual void put(const std::string &category,
                const _PersistentCacheBlob &key,
                const _PersistentCacheBlob &value) = 0;

        virtual _MaybePersistentCacheBlob get(
                const std::string &category,
                const _PersistentCacheBlob &key) = 0;

        %extend {
            static void reg(_PersistentCache *p) {
                _PersistentCache::set_impl({p, [](_PersistentCache*){}});
            }
        }
};

struct _OptimizeForInferenceOptions {
74 75 76 77 78 79 80 81 82
#define SET(n)  void enable_##n();
        SET(f16_io_f32_comp);
        SET(f16_io_comp);
        SET(fuse_conv_bias_nonlinearity);
        SET(fuse_conv_bias_with_z);
#undef SET
#define SET(_trans, _trans_capital)   \
        void enable_##_trans(); \

83
        SET(nchw4, NCHW4);
84 85 86 87 88
        SET(nhwcd4, NHWCD4);
        SET(nchw88, NCHW88);
        SET(nchw44, NCHW44);
        SET(nchw32, NCHW32);
        SET(chwn4, CHWN4);
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
#undef SET
};

%inline {
    static SymbolVarArray _optimize_for_inference(
            const SymbolVarArray& dest_vars,
            const _OptimizeForInferenceOptions& opt) {
        return mgb::gopt::optimize_for_inference(dest_vars, opt);
    }

    // defined in function_replace.cpp
    void _register_logger(PyObject *logger);
    void _timed_func_set_fork_exec_path(const char *arg0, const char *arg1);
    void _timed_func_exec_cb(const char *user_data);

    // defined in megbrain_wrap.cpp
    void _mgb_global_finalize();
    std::vector<size_t> _get_mgb_version();
    SymbolVarArray _grad(SymbolVar target, SymbolVarArray wrts,
            bool warn_mid_wrt, int use_virtual_grad,
            bool return_zero_for_nodep);
    SymbolVar _inter_graph_trans_var(
            CompGraph &dest_graph, SymbolVar src);
    SymbolVar _get_graph_optimizer_replaced_var(SymbolVar src);
    void _add_update_fastpath(SharedND& dest, SharedND& delta,
            float alpha, float beta, float bias);
    void _add_update_fastpath(SharedND& dest,
            CompGraphCallbackValueProxy& delta,
            float alpha, float beta, float bias);

    static SymbolVar _current_grad_target(CompGraph &graph) {
        return mgb::cg::current_grad_target(graph.get());
    }

    uint32_t _get_dtype_num(PyObject *dtype) {
        return static_cast<uint32_t>(npy::dtype_np2mgb(dtype).enumv());
    }

    PyObject* _get_serialized_dtype(PyObject *dtype) {
        std::string sdtype;
        auto write = [&sdtype](const void* data, size_t size) {
            auto pos = sdtype.size();
            sdtype.resize(pos + size);
            memcpy(&sdtype[pos], data, size);
        };
        mgb::serialization::serialize_dtype(npy::dtype_np2mgb(dtype), write);
        return PyBytes_FromStringAndSize(sdtype.data(), sdtype.size());
    }

    size_t max_size_t() {
        return std::numeric_limits<size_t>::max();
    }

    std::string _get_opr_fp_graph_exec(
        CompGraph& cg, const SymbolVarArray& outputs) {
        auto json = mgb::OprFootprint::get_opr_fp_graph_exec(cg.get(), outputs);
        return json->to_string();
    }
}

// vim: ft=swig