libacl_cblas-wrap.h 16.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
// generated by wraplib.py
// --- begin functions to be implemented
#ifndef _WRAPLIB_API_CALL
#define _WRAPLIB_API_CALL
#endif
#ifndef _WRAPLIB_CALLBACK
#define _WRAPLIB_CALLBACK
#endif
#ifndef ON_ENTRY
#define ON_ENTRY(x)
#endif
static void* get_library_handle();
static void* resolve_library_func(void* , const char*);
namespace {
template<typename T> T on_init_failed(int func_idx);
}
// --- end functions to be implemented
#include <mutex>
#include <cstddef>
static void load_library();
static aclError _WRAPLIB_API_CALL aclblasGemvEx_init(aclTransType arg0, int arg1, int arg2, const void *arg3, const void *arg4, int arg5, aclDataType arg6, const void *arg7, int arg8, aclDataType arg9, const void *arg10, void *arg11, int arg12, aclDataType arg13, aclComputeType arg14, aclrtStream arg15) {
    load_library();
    return aclblasGemvEx(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15);
}
static aclError _WRAPLIB_API_CALL aclblasGemvEx_error(aclTransType, int, int, const void *, const void *, int, aclDataType, const void *, int, aclDataType, const void *, void *, int, aclDataType, aclComputeType, aclrtStream) {
    return on_init_failed<aclError >(0);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForGemvEx_init(aclTransType arg0, int arg1, int arg2, aclDataType arg3, aclDataType arg4, aclDataType arg5, aclComputeType arg6, aclopHandle **arg7) {
    load_library();
    return aclblasCreateHandleForGemvEx(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForGemvEx_error(aclTransType, int, int, aclDataType, aclDataType, aclDataType, aclComputeType, aclopHandle **) {
    return on_init_failed<aclError >(1);
}
static aclError _WRAPLIB_API_CALL aclblasHgemv_init(aclTransType arg0, int arg1, int arg2, const aclFloat16 *arg3, const aclFloat16 *arg4, int arg5, const aclFloat16 *arg6, int arg7, const aclFloat16 *arg8, aclFloat16 *arg9, int arg10, aclComputeType arg11, aclrtStream arg12) {
    load_library();
    return aclblasHgemv(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12);
}
static aclError _WRAPLIB_API_CALL aclblasHgemv_error(aclTransType, int, int, const aclFloat16 *, const aclFloat16 *, int, const aclFloat16 *, int, const aclFloat16 *, aclFloat16 *, int, aclComputeType, aclrtStream) {
    return on_init_failed<aclError >(2);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForHgemv_init(aclTransType arg0, int arg1, int arg2, aclComputeType arg3, aclopHandle **arg4) {
    load_library();
    return aclblasCreateHandleForHgemv(arg0, arg1, arg2, arg3, arg4);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForHgemv_error(aclTransType, int, int, aclComputeType, aclopHandle **) {
    return on_init_failed<aclError >(3);
}
static aclError _WRAPLIB_API_CALL aclblasS8gemv_init(aclTransType arg0, int arg1, int arg2, const int32_t *arg3, const int8_t *arg4, int arg5, const int8_t *arg6, int arg7, const int32_t *arg8, int32_t *arg9, int arg10, aclComputeType arg11, aclrtStream arg12) {
    load_library();
    return aclblasS8gemv(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12);
}
static aclError _WRAPLIB_API_CALL aclblasS8gemv_error(aclTransType, int, int, const int32_t *, const int8_t *, int, const int8_t *, int, const int32_t *, int32_t *, int, aclComputeType, aclrtStream) {
    return on_init_failed<aclError >(4);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForS8gemv_init(aclTransType arg0, int arg1, int arg2, aclComputeType arg3, aclopHandle **arg4) {
    load_library();
    return aclblasCreateHandleForS8gemv(arg0, arg1, arg2, arg3, arg4);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForS8gemv_error(aclTransType, int, int, aclComputeType, aclopHandle **) {
    return on_init_failed<aclError >(5);
}
static aclError _WRAPLIB_API_CALL aclblasGemmEx_init(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, const void *arg6, const void *arg7, int arg8, aclDataType arg9, const void *arg10, int arg11, aclDataType arg12, const void *arg13, void *arg14, int arg15, aclDataType arg16, aclComputeType arg17, aclrtStream arg18) {
    load_library();
    return aclblasGemmEx(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18);
}
static aclError _WRAPLIB_API_CALL aclblasGemmEx_error(aclTransType, aclTransType, aclTransType, int, int, int, const void *, const void *, int, aclDataType, const void *, int, aclDataType, const void *, void *, int, aclDataType, aclComputeType, aclrtStream) {
    return on_init_failed<aclError >(6);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForGemmEx_init(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, aclDataType arg6, aclDataType arg7, aclDataType arg8, aclComputeType arg9, aclopHandle **arg10) {
    load_library();
    return aclblasCreateHandleForGemmEx(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForGemmEx_error(aclTransType, aclTransType, aclTransType, int, int, int, aclDataType, aclDataType, aclDataType, aclComputeType, aclopHandle **) {
    return on_init_failed<aclError >(7);
}
static aclError _WRAPLIB_API_CALL aclblasHgemm_init(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, const aclFloat16 *arg6, const aclFloat16 *arg7, int arg8, const aclFloat16 *arg9, int arg10, const aclFloat16 *arg11, aclFloat16 *arg12, int arg13, aclComputeType arg14, aclrtStream arg15) {
    load_library();
    return aclblasHgemm(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15);
}
static aclError _WRAPLIB_API_CALL aclblasHgemm_error(aclTransType, aclTransType, aclTransType, int, int, int, const aclFloat16 *, const aclFloat16 *, int, const aclFloat16 *, int, const aclFloat16 *, aclFloat16 *, int, aclComputeType, aclrtStream) {
    return on_init_failed<aclError >(8);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForHgemm_init(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, aclComputeType arg6, aclopHandle **arg7) {
    load_library();
    return aclblasCreateHandleForHgemm(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForHgemm_error(aclTransType, aclTransType, aclTransType, int, int, int, aclComputeType, aclopHandle **) {
    return on_init_failed<aclError >(9);
}
static aclError _WRAPLIB_API_CALL aclblasS8gemm_init(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, const int32_t *arg6, const int8_t *arg7, int arg8, const int8_t *arg9, int arg10, const int32_t *arg11, int32_t *arg12, int arg13, aclComputeType arg14, aclrtStream arg15) {
    load_library();
    return aclblasS8gemm(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15);
}
static aclError _WRAPLIB_API_CALL aclblasS8gemm_error(aclTransType, aclTransType, aclTransType, int, int, int, const int32_t *, const int8_t *, int, const int8_t *, int, const int32_t *, int32_t *, int, aclComputeType, aclrtStream) {
    return on_init_failed<aclError >(10);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForS8gemm_init(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, aclComputeType arg6, aclopHandle **arg7) {
    load_library();
    return aclblasCreateHandleForS8gemm(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7);
}
static aclError _WRAPLIB_API_CALL aclblasCreateHandleForS8gemm_error(aclTransType, aclTransType, aclTransType, int, int, int, aclComputeType, aclopHandle **) {
    return on_init_failed<aclError >(11);
}
static constexpr size_t NR_FUNC = 12;
static void* g_func_table[NR_FUNC] = {(void*)(&aclblasGemvEx_init),
    (void*)(&aclblasCreateHandleForGemvEx_init),
    (void*)(&aclblasHgemv_init),
    (void*)(&aclblasCreateHandleForHgemv_init),
    (void*)(&aclblasS8gemv_init),
    (void*)(&aclblasCreateHandleForS8gemv_init),
    (void*)(&aclblasGemmEx_init),
    (void*)(&aclblasCreateHandleForGemmEx_init),
    (void*)(&aclblasHgemm_init),
    (void*)(&aclblasCreateHandleForHgemm_init),
    (void*)(&aclblasS8gemm_init),
    (void*)(&aclblasCreateHandleForS8gemm_init)};
static void* g_func_table_error[NR_FUNC] = {(void*)(&aclblasGemvEx_error),
    (void*)(&aclblasCreateHandleForGemvEx_error),
    (void*)(&aclblasHgemv_error),
    (void*)(&aclblasCreateHandleForHgemv_error),
    (void*)(&aclblasS8gemv_error),
    (void*)(&aclblasCreateHandleForS8gemv_error),
    (void*)(&aclblasGemmEx_error),
    (void*)(&aclblasCreateHandleForGemmEx_error),
    (void*)(&aclblasHgemm_error),
    (void*)(&aclblasCreateHandleForHgemm_error),
    (void*)(&aclblasS8gemm_error),
    (void*)(&aclblasCreateHandleForS8gemm_error)};
static const char* const g_func_name[NR_FUNC] = {"aclblasGemvEx",
    "aclblasCreateHandleForGemvEx",
    "aclblasHgemv",
    "aclblasCreateHandleForHgemv",
    "aclblasS8gemv",
    "aclblasCreateHandleForS8gemv",
    "aclblasGemmEx",
    "aclblasCreateHandleForGemmEx",
    "aclblasHgemm",
    "aclblasCreateHandleForHgemm",
    "aclblasS8gemm",
    "aclblasCreateHandleForS8gemm"};

static void load_library() {
    static bool done = false;
    static std::mutex mtx;
    std::lock_guard<std::mutex> lg{mtx};

    if (done)
        return;

    void* handle = get_library_handle();
    for (size_t i = 0; i < NR_FUNC; ++i) {
        void* func;
        if (!handle) {
            func = nullptr;
        } else {
            func = resolve_library_func(handle, g_func_name[i]);
        }
        if (!func) {
            func = g_func_table_error[i];
        }
        __atomic_store_n(g_func_table + i, func, __ATOMIC_RELAXED);
    }
    done = true;
}

aclError _WRAPLIB_API_CALL aclblasGemvEx(aclTransType arg0, int arg1, int arg2, const void *arg3, const void *arg4, int arg5, aclDataType arg6, const void *arg7, int arg8, aclDataType arg9, const void *arg10, void *arg11, int arg12, aclDataType arg13, aclComputeType arg14, aclrtStream arg15) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, int, int, const void *, const void *, int, aclDataType, const void *, int, aclDataType, const void *, void *, int, aclDataType, aclComputeType, aclrtStream);
    ON_ENTRY(aclblasGemvEx);
    f_ptr_t f = (f_ptr_t)(g_func_table[0]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15);
}
aclError _WRAPLIB_API_CALL aclblasCreateHandleForGemvEx(aclTransType arg0, int arg1, int arg2, aclDataType arg3, aclDataType arg4, aclDataType arg5, aclComputeType arg6, aclopHandle **arg7) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, int, int, aclDataType, aclDataType, aclDataType, aclComputeType, aclopHandle **);
    ON_ENTRY(aclblasCreateHandleForGemvEx);
    f_ptr_t f = (f_ptr_t)(g_func_table[1]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7);
}
aclError _WRAPLIB_API_CALL aclblasHgemv(aclTransType arg0, int arg1, int arg2, const aclFloat16 *arg3, const aclFloat16 *arg4, int arg5, const aclFloat16 *arg6, int arg7, const aclFloat16 *arg8, aclFloat16 *arg9, int arg10, aclComputeType arg11, aclrtStream arg12) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, int, int, const aclFloat16 *, const aclFloat16 *, int, const aclFloat16 *, int, const aclFloat16 *, aclFloat16 *, int, aclComputeType, aclrtStream);
    ON_ENTRY(aclblasHgemv);
    f_ptr_t f = (f_ptr_t)(g_func_table[2]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12);
}
aclError _WRAPLIB_API_CALL aclblasCreateHandleForHgemv(aclTransType arg0, int arg1, int arg2, aclComputeType arg3, aclopHandle **arg4) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, int, int, aclComputeType, aclopHandle **);
    ON_ENTRY(aclblasCreateHandleForHgemv);
    f_ptr_t f = (f_ptr_t)(g_func_table[3]);
    return f(arg0, arg1, arg2, arg3, arg4);
}
aclError _WRAPLIB_API_CALL aclblasS8gemv(aclTransType arg0, int arg1, int arg2, const int32_t *arg3, const int8_t *arg4, int arg5, const int8_t *arg6, int arg7, const int32_t *arg8, int32_t *arg9, int arg10, aclComputeType arg11, aclrtStream arg12) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, int, int, const int32_t *, const int8_t *, int, const int8_t *, int, const int32_t *, int32_t *, int, aclComputeType, aclrtStream);
    ON_ENTRY(aclblasS8gemv);
    f_ptr_t f = (f_ptr_t)(g_func_table[4]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12);
}
aclError _WRAPLIB_API_CALL aclblasCreateHandleForS8gemv(aclTransType arg0, int arg1, int arg2, aclComputeType arg3, aclopHandle **arg4) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, int, int, aclComputeType, aclopHandle **);
    ON_ENTRY(aclblasCreateHandleForS8gemv);
    f_ptr_t f = (f_ptr_t)(g_func_table[5]);
    return f(arg0, arg1, arg2, arg3, arg4);
}
aclError _WRAPLIB_API_CALL aclblasGemmEx(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, const void *arg6, const void *arg7, int arg8, aclDataType arg9, const void *arg10, int arg11, aclDataType arg12, const void *arg13, void *arg14, int arg15, aclDataType arg16, aclComputeType arg17, aclrtStream arg18) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, aclTransType, aclTransType, int, int, int, const void *, const void *, int, aclDataType, const void *, int, aclDataType, const void *, void *, int, aclDataType, aclComputeType, aclrtStream);
    ON_ENTRY(aclblasGemmEx);
    f_ptr_t f = (f_ptr_t)(g_func_table[6]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15, arg16, arg17, arg18);
}
aclError _WRAPLIB_API_CALL aclblasCreateHandleForGemmEx(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, aclDataType arg6, aclDataType arg7, aclDataType arg8, aclComputeType arg9, aclopHandle **arg10) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, aclTransType, aclTransType, int, int, int, aclDataType, aclDataType, aclDataType, aclComputeType, aclopHandle **);
    ON_ENTRY(aclblasCreateHandleForGemmEx);
    f_ptr_t f = (f_ptr_t)(g_func_table[7]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10);
}
aclError _WRAPLIB_API_CALL aclblasHgemm(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, const aclFloat16 *arg6, const aclFloat16 *arg7, int arg8, const aclFloat16 *arg9, int arg10, const aclFloat16 *arg11, aclFloat16 *arg12, int arg13, aclComputeType arg14, aclrtStream arg15) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, aclTransType, aclTransType, int, int, int, const aclFloat16 *, const aclFloat16 *, int, const aclFloat16 *, int, const aclFloat16 *, aclFloat16 *, int, aclComputeType, aclrtStream);
    ON_ENTRY(aclblasHgemm);
    f_ptr_t f = (f_ptr_t)(g_func_table[8]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15);
}
aclError _WRAPLIB_API_CALL aclblasCreateHandleForHgemm(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, aclComputeType arg6, aclopHandle **arg7) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, aclTransType, aclTransType, int, int, int, aclComputeType, aclopHandle **);
    ON_ENTRY(aclblasCreateHandleForHgemm);
    f_ptr_t f = (f_ptr_t)(g_func_table[9]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7);
}
aclError _WRAPLIB_API_CALL aclblasS8gemm(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, const int32_t *arg6, const int8_t *arg7, int arg8, const int8_t *arg9, int arg10, const int32_t *arg11, int32_t *arg12, int arg13, aclComputeType arg14, aclrtStream arg15) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, aclTransType, aclTransType, int, int, int, const int32_t *, const int8_t *, int, const int8_t *, int, const int32_t *, int32_t *, int, aclComputeType, aclrtStream);
    ON_ENTRY(aclblasS8gemm);
    f_ptr_t f = (f_ptr_t)(g_func_table[10]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9, arg10, arg11, arg12, arg13, arg14, arg15);
}
aclError _WRAPLIB_API_CALL aclblasCreateHandleForS8gemm(aclTransType arg0, aclTransType arg1, aclTransType arg2, int arg3, int arg4, int arg5, aclComputeType arg6, aclopHandle **arg7) {
    typedef aclError (_WRAPLIB_API_CALL *f_ptr_t)(aclTransType, aclTransType, aclTransType, int, int, int, aclComputeType, aclopHandle **);
    ON_ENTRY(aclblasCreateHandleForS8gemm);
    f_ptr_t f = (f_ptr_t)(g_func_table[11]);
    return f(arg0, arg1, arg2, arg3, arg4, arg5, arg6, arg7);
}