algo.cpp 25.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/arm_common/elemwise/ternary/algo.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "src/arm_common/elemwise/ternary/algo.h"
#include "src/arm_common/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_elemwise_ternary)

using namespace megdnn;
using namespace arm_common;

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
    auto mode = kern_param.mode;                           \
    if (mode == Mode::FUSE_MUL_ADD3)                       \
        return true;
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT

#define DECL_AVAILABLE(case, type)                                       \
    bool ElemwiseImpl::AlgoTernaryFma3##case ::is_available(             \
            const KernParam& kern_param) const {                         \
        if (type == kern_param.broad_cast_type) {                        \
            auto& elparam = kern_param.ternary_elparam;                  \
            auto& src0 = elparam[0];                                     \
            DISPATCH_TYPE("AlgoTernaryFma3::is_available" #case##_hash); \
        }                                                                \
        return false;                                                    \
    }

DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC);
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR);
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101);
45
DECL_AVAILABLE(Bcast111CVecBcast111C, BcastType::BCAST111C_VEC_BCAST111C);
46
DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX);
47
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC);
48
DECL_AVAILABLE(VecBcast111CVec, BcastType::VEC_BCAST111C_VEC);
49
DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC);
50 51 52 53 54 55
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC);
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR);
#undef DECL_CB
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

M
Megvii Engine Team 已提交
56 57 58 59 60 61 62
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id)                             \
    switch (kern_param.mode) {                                                         \
        DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, FuseMulAdd3Op); \
        default:                                                                       \
            megdnn_throw(ssprintf(                                                     \
                    "No avaiable algo find for: %d",                                   \
                    static_cast<int>(kern_param.mode)));                               \
63 64
    }
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT
M
Megvii Engine Team 已提交
65
void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) const {
66 67 68 69
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 1: shape of (src0, src2) and src1 are exactly match
M
Megvii Engine Team 已提交
70 71 72 73 74 75 76 77 78 79 80 81
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t)>                                          \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>, BcastType::VEC_VEC_VEC>::run;        \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
82 83 84 85
                    run(static_cast<const _type*>(src0.raw_ptr()),                  \
                        static_cast<const _type*>(src1.raw_ptr()),                  \
                        static_cast<const _type*>(src2.raw_ptr()),                  \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,      \
M
Megvii Engine Team 已提交
86 87 88 89
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        src0.layout.total_nr_elems()));                             \
        }                                                                           \
        MIDOUT_END();                                                               \
90 91 92 93 94 95 96 97 98 99 100 101 102 103
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecVecVec::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 2: (src2 is a scalar) && (src0 and src1 has the same shape)
M
Megvii Engine Team 已提交
104 105 106 107 108 109 110 111 112 113 114 115
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                \
    case Mode::_mode:                                                              \
        MIDOUT_BEGIN(                                                              \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),              \
                midout_iv(Mode::_mode), _type_midout_id) {                         \
            thin_function<void(                                                    \
                    const _type*, const _type*, const _type, _type*, DType, DType, \
                    DType, DType, size_t)>                                         \
                    run = OpCallerTernary<                                         \
                            _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run;    \
            MEGDNN_DISPATCH_CPU_KERN(                                              \
                    static_cast<naive::HandleImpl*>(kern_param.handle),            \
116 117 118 119
                    run(static_cast<const _type*>(src0.raw_ptr()),                 \
                        static_cast<const _type*>(src1.raw_ptr()),                 \
                        static_cast<const _type*>(src2.raw_ptr())[0],              \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,     \
M
Megvii Engine Team 已提交
120 121 122 123
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,    \
                        src0.layout.total_nr_elems()));                            \
        }                                                                          \
        MIDOUT_END();                                                              \
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecVecScalar::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 3: shape of src0 and src2 is {1, C, 1, 1}
    BroadcastChannelInfo binfo;
    is_broadcasted_channel_like(src0.layout, binfo);
M
Megvii Engine Team 已提交
140 141 142 143 144 145 146 147 148 149 150 151
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                    \
    case Mode::_mode:                                                                  \
        MIDOUT_BEGIN(                                                                  \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),                  \
                midout_iv(Mode::_mode), _type_midout_id) {                             \
            thin_function<void(                                                        \
                    const _type*, const _type*, const _type*, _type*, DType, DType,    \
                    DType, DType, size_t, size_t, size_t)>                             \
                    run = OpCallerTernary<                                             \
                            _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \
            MEGDNN_DISPATCH_CPU_KERN(                                                  \
                    static_cast<naive::HandleImpl*>(kern_param.handle),                \
152 153 154 155
                    run(static_cast<const _type*>(src0.raw_ptr()),                     \
                        static_cast<const _type*>(src1.raw_ptr()),                     \
                        static_cast<const _type*>(src2.raw_ptr()),                     \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,         \
M
Megvii Engine Team 已提交
156 157 158 159
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,        \
                        binfo.x, binfo.y, binfo.z));                                   \
        }                                                                              \
        MIDOUT_END();                                                                  \
