algo.cpp 20.5 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/arm_common/elemwise/ternary/algo.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */
#include "src/arm_common/elemwise/ternary/algo.h"
#include "src/arm_common/elemwise_op.h"

#include "src/common/utils.h"
#include "src/naive/handle.h"

#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_elemwise_ternary)

using namespace megdnn;
using namespace arm_common;

#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id) \
    auto mode = kern_param.mode;                           \
    if (mode == Mode::FUSE_MUL_ADD3)                       \
        return true;
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT

#define DECL_AVAILABLE(case, type)                                       \
    bool ElemwiseImpl::AlgoTernaryFma3##case ::is_available(             \
            const KernParam& kern_param) const {                         \
        if (type == kern_param.broad_cast_type) {                        \
            auto& elparam = kern_param.ternary_elparam;                  \
            auto& src0 = elparam[0];                                     \
            DISPATCH_TYPE("AlgoTernaryFma3::is_available" #case##_hash); \
        }                                                                \
        return false;                                                    \
    }

DECL_AVAILABLE(VecVecVec, BcastType::VEC_VEC_VEC);
DECL_AVAILABLE(VecVecScalar, BcastType::VEC_VEC_SCALAR);
DECL_AVAILABLE(Bcast101VecBcast101, BcastType::BCAST101_VEC_BCAST101);
45
DECL_AVAILABLE(Bcast101xXVecBcast101xX, BcastType::BCAST101xX_VEC_BCAST101xX);
46
DECL_AVAILABLE(VecBcast101Vec, BcastType::VEC_BCAST101_VEC);
47
DECL_AVAILABLE(VecBcast101xXVec, BcastType::VEC_BCAST101xX_VEC);
48 49 50 51 52 53
DECL_AVAILABLE(VecScalarVec, BcastType::VEC_SCALAR_VEC);
DECL_AVAILABLE(VecScalarScalar, BcastType::VEC_SCALAR_SCALAR);
#undef DECL_CB
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

M
Megvii Engine Team 已提交
54 55 56 57 58 59 60
#define DISPATCH_MODE_FLOAT(_case, _type, _type_midout_id)                             \
    switch (kern_param.mode) {                                                         \
        DISPATCH_TERNARY(FUSE_MUL_ADD3, _case, _type, _type_midout_id, FuseMulAdd3Op); \
        default:                                                                       \
            megdnn_throw(ssprintf(                                                     \
                    "No avaiable algo find for: %d",                                   \
                    static_cast<int>(kern_param.mode)));                               \
61 62
    }
#define DISPATCH_MODE_INT DISPATCH_MODE_FLOAT
M
Megvii Engine Team 已提交
63
void ElemwiseImpl::AlgoTernaryFma3VecVecVec::exec(const KernParam& kern_param) const {
64 65 66 67
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 1: shape of (src0, src2) and src1 are exactly match
M
Megvii Engine Team 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t)>                                          \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>, BcastType::VEC_VEC_VEC>::run;        \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
                    run(static_cast<const _type*>(src0.raw_ptr),                    \
                        static_cast<const _type*>(src1.raw_ptr),                    \
                        static_cast<const _type*>(src2.raw_ptr),                    \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,        \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        src0.layout.total_nr_elems()));                             \
        }                                                                           \
        MIDOUT_END();                                                               \
88 89 90 91 92 93 94 95 96 97 98 99 100 101
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecVecVec::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
void ElemwiseImpl::AlgoTernaryFma3VecVecScalar::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 2: (src2 is a scalar) && (src0 and src1 has the same shape)
M
Megvii Engine Team 已提交
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                \
    case Mode::_mode:                                                              \
        MIDOUT_BEGIN(                                                              \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),              \
                midout_iv(Mode::_mode), _type_midout_id) {                         \
            thin_function<void(                                                    \
                    const _type*, const _type*, const _type, _type*, DType, DType, \
                    DType, DType, size_t)>                                         \
                    run = OpCallerTernary<                                         \
                            _op<_type, _type>, BcastType::VEC_VEC_SCALAR>::run;    \
            MEGDNN_DISPATCH_CPU_KERN(                                              \
                    static_cast<naive::HandleImpl*>(kern_param.handle),            \
                    run(static_cast<const _type*>(src0.raw_ptr),                   \
                        static_cast<const _type*>(src1.raw_ptr),                   \
                        static_cast<const _type*>(src2.raw_ptr)[0],                \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,       \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,    \
                        src0.layout.total_nr_elems()));                            \
        }                                                                          \
        MIDOUT_END();                                                              \
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecVecScalar::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
void ElemwiseImpl::AlgoTernaryFma3Bcast101VecBcast101::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 3: shape of src0 and src2 is {1, C, 1, 1}
    BroadcastChannelInfo binfo;
    is_broadcasted_channel_like(src0.layout, binfo);