160 161 162 163 164 165 166 167
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
168

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
void ElemwiseImpl::AlgoTernaryFma3Bcast111CVecBcast111C::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 3: shape of src0 and src2 is {1, 1, 1, C}
    BroadcastChannelInfo binfo;
    is_NHWC_broadcasted_channel_like(src0.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                   \
    case Mode::_mode:                                                                 \
        MIDOUT_BEGIN(                                                                 \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),                 \
                midout_iv(Mode::_mode), _type_midout_id) {                            \
            thin_function<void(                                                       \
                    const _type*, const _type*, size_t, const _type*, _type*, DType,  \
                    DType, DType, DType, size_t, size_t, size_t)>                     \
                    run = OpCallerTernary<                                            \
                            _op<_type, _type>,                                        \
                            BcastType::BCAST111C_VEC_BCAST111C>::run;                 \
            MEGDNN_DISPATCH_CPU_KERN(                                                 \
                    static_cast<naive::HandleImpl*>(kern_param.handle),               \
190 191
                    run(static_cast<const _type*>(src0.raw_ptr()),                    \
                        static_cast<const _type*>(src1.raw_ptr()),                    \
192
                        is_vector(src1.layout) ? 0 : src1.layout.stride[0] - binfo.z, \
193 194
                        static_cast<const _type*>(src2.raw_ptr()),                    \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,        \
195 196 197 198 199 200 201 202 203 204 205 206 207
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,       \
                        binfo.x, binfo.y, binfo.z));                                  \
        }                                                                             \
        MIDOUT_END();                                                                 \
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3Bcast111CVecBcast111C::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}

208
void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec(
209 210 211 212 213
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    BroadcastChannelInfo binfo;
M
Megvii Engine Team 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
    megdnn_assert(
            is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
                    is_broadcastedx_channel_like<8>(src0.layout, binfo),
            "only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t, size_t, size_t, size_t)>                  \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>,                                      \
                            BcastType::BCAST101xX_VEC_BCAST101xX>::run;             \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
231 232 233 234
                    run(static_cast<const _type*>(src0.raw_ptr()),                  \
                        static_cast<const _type*>(src1.raw_ptr()),                  \
                        static_cast<const _type*>(src2.raw_ptr()),                  \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,      \
M
Megvii Engine Team 已提交
235 236 237 238
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        batch_size, binfo.x, binfo.y, binfo.z));                    \
        }                                                                           \
        MIDOUT_END();                                                               \
239 240 241 242
        return

    size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
    auto&& dst = *(kern_param.m_dst);
243
    DISPATCH_TYPE("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash);
244 245 246 247 248
#undef DISPATCH_TERNARY

    return;
}

249
void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec(
250 251 252 253 254
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    BroadcastChannelInfo binfo;
M
Megvii Engine Team 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    megdnn_assert(
            is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
                    is_broadcastedx_channel_like<8>(src1.layout, binfo),
            "only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t, size_t, size_t, size_t)>                  \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
271 272 273 274
                    run(static_cast<const _type*>(src0.raw_ptr()),                  \
                        static_cast<const _type*>(src1.raw_ptr()),                  \
                        static_cast<const _type*>(src2.raw_ptr()),                  \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,      \
M
Megvii Engine Team 已提交
275 276 277 278
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        batch_size, binfo.x, binfo.y, binfo.z));                    \
        }                                                                           \
        MIDOUT_END();                                                               \
279 280 281 282
        return

    size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
    auto&& dst = *(kern_param.m_dst);
283
    DISPATCH_TYPE("AlgoTernaryFma3VecBcast101xXVec::exec"_hash);
284 285 286 287 288
#undef DISPATCH_TERNARY

    return;
}

289 290 291 292 293 294 295 296
void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 4: shape of src1 is {1, C, 1, 1}, and src0 and src2 are contig
    BroadcastChannelInfo binfo;
    is_broadcasted_channel_like(src1.layout, binfo);
M
Megvii Engine Team 已提交
297 298 299 300 301 302 303 304 305 306 307 308
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t, size_t, size_t)>                          \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run;   \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
309 310 311 312
                    run(static_cast<const _type*>(src0.raw_ptr()),                  \
                        static_cast<const _type*>(src1.raw_ptr()),                  \
                        static_cast<const _type*>(src2.raw_ptr()),                  \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,      \
M
Megvii Engine Team 已提交
313 314 315 316
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        binfo.x, binfo.y, binfo.z));                                \
        }                                                                           \
        MIDOUT_END();                                                               \
317 318 319 320 321 322 323 324
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecBcast101Vec::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
325

326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345
void ElemwiseImpl::AlgoTernaryFma3VecBcast111CVec::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 4: shape of src1 is {1, 1, 1, C}, and src0 and src2 are contig
    BroadcastChannelInfo binfo;
    is_NHWC_broadcasted_channel_like(src1.layout, binfo);
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                   \
    case Mode::_mode:                                                                 \
        MIDOUT_BEGIN(                                                                 \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),                 \
                midout_iv(Mode::_mode), _type_midout_id) {                            \
            thin_function<void(                                                       \
                    const _type*, size_t, const _type*, const _type*, size_t, _type*, \
                    DType, DType, DType, DType, size_t, size_t, size_t)>              \
                    run = OpCallerTernary<                                            \
                            _op<_type, _type>, BcastType::VEC_BCAST111C_VEC>::run;    \
            MEGDNN_DISPATCH_CPU_KERN(                                                 \
                    static_cast<naive::HandleImpl*>(kern_param.handle),               \
346
                    run(static_cast<const _type*>(src0.raw_ptr()),                    \
347
                        is_vector(src0.layout) ? 0 : src0.layout.stride[0] - binfo.z, \
348 349
                        static_cast<const _type*>(src1.raw_ptr()),                    \
                        static_cast<const _type*>(src2.raw_ptr()),                    \
350
                        is_vector(src2.layout) ? 0 : src2.layout.stride[0] - binfo.z, \
351
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,        \
352 353 354 355 356 357 358 359 360 361 362 363 364
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,       \
                        binfo.x, binfo.y, binfo.z));                                  \
        }                                                                             \
        MIDOUT_END();                                                                 \
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecBcast111CVec::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}

365 366 367 368 369 370
void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 5: (src1 is a scalar) && (src0 and src2 has the same shape)
M
Megvii Engine Team 已提交
371 372 373 374 375 376 377 378 379 380 381 382
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                \
    case Mode::_mode:                                                              \
        MIDOUT_BEGIN(                                                              \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),              \
                midout_iv(Mode::_mode), _type_midout_id) {                         \
            thin_function<void(                                                    \
                    const _type*, const _type, const _type*, _type*, DType, DType, \
                    DType, DType, size_t)>                                         \
                    run = OpCallerTernary<                                         \
                            _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run;    \
            MEGDNN_DISPATCH_CPU_KERN(                                              \
                    static_cast<naive::HandleImpl*>(kern_param.handle),            \
383 384 385 386
                    run(static_cast<const _type*>(src0.raw_ptr()),                 \
                        static_cast<const _type*>(src1.raw_ptr())[0],              \
                        static_cast<const _type*>(src2.raw_ptr()),                 \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,     \
M
Megvii Engine Team 已提交
387 388 389 390
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,    \
                        src0.layout.total_nr_elems()));                            \
        }                                                                          \
        MIDOUT_END();                                                              \
391 392 393 394 395 396 397 398 399 400 401 402 403 404
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecScalarVec::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 6: (src1 and src2 is scalar) && (src0 is vector)
M
Megvii Engine Team 已提交
405 406 407 408 409 410 411 412 413 414 415 416
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                \
    case Mode::_mode:                                                              \
        MIDOUT_BEGIN(                                                              \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),              \
                midout_iv(Mode::_mode), _type_midout_id) {                         \
            thin_function<void(                                                    \
                    const _type*, const _type, const _type, _type*, DType, DType,  \
                    DType, DType, size_t)>                                         \
                    run = OpCallerTernary<                                         \
                            _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \
            MEGDNN_DISPATCH_CPU_KERN(                                              \
                    static_cast<naive::HandleImpl*>(kern_param.handle),            \
417 418 419 420
                    run(static_cast<const _type*>(src0.raw_ptr()),                 \
                        static_cast<const _type*>(src1.raw_ptr())[0],              \
                        static_cast<const _type*>(src2.raw_ptr())[0],              \
                        static_cast<_type*>(dst.raw_ptr()), src0.layout.dtype,     \
M
Megvii Engine Team 已提交
421 422 423 424
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,    \
                        src0.layout.total_nr_elems()));                            \
        }                                                                          \
        MIDOUT_END();                                                              \
425 426 427 428 429 430 431 432 433 434 435 436
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecScalarScalar::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

// vim: syntax=cpp.doxygen