M
Megvii Engine Team 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                    \
    case Mode::_mode:                                                                  \
        MIDOUT_BEGIN(                                                                  \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),                  \
                midout_iv(Mode::_mode), _type_midout_id) {                             \
            thin_function<void(                                                        \
                    const _type*, const _type*, const _type*, _type*, DType, DType,    \
                    DType, DType, size_t, size_t, size_t)>                             \
                    run = OpCallerTernary<                                             \
                            _op<_type, _type>, BcastType::BCAST101_VEC_BCAST101>::run; \
            MEGDNN_DISPATCH_CPU_KERN(                                                  \
                    static_cast<naive::HandleImpl*>(kern_param.handle),                \
                    run(static_cast<const _type*>(src0.raw_ptr),                       \
                        static_cast<const _type*>(src1.raw_ptr),                       \
                        static_cast<const _type*>(src2.raw_ptr),                       \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,           \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,        \
                        binfo.x, binfo.y, binfo.z));                                   \
        }                                                                              \
        MIDOUT_END();                                                                  \
158 159 160 161 162 163 164 165
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3Bcast101VecBcast101::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
166

167
void ElemwiseImpl::AlgoTernaryFma3Bcast101xXVecBcast101xX::exec(
168 169 170 171 172
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    BroadcastChannelInfo binfo;
M
Megvii Engine Team 已提交
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    megdnn_assert(
            is_broadcastedx_channel_like<4>(src0.layout, binfo) ||
                    is_broadcastedx_channel_like<8>(src0.layout, binfo),
            "only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t, size_t, size_t, size_t)>                  \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>,                                      \
                            BcastType::BCAST101xX_VEC_BCAST101xX>::run;             \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
                    run(static_cast<const _type*>(src0.raw_ptr),                    \
                        static_cast<const _type*>(src1.raw_ptr),                    \
                        static_cast<const _type*>(src2.raw_ptr),                    \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,        \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        batch_size, binfo.x, binfo.y, binfo.z));                    \
        }                                                                           \
        MIDOUT_END();                                                               \
198 199 200 201
        return

    size_t batch_size = src1.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
    auto&& dst = *(kern_param.m_dst);
202
    DISPATCH_TYPE("AlgoTernaryFma3Bcast101xXVecBcast101xX::exec"_hash);
203 204 205 206 207
#undef DISPATCH_TERNARY

    return;
}

208
void ElemwiseImpl::AlgoTernaryFma3VecBcast101xXVec::exec(
209 210 211 212 213
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    BroadcastChannelInfo binfo;
M
Megvii Engine Team 已提交
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    megdnn_assert(
            is_broadcastedx_channel_like<4>(src1.layout, binfo) ||
                    is_broadcastedx_channel_like<8>(src1.layout, binfo),
            "only nchw44 and nchw88 supported");
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t, size_t, size_t, size_t)>                  \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>, BcastType::VEC_BCAST101xX_VEC>::run; \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
                    run(static_cast<const _type*>(src0.raw_ptr),                    \
                        static_cast<const _type*>(src1.raw_ptr),                    \
                        static_cast<const _type*>(src2.raw_ptr),                    \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,        \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        batch_size, binfo.x, binfo.y, binfo.z));                    \
        }                                                                           \
        MIDOUT_END();                                                               \
238 239 240 241
        return

    size_t batch_size = src0.layout.shape[0] / (binfo.x * binfo.y * binfo.z);
    auto&& dst = *(kern_param.m_dst);
242
    DISPATCH_TYPE("AlgoTernaryFma3VecBcast101xXVec::exec"_hash);
243 244 245 246 247
#undef DISPATCH_TERNARY

    return;
}

248 249 250 251 252 253 254 255
void ElemwiseImpl::AlgoTernaryFma3VecBcast101Vec::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 4: shape of src1 is {1, C, 1, 1}, and src0 and src2 are contig
    BroadcastChannelInfo binfo;
    is_broadcasted_channel_like(src1.layout, binfo);
M
Megvii Engine Team 已提交
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                 \
    case Mode::_mode:                                                               \
        MIDOUT_BEGIN(                                                               \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),               \
                midout_iv(Mode::_mode), _type_midout_id) {                          \
            thin_function<void(                                                     \
                    const _type*, const _type*, const _type*, _type*, DType, DType, \
                    DType, DType, size_t, size_t, size_t)>                          \
                    run = OpCallerTernary<                                          \
                            _op<_type, _type>, BcastType::VEC_BCAST101_VEC>::run;   \
            MEGDNN_DISPATCH_CPU_KERN(                                               \
                    static_cast<naive::HandleImpl*>(kern_param.handle),             \
                    run(static_cast<const _type*>(src0.raw_ptr),                    \
                        static_cast<const _type*>(src1.raw_ptr),                    \
                        static_cast<const _type*>(src2.raw_ptr),                    \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,        \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,     \
                        binfo.x, binfo.y, binfo.z));                                \
        }                                                                           \
        MIDOUT_END();                                                               \
276 277 278 279 280 281 282 283
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecBcast101Vec::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
284

285 286 287 288 289 290
void ElemwiseImpl::AlgoTernaryFma3VecScalarVec::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 5: (src1 is a scalar) && (src0 and src2 has the same shape)
M
Megvii Engine Team 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                \
    case Mode::_mode:                                                              \
        MIDOUT_BEGIN(                                                              \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),              \
                midout_iv(Mode::_mode), _type_midout_id) {                         \
            thin_function<void(                                                    \
                    const _type*, const _type, const _type*, _type*, DType, DType, \
                    DType, DType, size_t)>                                         \
                    run = OpCallerTernary<                                         \
                            _op<_type, _type>, BcastType::VEC_SCALAR_VEC>::run;    \
            MEGDNN_DISPATCH_CPU_KERN(                                              \
                    static_cast<naive::HandleImpl*>(kern_param.handle),            \
                    run(static_cast<const _type*>(src0.raw_ptr),                   \
                        static_cast<const _type*>(src1.raw_ptr)[0],                \
                        static_cast<const _type*>(src2.raw_ptr),                   \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,       \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,    \
                        src0.layout.total_nr_elems()));                            \
        }                                                                          \
        MIDOUT_END();                                                              \
311 312 313 314 315 316 317 318 319 320 321 322 323 324
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecScalarVec::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
void ElemwiseImpl::AlgoTernaryFma3VecScalarScalar::exec(
        const KernParam& kern_param) const {
    auto& elparam = kern_param.ternary_elparam;
    auto &src0 = elparam[0], &src1 = elparam[1], &src2 = elparam[2];

    // Case 6: (src1 and src2 is scalar) && (src0 is vector)
M
Megvii Engine Team 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
#define DISPATCH_TERNARY(_mode, _case, _type, _type_midout_id, _op)                \
    case Mode::_mode:                                                              \
        MIDOUT_BEGIN(                                                              \
                megdnn_arm_common_elemwise_ternary, midout_iv(_case),              \
                midout_iv(Mode::_mode), _type_midout_id) {                         \
            thin_function<void(                                                    \
                    const _type*, const _type, const _type, _type*, DType, DType,  \
                    DType, DType, size_t)>                                         \
                    run = OpCallerTernary<                                         \
                            _op<_type, _type>, BcastType::VEC_SCALAR_SCALAR>::run; \
            MEGDNN_DISPATCH_CPU_KERN(                                              \
                    static_cast<naive::HandleImpl*>(kern_param.handle),            \
                    run(static_cast<const _type*>(src0.raw_ptr),                   \
                        static_cast<const _type*>(src1.raw_ptr)[0],                \
                        static_cast<const _type*>(src2.raw_ptr)[0],                \
                        static_cast<_type*>(dst.raw_ptr), src0.layout.dtype,       \
                        src1.layout.dtype, src2.layout.dtype, dst.layout.dtype,    \
                        src0.layout.total_nr_elems()));                            \
        }                                                                          \
        MIDOUT_END();                                                              \
345 346 347 348 349 350 351 352 353 354 355 356
        return

    auto&& dst = *(kern_param.m_dst);
    DISPATCH_TYPE("AlgoTernaryFma3VecScalarScalar::exec"_hash);
#undef DISPATCH_TERNARY

    return;
}
#undef DISPATCH_MODE_FLOAT
#undef DISPATCH_MODE_INT

// vim: syntax=cpp.doxygen